Skip to content

Commit 1084ea7

Browse files
committed
Allow user to cancel generation.
This will be more useful when we allow multiple images.
1 parent d78d430 commit 1084ea7

File tree

4 files changed

+36
-5
lines changed

4 files changed

+36
-5
lines changed

Diffusion-macOS/GeneratedImageView.swift

+11-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,15 @@ struct GeneratedImageView: View {
2222
let step = Int(progress.step) + 1
2323
let fraction = Double(step) / Double(progress.stepCount)
2424
let label = "Step \(step) of \(progress.stepCount)"
25-
return AnyView(ProgressView(label, value: fraction, total: 1).padding())
25+
return AnyView(HStack {
26+
ProgressView(label, value: fraction, total: 1).padding()
27+
Button {
28+
generation.cancelGeneration()
29+
} label: {
30+
Image(systemName: "x.circle.fill").foregroundColor(.gray)
31+
}
32+
.buttonStyle(.plain)
33+
})
2634
case .complete(_, let image, _, _):
2735
guard let theImage = image else {
2836
return AnyView(Image(systemName: "exclamationmark.triangle").resizable())
@@ -34,6 +42,8 @@ struct GeneratedImageView: View {
3442
)
3543
case .failed(_):
3644
return AnyView(Image(systemName: "exclamationmark.triangle").resizable())
45+
case .userCanceled:
46+
return AnyView(Text("Generation canceled"))
3747
}
3848
}
3949
}

Diffusion-macOS/StatusView.swift

+10-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ struct StatusView: View {
2020
generation.state = .running(nil)
2121
do {
2222
let result = try await generation.generate()
23-
generation.state = .complete(generation.positivePrompt, result.image, result.lastSeed, result.interval)
23+
if result.userCanceled {
24+
generation.state = .userCanceled
25+
} else {
26+
generation.state = .complete(generation.positivePrompt, result.image, result.lastSeed, result.interval)
27+
}
2428
} catch {
2529
generation.state = .failed(error)
2630
}
@@ -92,6 +96,11 @@ struct StatusView: View {
9296
}.frame(maxHeight: 25)
9397
case .failed(let error):
9498
return errorWithDetails("Generation error", error: error)
99+
case .userCanceled:
100+
return HStack {
101+
Text("Generation canceled.")
102+
Spacer()
103+
}
95104
}
96105
}
97106

Diffusion/Pipeline/Pipeline.swift

+10-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ struct GenerationResult {
1818
var image: CGImage?
1919
var lastSeed: UInt32
2020
var interval: TimeInterval?
21+
var userCanceled: Bool
2122
}
2223

2324
class Pipeline {
@@ -30,7 +31,8 @@ class Pipeline {
3031
}
3132
}
3233
lazy private(set) var progressPublisher: CurrentValueSubject<StableDiffusionProgress?, Never> = CurrentValueSubject(progress)
33-
34+
35+
private var canceled = false
3436

3537
init(_ pipeline: StableDiffusionPipeline, maxSeed: UInt32 = UInt32.max) {
3638
self.pipeline = pipeline
@@ -47,6 +49,7 @@ class Pipeline {
4749
disableSafety: Bool = false
4850
) throws -> GenerationResult {
4951
let beginDate = Date()
52+
canceled = false
5053
print("Generating...")
5154
let theSeed = seed ?? UInt32.random(in: 0...maxSeed)
5255
let images = try pipeline.generateImages(
@@ -60,17 +63,21 @@ class Pipeline {
6063
scheduler: scheduler
6164
) { progress in
6265
handleProgress(progress)
63-
return true
66+
return !canceled
6467
}
6568
let interval = Date().timeIntervalSince(beginDate)
6669
print("Got images: \(images) in \(interval)")
6770

6871
// Unwrap the 1 image we asked for, nil means safety checker triggered
6972
let image = images.compactMap({ $0 }).first
70-
return GenerationResult(image: image, lastSeed: theSeed, interval: interval)
73+
return GenerationResult(image: image, lastSeed: theSeed, interval: interval, userCanceled: canceled)
7174
}
7275

7376
func handleProgress(_ progress: StableDiffusionPipeline.Progress) {
7477
self.progress = progress
7578
}
79+
80+
func setCancelled() {
81+
canceled = true
82+
}
7683
}

Diffusion/State.swift

+5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ enum GenerationState {
1717
case startup
1818
case running(StableDiffusionProgress?)
1919
case complete(String, CGImage?, UInt32, TimeInterval?)
20+
case userCanceled
2021
case failed(Error)
2122
}
2223

@@ -63,6 +64,10 @@ class GenerationContext: ObservableObject {
6364
disableSafety: disableSafety
6465
)
6566
}
67+
68+
func cancelGeneration() {
69+
pipeline?.setCancelled()
70+
}
6671
}
6772

6873
class Settings {

0 commit comments

Comments
 (0)