Skip to content

Commit 367df55

Browse files
committed
Pass seed back, report generation errors.
1 parent e2cc04f commit 367df55

File tree

6 files changed

+57
-40
lines changed

6 files changed

+57
-40
lines changed

Diffusion-macOS/ContentView.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ struct ContentView: View {
6666
@StateObject var generation = GenerationContext()
6767

6868
func toolbar() -> any View {
69-
if case .complete(let prompt, let cgImage, _) = generation.state, let cgImage = cgImage {
69+
if case .complete(let prompt, let cgImage, let seed, _) = generation.state, let cgImage = cgImage {
7070
return ShareButtons(image: cgImage, name: prompt)
7171
} else {
7272
let prompt = DEFAULT_PROMPT

Diffusion-macOS/GeneratedImageView.swift

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct GeneratedImageView: View {
2323
let fraction = Double(step) / Double(progress.stepCount)
2424
let label = "Step \(step) of \(progress.stepCount)"
2525
return AnyView(ProgressView(label, value: fraction, total: 1).padding())
26-
case .complete(_, let image, _):
26+
case .complete(_, let image, _, _):
2727
guard let theImage = image else {
2828
return AnyView(Image(systemName: "exclamationmark.triangle").resizable())
2929
}
@@ -32,6 +32,8 @@ struct GeneratedImageView: View {
3232
.resizable()
3333
.clipShape(RoundedRectangle(cornerRadius: 20))
3434
)
35+
case .failed(_):
36+
return AnyView(Image(systemName: "exclamationmark.triangle").resizable())
3537
}
3638
}
3739
}

Diffusion-macOS/StatusView.swift

+39-31
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,42 @@ struct StatusView: View {
1818
if case .running = generation.state { return }
1919
Task {
2020
generation.state = .running(nil)
21-
let interval: TimeInterval?
22-
let image: CGImage?
23-
(image, interval) = await generation.generate() ?? (nil, nil)
24-
generation.state = .complete(generation.positivePrompt, image, interval)
21+
do {
22+
let result = try await generation.generate()
23+
generation.state = .complete(generation.positivePrompt, result.image, result.lastSeed, result.interval)
24+
} catch {
25+
generation.state = .failed(error)
26+
}
2527
}
2628
}
27-
29+
30+
func errorWithDetails(_ message: String, error: Error) -> any View {
31+
HStack {
32+
Text(message)
33+
Spacer()
34+
Button {
35+
showErrorPopover.toggle()
36+
} label: {
37+
Image(systemName: "info.circle")
38+
}.buttonStyle(.plain)
39+
.popover(isPresented: $showErrorPopover) {
40+
VStack {
41+
Text(verbatim: "\(error)")
42+
.lineLimit(nil)
43+
.padding(.all, 5)
44+
Button {
45+
showErrorPopover.toggle()
46+
} label: {
47+
Text("Dismiss").frame(maxWidth: 200)
48+
}
49+
.padding(.bottom)
50+
}
51+
.frame(minWidth: 400, idealWidth: 400, maxWidth: 400)
52+
.fixedSize()
53+
}
54+
}
55+
}
56+
2857
func generationStatusView() -> any View {
2958
switch generation.state {
3059
case .startup: return EmptyView()
@@ -42,7 +71,7 @@ struct StatusView: View {
4271
Text("Generating \(Int(round(100*fraction)))%")
4372
Spacer()
4473
}
45-
case .complete(_, let image, let interval):
74+
case .complete(_, let image, let lastSeed, let interval):
4675
guard let _ = image else {
4776
return HStack {
4877
Text("Safety checker triggered, please try a different prompt or seed")
@@ -55,9 +84,11 @@ struct StatusView: View {
5584
Text(intervalString)
5685
Spacer()
5786
}.frame(maxHeight: 25)
87+
case .failed(let error):
88+
return errorWithDetails("Generation error", error: error)
5889
}
5990
}
60-
91+
6192
var body: some View {
6293
switch pipelineState.wrappedValue {
6394
case .downloading(let progress):
@@ -80,30 +111,7 @@ struct StatusView: View {
80111
AnyView(generationStatusView())
81112
}
82113
case .failed(let error):
83-
HStack {
84-
Text("Pipeline loading error")
85-
Spacer()
86-
Button {
87-
showErrorPopover.toggle()
88-
} label: {
89-
Image(systemName: "info.circle")
90-
}.buttonStyle(.plain)
91-
.popover(isPresented: $showErrorPopover) {
92-
VStack {
93-
Text(verbatim: "\(error)")
94-
.lineLimit(nil)
95-
.padding(.all, 5)
96-
Button {
97-
showErrorPopover.toggle()
98-
} label: {
99-
Text("Dismiss").frame(maxWidth: 200)
100-
}
101-
.padding(.bottom)
102-
}
103-
.frame(minWidth: 400, idealWidth: 400, maxWidth: 400)
104-
.fixedSize()
105-
}
106-
}
114+
AnyView(errorWithDetails("Pipeline loading error", error: error))
107115
}
108116
}
109117
}

Diffusion/Pipeline/Pipeline.swift

+8-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ import StableDiffusion
1414

1515
typealias StableDiffusionProgress = StableDiffusionPipeline.Progress
1616

17+
struct GenerationResult {
18+
var image: CGImage?
19+
var lastSeed: UInt32
20+
var interval: TimeInterval?
21+
}
22+
1723
class Pipeline {
1824
let pipeline: StableDiffusionPipeline
1925

@@ -37,7 +43,7 @@ class Pipeline {
3743
seed: UInt32? = nil,
3844
guidanceScale: Float = 7.5,
3945
disableSafety: Bool = false
40-
) throws -> (CGImage, TimeInterval) {
46+
) throws -> GenerationResult {
4147
let beginDate = Date()
4248
print("Generating...")
4349
let theSeed = seed ?? UInt32.random(in: 0..<UInt32.max)
@@ -59,7 +65,7 @@ class Pipeline {
5965

6066
// unwrap the 1 image we asked for
6167
guard let image = images.compactMap({ $0 }).first else { throw "Generation failed" }
62-
return (image, interval)
68+
return GenerationResult(image: image, lastSeed: theSeed, interval: interval)
6369
}
6470

6571
func handleProgress(_ progress: StableDiffusionPipeline.Progress) {

Diffusion/State.swift

+5-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ let DEFAULT_PROMPT = "Labrador in the style of Vermeer"
1616
enum GenerationState {
1717
case startup
1818
case running(StableDiffusionProgress?)
19-
case complete(String, CGImage?, TimeInterval?)
19+
case complete(String, CGImage?, UInt32, TimeInterval?)
20+
case failed(Error)
2021
}
2122

2223
class GenerationContext: ObservableObject {
@@ -49,10 +50,10 @@ class GenerationContext: ObservableObject {
4950

5051
private var progressSubscriber: Cancellable?
5152

52-
func generate() async -> (CGImage, TimeInterval)? {
53-
guard let pipeline = pipeline else { return nil }
53+
func generate() async throws -> GenerationResult {
54+
guard let pipeline = pipeline else { throw "No pipeline" }
5455
let seed = self.seed >= 0 ? UInt32(self.seed) : nil
55-
return try? pipeline.generate(
56+
return try pipeline.generate(
5657
prompt: positivePrompt,
5758
negativePrompt: negativePrompt,
5859
scheduler: scheduler,

Diffusion/Views/TextToImage.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ struct TextToImage: View {
9999
generation.state = .running(nil)
100100
let interval: TimeInterval?
101101
let image: CGImage?
102-
(image, interval) = await generation.generate() ?? (nil, nil)
102+
let result = await generation.generate()
103103
generation.state = .complete(generation.positivePrompt, image, interval)
104104
}
105105
}

0 commit comments

Comments
 (0)