Skip to content

Commit 9c54eb4

Browse files
committed
Allow last random seed to be reused
1 parent 367df55 commit 9c54eb4

File tree

5 files changed

+22
-7
lines changed

5 files changed

+22
-7
lines changed

Diffusion-macOS/ContentView.swift

+2-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ struct ContentView: View {
6666
@StateObject var generation = GenerationContext()
6767

6868
func toolbar() -> any View {
69-
if case .complete(let prompt, let cgImage, let seed, _) = generation.state, let cgImage = cgImage {
69+
if case .complete(let prompt, let cgImage, _, _) = generation.state, let cgImage = cgImage {
70+
// TODO: share seed too
7071
return ShareButtons(image: cgImage, name: prompt)
7172
} else {
7273
let prompt = DEFAULT_PROMPT

Diffusion-macOS/ControlsView.swift

+5-2
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ struct ControlsView: View {
7070
@State private var showGuidanceHelp = false
7171
@State private var showStepsHelp = false
7272
@State private var showSeedHelp = false
73+
74+
// Reasonable range for the slider
75+
let maxSeed: UInt32 = 1000
7376

7477
func updateSafetyCheckerState() {
7578
mustShowSafetyCheckerDisclaimer = generation.disableSafety && !Settings.shared.safetyCheckerDisclaimerShown
@@ -82,7 +85,7 @@ struct ControlsView: View {
8285
pipelineLoader?.cancel()
8386
pipelineState = .downloading(0)
8487
Task.init {
85-
let loader = PipelineLoader(model: model)
88+
let loader = PipelineLoader(model: model, maxSeed: maxSeed)
8689
self.pipelineLoader = loader
8790
stateSubscriber = loader.statePublisher.sink { state in
8891
DispatchQueue.main.async {
@@ -245,7 +248,7 @@ struct ControlsView: View {
245248

246249
DisclosureGroup(isExpanded: $disclosedSeed) {
247250
let sliderLabel = generation.seed < 0 ? "Random Seed" : "Seed"
248-
CompactSlider(value: $generation.seed, in: -1...1000, step: 1) {
251+
CompactSlider(value: $generation.seed, in: -1...Double(maxSeed), step: 1) {
249252
Text(sliderLabel)
250253
Spacer()
251254
Text("\(Int(generation.seed))")

Diffusion-macOS/StatusView.swift

+6
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ struct StatusView: View {
8383
let intervalString = String(format: "Time: %.1fs", interval ?? 0)
8484
Text(intervalString)
8585
Spacer()
86+
if generation.seed != Double(lastSeed) {
87+
Text("Seed: \(lastSeed)")
88+
Button("Set") {
89+
generation.seed = Double(lastSeed)
90+
}
91+
}
8692
}.frame(maxHeight: 25)
8793
case .failed(let error):
8894
return errorWithDetails("Generation error", error: error)

Diffusion/Pipeline/Pipeline.swift

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ struct GenerationResult {
2222

2323
class Pipeline {
2424
let pipeline: StableDiffusionPipeline
25+
let maxSeed: UInt32
2526

2627
var progress: StableDiffusionProgress? = nil {
2728
didSet {
@@ -31,8 +32,9 @@ class Pipeline {
3132
lazy private(set) var progressPublisher: CurrentValueSubject<StableDiffusionProgress?, Never> = CurrentValueSubject(progress)
3233

3334

34-
init(_ pipeline: StableDiffusionPipeline) {
35+
init(_ pipeline: StableDiffusionPipeline, maxSeed: UInt32 = UInt32.max) {
3536
self.pipeline = pipeline
37+
self.maxSeed = maxSeed
3638
}
3739

3840
func generate(
@@ -46,7 +48,7 @@ class Pipeline {
4648
) throws -> GenerationResult {
4749
let beginDate = Date()
4850
print("Generating...")
49-
let theSeed = seed ?? UInt32.random(in: 0..<UInt32.max)
51+
let theSeed = seed ?? UInt32.random(in: 0...maxSeed)
5052
let images = try pipeline.generateImages(
5153
prompt: prompt,
5254
negativePrompt: negativePrompt,

Diffusion/Pipeline/PipelineLoader.swift

+5-2
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@ class PipelineLoader {
1818
static let models = Path.applicationSupport / "hf-diffusion-models"
1919

2020
let model: ModelInfo
21+
let maxSeed: UInt32
22+
2123
private var downloadSubscriber: Cancellable?
2224

23-
init(model: ModelInfo) {
25+
init(model: ModelInfo, maxSeed: UInt32 = UInt32.max) {
2426
self.model = model
27+
self.maxSeed = maxSeed
2528
state = .undetermined
2629
setInitialState()
2730
}
@@ -100,7 +103,7 @@ extension PipelineLoader {
100103
try await download()
101104
try await unzip()
102105
let pipeline = try await load(url: compiledPath.url)
103-
return Pipeline(pipeline)
106+
return Pipeline(pipeline, maxSeed: maxSeed)
104107
} catch {
105108
state = .failed(error)
106109
throw error

0 commit comments

Comments
 (0)