Skip to content

Commit 8796695

Browse files
authored
Advanced Settings: ANE (huggingface#28)
* Preparing to allow users to override inference settings. * ANE setting. * Do not show advanced settings if ane is not available.
1 parent 8986b34 commit 8796695

File tree

6 files changed

+133
-26
lines changed

6 files changed

+133
-26
lines changed

Diffusion-macOS/ControlsView.swift

+59-7
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ struct ControlsView: View {
5656
@State private var disclosedGuidance = false
5757
@State private var disclosedSteps = false
5858
@State private var disclosedSeed = false
59+
@State private var disclosedAdvanced = false
60+
@State private var useANE = (Settings.shared.userSelectedAttentionVariant ?? ModelInfo.defaultAttention) == .splitEinsum
5961

6062
// TODO: refactor download with similar code in Loading.swift (iOS)
6163
@State private var stateSubscriber: Cancellable?
@@ -64,28 +66,35 @@ struct ControlsView: View {
6466

6567
// TODO: make this computed, and observable, and easy to read
6668
@State private var mustShowSafetyCheckerDisclaimer = false
67-
69+
@State private var mustShowModelDownloadDisclaimer = false // When changing advanced settings
70+
6871
@State private var showModelsHelp = false
6972
@State private var showPromptsHelp = false
7073
@State private var showGuidanceHelp = false
7174
@State private var showStepsHelp = false
7275
@State private var showSeedHelp = false
73-
76+
@State private var showAdvancedHelp = false
77+
7478
// Reasonable range for the slider
7579
let maxSeed: UInt32 = 1000
7680

7781
func updateSafetyCheckerState() {
7882
mustShowSafetyCheckerDisclaimer = generation.disableSafety && !Settings.shared.safetyCheckerDisclaimerShown
7983
}
8084

85+
func updateANEState() {
86+
Settings.shared.userSelectedAttentionVariant = useANE ? .splitEinsum : .original
87+
modelDidChange(model: Settings.shared.currentModel)
88+
}
89+
8190
func modelDidChange(model: ModelInfo) {
8291
print("Loading model \(model)")
8392
Settings.shared.currentModel = model
8493

8594
pipelineLoader?.cancel()
8695
pipelineState = .downloading(0)
8796
Task.init {
88-
let loader = PipelineLoader(model: model, maxSeed: maxSeed)
97+
let loader = PipelineLoader(model: model, variant: Settings.shared.userSelectedAttentionVariant, maxSeed: maxSeed)
8998
self.pipelineLoader = loader
9099
stateSubscriber = loader.statePublisher.sink { state in
91100
DispatchQueue.main.async {
@@ -114,16 +123,20 @@ struct ControlsView: View {
114123
}
115124
}
116125

126+
func isModelDownloaded(_ model: ModelInfo, variant: AttentionVariant? = nil) -> Bool {
127+
PipelineLoader(model: model, variant: variant ?? Settings.shared.userSelectedAttentionVariant).ready
128+
}
129+
117130
func modelLabel(_ model: ModelInfo) -> Text {
118-
let downloaded = PipelineLoader(model: model).ready
131+
let downloaded = isModelDownloaded(model)
119132
let prefix = downloaded ? "" : "" //"○ "
120133
return Text(prefix).foregroundColor(downloaded ? .accentColor : .secondary) + Text(model.modelVersion)
121134
}
122135

123136
var body: some View {
124137
VStack(alignment: .leading) {
125138

126-
Label("Adjustments", systemImage: "gearshape.2")
139+
Label("Generation Options", systemImage: "gearshape.2")
127140
.font(.headline)
128141
.fontWeight(.bold)
129142
Divider()
@@ -217,7 +230,6 @@ struct ControlsView: View {
217230
}
218231
}.foregroundColor(.secondary)
219232
}
220-
Divider()
221233

222234
DisclosureGroup(isExpanded: $disclosedSteps) {
223235
CompactSlider(value: $generation.steps, in: 0...150, step: 5) {
@@ -244,7 +256,6 @@ struct ControlsView: View {
244256
}
245257
}.foregroundColor(.secondary)
246258
}
247-
Divider()
248259

249260
DisclosureGroup(isExpanded: $disclosedSeed) {
250261
let sliderLabel = generation.seed < 0 ? "Random Seed" : "Seed"
@@ -272,6 +283,47 @@ struct ControlsView: View {
272283
}
273284
}.foregroundColor(.secondary)
274285
}
286+
287+
if hasANE {
288+
Divider()
289+
DisclosureGroup(isExpanded: $disclosedAdvanced) {
290+
HStack {
291+
Toggle("Use Neural Engine", isOn: $useANE).onChange(of: useANE) { value in
292+
guard let currentModel = ModelInfo.from(modelVersion: model) else { return }
293+
let variantDownloaded = isModelDownloaded(currentModel, variant: useANE ? .splitEinsum : .original)
294+
if variantDownloaded {
295+
updateANEState()
296+
} else {
297+
mustShowModelDownloadDisclaimer.toggle()
298+
}
299+
}
300+
.padding(.leading, 10)
301+
Spacer()
302+
}
303+
.alert("Download Required", isPresented: $mustShowModelDownloadDisclaimer, actions: {
304+
Button("Cancel", role: .destructive) { useANE.toggle() }
305+
Button("Download", role: .cancel) { updateANEState() }
306+
}, message: {
307+
Text("This setting requires a new version of the selected model.")
308+
})
309+
} label: {
310+
HStack {
311+
Label("Advanced", systemImage: "terminal").foregroundColor(.secondary)
312+
Spacer()
313+
if disclosedAdvanced {
314+
Button {
315+
showAdvancedHelp.toggle()
316+
} label: {
317+
Image(systemName: "info.circle")
318+
}
319+
.buttonStyle(.plain)
320+
.popover(isPresented: $showAdvancedHelp, arrowEdge: .trailing) {
321+
advancedHelp($showAdvancedHelp)
322+
}
323+
}
324+
}.foregroundColor(.secondary)
325+
}
326+
}
275327
}
276328
}
277329
.disclosureGroupStyle(LabelToggleDisclosureGroupStyle())

Diffusion-macOS/Diffusion_macOSApp.swift

+7
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,10 @@ struct Diffusion_macOSApp: App {
1818
}
1919

2020
let runningOnMac = true
21+
22+
#if canImport(MLCompute)
23+
import MLCompute
24+
let hasANE = MLCDevice.ane() != nil
25+
#else
26+
let hasANE = false
27+
#endif

Diffusion-macOS/HelpContent.swift

+14
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,17 @@ func seedHelp(_ showing: Binding<Bool>) -> some View {
123123
"""
124124
return helpContent(title: "Generation Seed", description: description, showing: showing)
125125
}
126+
127+
func advancedHelp(_ showing: Binding<Bool>) -> some View {
128+
let description =
129+
"""
130+
This section allows you to try different optimization settings.
131+
132+
Diffusers will try to select the best configuration for you, but it may not always be optimal \
133+
for your computer. You can experiment with these settings to verify the combination that works faster \
134+
in your system.
135+
136+
Please, note that these settings may trigger downloads of additional model variants.
137+
"""
138+
return helpContent(title: "Advanced Model Settings", description: description, showing: showing)
139+
}

Diffusion/ModelInfo.swift

+27-15
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88

99
import CoreML
1010

11+
enum AttentionVariant: String {
12+
case original
13+
case splitEinsum
14+
}
15+
1116
struct ModelInfo {
1217
/// Hugging Face model Id that contains .zip archives with compiled Core ML models
1318
let modelId: String
@@ -19,38 +24,45 @@ struct ModelInfo {
1924
let originalAttentionSuffix: String
2025

2126
/// Suffix of the archive containing the SPLIT_EINSUM attention variant. Usually something like "split_einsum_compiled"
22-
let splitAttentionName: String
27+
let splitAttentionSuffix: String
2328

2429
/// Whether the archive contains the VAE Encoder (for image to image tasks). Not yet in use.
2530
let supportsEncoder: Bool
2631

27-
init(modelId: String, modelVersion: String, originalAttentionSuffix: String = "original_compiled", splitAttentionName: String = "split_einsum_compiled", supportsEncoder: Bool = false) {
32+
init(modelId: String, modelVersion: String, originalAttentionSuffix: String = "original_compiled", splitAttentionSuffix: String = "split_einsum_compiled", supportsEncoder: Bool = false) {
2833
self.modelId = modelId
2934
self.modelVersion = modelVersion
3035
self.originalAttentionSuffix = originalAttentionSuffix
31-
self.splitAttentionName = splitAttentionName
36+
self.splitAttentionSuffix = splitAttentionSuffix
3237
self.supportsEncoder = supportsEncoder
3338
}
3439
}
3540

3641
extension ModelInfo {
37-
/// Best variant for the current platform.
38-
/// Currently using `split_einsum` for iOS and `original` for macOS, but could vary depending on model.
39-
var bestURL: URL {
42+
static var defaultAttention: AttentionVariant {
43+
return runningOnMac ? .original : .splitEinsum
44+
}
45+
46+
// TODO: heuristics per {model, device}
47+
var bestAttention: AttentionVariant {
48+
return ModelInfo.defaultAttention
49+
}
50+
51+
func modelURL(for variant: AttentionVariant) -> URL {
4052
// Pattern: https://huggingface.co/pcuenq/coreml-stable-diffusion/resolve/main/coreml-stable-diffusion-v1-5_original_compiled.zip
41-
let suffix = runningOnMac ? originalAttentionSuffix : splitAttentionName
53+
let suffix: String
54+
switch variant {
55+
case .original: suffix = originalAttentionSuffix
56+
case .splitEinsum: suffix = splitAttentionSuffix
57+
}
4258
let repo = modelId.split(separator: "/").last!
4359
return URL(string: "https://huggingface.co/\(modelId)/resolve/main/\(repo)_\(suffix).zip")!
4460
}
4561

46-
/// Best units for current platform.
47-
/// Currently using `cpuAndNeuralEngine` for iOS and `cpuAndGPU` for macOS, but could vary depending on model.
48-
/// .all works for v1.4, but not for v1.5.
49-
// TODO: measure performance on different devices.
50-
var bestComputeUnits: MLComputeUnits {
51-
return runningOnMac ? .cpuAndGPU : .cpuAndNeuralEngine
52-
}
53-
62+
/// Best variant for the current platform.
63+
/// Currently using `split_einsum` for iOS and `original` for macOS, but could vary depending on model.
64+
var bestURL: URL { modelURL(for: bestAttention) }
65+
5466
var reduceMemory: Bool {
5567
return !runningOnMac
5668
}

Diffusion/Pipeline/PipelineLoader.swift

+10-3
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ class PipelineLoader {
1818
static let models = Path.applicationSupport / "hf-diffusion-models"
1919

2020
let model: ModelInfo
21+
let variant: AttentionVariant
2122
let maxSeed: UInt32
2223

2324
private var downloadSubscriber: Cancellable?
2425

25-
init(model: ModelInfo, maxSeed: UInt32 = UInt32.max) {
26+
init(model: ModelInfo, variant: AttentionVariant? = nil, maxSeed: UInt32 = UInt32.max) {
2627
self.model = model
28+
self.variant = variant ?? model.bestAttention
2729
self.maxSeed = maxSeed
2830
state = .undetermined
2931
setInitialState()
@@ -73,7 +75,7 @@ extension PipelineLoader {
7375

7476
extension PipelineLoader {
7577
var url: URL {
76-
return model.bestURL
78+
return model.modelURL(for: variant)
7779
}
7880

7981
var filename: String {
@@ -95,6 +97,11 @@ extension PipelineLoader {
9597
var ready: Bool {
9698
return compiledPath.exists
9799
}
100+
101+
// TODO: measure performance on different devices, disassociate from variant
102+
var computeUnits: MLComputeUnits {
103+
variant == .original ? .cpuAndGPU : .cpuAndNeuralEngine
104+
}
98105

99106
// TODO: maybe receive Progress to add another progress as child
100107
func prepare() async throws -> Pipeline {
@@ -142,7 +149,7 @@ extension PipelineLoader {
142149
func load(url: URL) async throws -> StableDiffusionPipeline {
143150
let beginDate = Date()
144151
let configuration = MLModelConfiguration()
145-
configuration.computeUnits = model.bestComputeUnits
152+
configuration.computeUnits = computeUnits
146153
let pipeline = try StableDiffusionPipeline(resourcesAt: url,
147154
configuration: configuration,
148155
disableSafety: false,

Diffusion/State.swift

+16-1
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,14 @@ class Settings {
7373
enum Keys: String {
7474
case model
7575
case safetyCheckerDisclaimer
76+
case variant
7677
}
7778

7879
private init() {
7980
defaults.register(defaults: [
8081
Keys.model.rawValue: ModelInfo.v2Base.modelId,
81-
Keys.safetyCheckerDisclaimer.rawValue: false
82+
Keys.safetyCheckerDisclaimer.rawValue: false,
83+
Keys.variant.rawValue: "- default -"
8284
])
8385
}
8486

@@ -100,4 +102,17 @@ class Settings {
100102
return defaults.bool(forKey: Keys.safetyCheckerDisclaimer.rawValue)
101103
}
102104
}
105+
106+
/// Returns the option selected by the user, if overridden
107+
/// `nil` means: guess best for this {model, device}
108+
var userSelectedAttentionVariant: AttentionVariant? {
109+
set {
110+
// Any String other than the supported ones would cause `get` to return `nil`
111+
defaults.set(newValue?.rawValue ?? "- default -", forKey: Keys.variant.rawValue)
112+
}
113+
get {
114+
let current = defaults.string(forKey: Keys.variant.rawValue)
115+
return AttentionVariant(rawValue: current ?? "")
116+
}
117+
}
103118
}

0 commit comments

Comments
 (0)