Skip to content

Commit 3b75e75

Browse files
authoredFeb 22, 2023
Allow all compute units to be selected (huggingface#30)
* Allow all compute units to be selected by the user. * Remove commented code. * Simplify labels. * Remove warning * Align picker left * Apply suggestions from code review
1 parent d69e554 commit 3b75e75

File tree

4 files changed

+63
-38
lines changed

4 files changed

+63
-38
lines changed
 

‎Diffusion-macOS/ControlsView.swift

+27-20
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ struct ControlsView: View {
5757
@State private var disclosedSteps = false
5858
@State private var disclosedSeed = false
5959
@State private var disclosedAdvanced = false
60-
@State private var useANE = (Settings.shared.userSelectedAttentionVariant ?? ModelInfo.defaultAttention) == .splitEinsum
6160

6261
// TODO: refactor download with similar code in Loading.swift (iOS)
6362
@State private var stateSubscriber: Cancellable?
@@ -82,14 +81,18 @@ struct ControlsView: View {
8281
mustShowSafetyCheckerDisclaimer = generation.disableSafety && !Settings.shared.safetyCheckerDisclaimerShown
8382
}
8483

85-
func updateANEState() {
86-
Settings.shared.userSelectedAttentionVariant = useANE ? .splitEinsum : .original
84+
func updateComputeUnitsState() {
85+
Settings.shared.userSelectedComputeUnits = generation.computeUnits
8786
modelDidChange(model: Settings.shared.currentModel)
8887
}
8988

89+
func resetComputeUnitsState() {
90+
generation.computeUnits = Settings.shared.userSelectedComputeUnits ?? ModelInfo.defaultComputeUnits
91+
}
92+
9093
func modelDidChange(model: ModelInfo) {
91-
guard pipelineLoader?.model != model || pipelineLoader?.variant != Settings.shared.userSelectedAttentionVariant else {
92-
print("Reusing same model \(model) with attention \(String(describing: Settings.shared.userSelectedAttentionVariant))")
94+
guard pipelineLoader?.model != model || pipelineLoader?.computeUnits != generation.computeUnits else {
95+
print("Reusing same model \(model) with units \(generation.computeUnits)")
9396
return
9497
}
9598

@@ -99,7 +102,7 @@ struct ControlsView: View {
99102
pipelineLoader?.cancel()
100103
pipelineState = .downloading(0)
101104
Task.init {
102-
let loader = PipelineLoader(model: model, variant: Settings.shared.userSelectedAttentionVariant, maxSeed: maxSeed)
105+
let loader = PipelineLoader(model: model, computeUnits: generation.computeUnits, maxSeed: maxSeed)
103106
self.pipelineLoader = loader
104107
stateSubscriber = loader.statePublisher.sink { state in
105108
DispatchQueue.main.async {
@@ -128,8 +131,8 @@ struct ControlsView: View {
128131
}
129132
}
130133

131-
func isModelDownloaded(_ model: ModelInfo, variant: AttentionVariant? = nil) -> Bool {
132-
PipelineLoader(model: model, variant: variant ?? Settings.shared.userSelectedAttentionVariant).ready
134+
func isModelDownloaded(_ model: ModelInfo, computeUnits: ComputeUnits? = nil) -> Bool {
135+
PipelineLoader(model: model, computeUnits: computeUnits ?? generation.computeUnits).ready
133136
}
134137

135138
func modelLabel(_ model: ModelInfo) -> Text {
@@ -301,21 +304,25 @@ struct ControlsView: View {
301304
Divider()
302305
DisclosureGroup(isExpanded: $disclosedAdvanced) {
303306
HStack {
304-
Toggle("Use Neural Engine", isOn: $useANE).onChange(of: useANE) { value in
305-
guard let currentModel = ModelInfo.from(modelVersion: model) else { return }
306-
let variantDownloaded = isModelDownloaded(currentModel, variant: useANE ? .splitEinsum : .original)
307-
if variantDownloaded {
308-
updateANEState()
309-
} else {
310-
mustShowModelDownloadDisclaimer.toggle()
311-
}
312-
}
313-
.padding(.leading, 10)
307+
Picker(selection: $generation.computeUnits, label: Text("Use")) {
308+
Text("GPU").tag(ComputeUnits.cpuAndGPU)
309+
Text("Neural Engine").tag(ComputeUnits.cpuAndNeuralEngine)
310+
Text("GPU and Neural Engine").tag(ComputeUnits.all)
311+
}.pickerStyle(.radioGroup).padding(.leading)
314312
Spacer()
315313
}
314+
.onChange(of: generation.computeUnits) { units in
315+
guard let currentModel = ModelInfo.from(modelVersion: model) else { return }
316+
let variantDownloaded = isModelDownloaded(currentModel, computeUnits: units)
317+
if variantDownloaded {
318+
updateComputeUnitsState()
319+
} else {
320+
mustShowModelDownloadDisclaimer.toggle()
321+
}
322+
}
316323
.alert("Download Required", isPresented: $mustShowModelDownloadDisclaimer, actions: {
317-
Button("Cancel", role: .destructive) { useANE.toggle() }
318-
Button("Download", role: .cancel) { updateANEState() }
324+
Button("Cancel", role: .destructive) { resetComputeUnitsState() }
325+
Button("Download", role: .cancel) { updateComputeUnitsState() }
319326
}, message: {
320327
Text("This setting requires a new version of the selected model.")
321328
})

‎Diffusion/ModelInfo.swift

+9-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ enum AttentionVariant: String {
1313
case splitEinsum
1414
}
1515

16+
extension AttentionVariant {
17+
var defaultComputeUnits: MLComputeUnits { self == .original ? .cpuAndGPU : .cpuAndNeuralEngine }
18+
}
19+
1620
struct ModelInfo {
1721
/// Hugging Face model Id that contains .zip archives with compiled Core ML models
1822
let modelId: String
@@ -39,6 +43,7 @@ struct ModelInfo {
3943
}
4044

4145
extension ModelInfo {
46+
//TODO: set compute units instead and derive variant from it
4247
static var defaultAttention: AttentionVariant {
4348
guard runningOnMac else { return .splitEinsum }
4449
#if os(macOS)
@@ -49,9 +54,10 @@ extension ModelInfo {
4954
#endif
5055
}
5156

52-
var bestAttention: AttentionVariant {
53-
return ModelInfo.defaultAttention
54-
}
57+
static var defaultComputeUnits: MLComputeUnits { defaultAttention.defaultComputeUnits }
58+
59+
var bestAttention: AttentionVariant { ModelInfo.defaultAttention }
60+
var defaultComputeUnits: MLComputeUnits { bestAttention.defaultComputeUnits }
5561

5662
func modelURL(for variant: AttentionVariant) -> URL {
5763
// Pattern: https://huggingface.co/pcuenq/coreml-stable-diffusion/resolve/main/coreml-stable-diffusion-v1-5_original_compiled.zip

‎Diffusion/Pipeline/PipelineLoader.swift

+13-7
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ class PipelineLoader {
1818
static let models = Path.applicationSupport / "hf-diffusion-models"
1919

2020
let model: ModelInfo
21-
let variant: AttentionVariant
21+
let computeUnits: ComputeUnits
2222
let maxSeed: UInt32
2323

2424
private var downloadSubscriber: Cancellable?
2525

26-
init(model: ModelInfo, variant: AttentionVariant? = nil, maxSeed: UInt32 = UInt32.max) {
26+
init(model: ModelInfo, computeUnits: ComputeUnits? = nil, maxSeed: UInt32 = UInt32.max) {
2727
self.model = model
28-
self.variant = variant ?? model.bestAttention
28+
self.computeUnits = computeUnits ?? model.defaultComputeUnits
2929
self.maxSeed = maxSeed
3030
state = .undetermined
3131
setInitialState()
@@ -98,11 +98,17 @@ extension PipelineLoader {
9898
return compiledPath.exists
9999
}
100100

101-
// TODO: measure performance on different devices, disassociate from variant
102-
var computeUnits: MLComputeUnits {
103-
variant == .original ? .cpuAndGPU : .cpuAndNeuralEngine
101+
var variant: AttentionVariant {
102+
switch computeUnits {
103+
case .cpuOnly : return .original // Not supported yet
104+
case .cpuAndGPU : return .original
105+
case .cpuAndNeuralEngine: return .splitEinsum
106+
case .all : return .splitEinsum
107+
@unknown default:
108+
fatalError("Unknown MLComputeUnits")
109+
}
104110
}
105-
111+
106112
// TODO: maybe receive Progress to add another progress as child
107113
func prepare() async throws -> Pipeline {
108114
do {

‎Diffusion/State.swift

+14-8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import Combine
1010
import SwiftUI
1111
import StableDiffusion
12+
import CoreML
1213

1314
let DEFAULT_MODEL = ModelInfo.v2Base
1415
let DEFAULT_PROMPT = "Labrador in the style of Vermeer"
@@ -21,6 +22,8 @@ enum GenerationState {
2122
case failed(Error)
2223
}
2324

25+
typealias ComputeUnits = MLComputeUnits
26+
2427
class GenerationContext: ObservableObject {
2528
let scheduler = StableDiffusionScheduler.dpmSolverMultistepScheduler
2629

@@ -48,6 +51,8 @@ class GenerationContext: ObservableObject {
4851
@Published var seed = -1.0
4952
@Published var guidanceScale = 7.5
5053
@Published var disableSafety = false
54+
55+
@Published var computeUnits: ComputeUnits = Settings.shared.userSelectedComputeUnits ?? ModelInfo.defaultComputeUnits
5156

5257
private var progressSubscriber: Cancellable?
5358

@@ -78,14 +83,14 @@ class Settings {
7883
enum Keys: String {
7984
case model
8085
case safetyCheckerDisclaimer
81-
case variant
86+
case computeUnits
8287
}
8388

8489
private init() {
8590
defaults.register(defaults: [
8691
Keys.model.rawValue: ModelInfo.v2Base.modelId,
8792
Keys.safetyCheckerDisclaimer.rawValue: false,
88-
Keys.variant.rawValue: "- default -"
93+
Keys.computeUnits.rawValue: -1 // Use default
8994
])
9095
}
9196

@@ -109,15 +114,16 @@ class Settings {
109114
}
110115

111116
/// Returns the option selected by the user, if overridden
112-
/// `nil` means: guess best for this {model, device}
113-
var userSelectedAttentionVariant: AttentionVariant? {
117+
/// `nil` means: guess best
118+
var userSelectedComputeUnits: ComputeUnits? {
114119
set {
115-
// Any String other than the supported ones would cause `get` to return `nil`
116-
defaults.set(newValue?.rawValue ?? "- default -", forKey: Keys.variant.rawValue)
120+
// Any value other than the supported ones would cause `get` to return `nil`
121+
defaults.set(newValue?.rawValue ?? -1, forKey: Keys.computeUnits.rawValue)
117122
}
118123
get {
119-
let current = defaults.string(forKey: Keys.variant.rawValue)
120-
return AttentionVariant(rawValue: current ?? "")
124+
let current = defaults.integer(forKey: Keys.computeUnits.rawValue)
125+
guard current != -1 else { return nil }
126+
return ComputeUnits(rawValue: current)
121127
}
122128
}
123129
}

0 commit comments

Comments
 (0)
Please sign in to comment.