forked from huggingface/swift-coreml-diffusers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathState.swift
129 lines (110 loc) · 3.83 KB
/
State.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
//
// State.swift
// Diffusion
//
// Created by Pedro Cuenca on 17/1/23.
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
//
import Combine
import SwiftUI
import StableDiffusion
import CoreML
let DEFAULT_MODEL = ModelInfo.v2Base
let DEFAULT_PROMPT = "Labrador in the style of Vermeer"
enum GenerationState {
case startup
case running(StableDiffusionProgress?)
case complete(String, CGImage?, UInt32, TimeInterval?)
case userCanceled
case failed(Error)
}
typealias ComputeUnits = MLComputeUnits
class GenerationContext: ObservableObject {
let scheduler = StableDiffusionScheduler.dpmSolverMultistepScheduler
@Published var pipeline: Pipeline? = nil {
didSet {
if let pipeline = pipeline {
progressSubscriber = pipeline
.progressPublisher
.receive(on: DispatchQueue.main)
.sink { progress in
guard let progress = progress else { return }
self.state = .running(progress)
}
}
}
}
@Published var state: GenerationState = .startup
@Published var positivePrompt = DEFAULT_PROMPT
@Published var negativePrompt = ""
// FIXME: Double to support the slider component
@Published var steps = 25.0
@Published var numImages = 1.0
@Published var seed = -1.0
@Published var guidanceScale = 7.5
@Published var disableSafety = false
@Published var computeUnits: ComputeUnits = Settings.shared.userSelectedComputeUnits ?? ModelInfo.defaultComputeUnits
private var progressSubscriber: Cancellable?
func generate() async throws -> GenerationResult {
guard let pipeline = pipeline else { throw "No pipeline" }
let seed = self.seed >= 0 ? UInt32(self.seed) : nil
return try pipeline.generate(
prompt: positivePrompt,
negativePrompt: negativePrompt,
scheduler: scheduler,
numInferenceSteps: Int(steps),
seed: seed,
guidanceScale: Float(guidanceScale),
disableSafety: disableSafety
)
}
func cancelGeneration() {
pipeline?.setCancelled()
}
}
class Settings {
static let shared = Settings()
let defaults = UserDefaults.standard
enum Keys: String {
case model
case safetyCheckerDisclaimer
case computeUnits
}
private init() {
defaults.register(defaults: [
Keys.model.rawValue: ModelInfo.v2Base.modelId,
Keys.safetyCheckerDisclaimer.rawValue: false,
Keys.computeUnits.rawValue: -1 // Use default
])
}
var currentModel: ModelInfo {
set {
defaults.set(newValue.modelId, forKey: Keys.model.rawValue)
}
get {
guard let modelId = defaults.string(forKey: Keys.model.rawValue) else { return DEFAULT_MODEL }
return ModelInfo.from(modelId: modelId) ?? DEFAULT_MODEL
}
}
var safetyCheckerDisclaimerShown: Bool {
set {
defaults.set(newValue, forKey: Keys.safetyCheckerDisclaimer.rawValue)
}
get {
return defaults.bool(forKey: Keys.safetyCheckerDisclaimer.rawValue)
}
}
/// Returns the option selected by the user, if overridden
/// `nil` means: guess best
var userSelectedComputeUnits: ComputeUnits? {
set {
// Any value other than the supported ones would cause `get` to return `nil`
defaults.set(newValue?.rawValue ?? -1, forKey: Keys.computeUnits.rawValue)
}
get {
let current = defaults.integer(forKey: Keys.computeUnits.rawValue)
guard current != -1 else { return nil }
return ComputeUnits(rawValue: current)
}
}
}