Skip to content

Commit a561fae

Browse files
Add sd3 pipeline and models (#96)
* Add sd3 pipeline and models * Add sd3 pipeline and models * Update minimum deployment target for macOS * Add checks for ANE support when selecting compute unit * Fix progress for non-background downloads
1 parent 17ee716 commit a561fae

File tree

10 files changed

+236
-65
lines changed

10 files changed

+236
-65
lines changed

Diffusion-macOS/ControlsView.swift

+43-5
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,15 @@ struct ControlsView: View {
108108
return
109109
}
110110

111+
if !model.supportsNeuralEngine && generation.computeUnits == .cpuAndNeuralEngine {
112+
// Reset compute units to GPU if Neural Engine is not supported
113+
Settings.shared.userSelectedComputeUnits = .cpuAndGPU
114+
resetComputeUnitsState()
115+
print("Neural Engine not supported for model \(model), switching to GPU")
116+
} else {
117+
resetComputeUnitsState()
118+
}
119+
111120
Settings.shared.currentModel = model
112121

113122
pipelineLoader?.cancel()
@@ -155,9 +164,15 @@ struct ControlsView: View {
155164
VStack {
156165
Spacer()
157166
PromptTextField(text: $generation.positivePrompt, isPositivePrompt: true, model: $model)
167+
.onChange(of: generation.positivePrompt) { prompt in
168+
Settings.shared.prompt = prompt
169+
}
158170
.padding(.top, 5)
159171
Spacer()
160172
PromptTextField(text: $generation.negativePrompt, isPositivePrompt: false, model: $model)
173+
.onChange(of: generation.negativePrompt) { negativePrompt in
174+
Settings.shared.negativePrompt = negativePrompt
175+
}
161176
.padding(.bottom, 5)
162177
Spacer()
163178
}
@@ -242,7 +257,11 @@ struct ControlsView: View {
242257
Text("Guidance Scale")
243258
Spacer()
244259
Text(guidanceScaleValue)
245-
}.padding(.leading, 10)
260+
}
261+
.onChange(of: generation.guidanceScale) { guidanceScale in
262+
Settings.shared.guidanceScale = guidanceScale
263+
}
264+
.padding(.leading, 10)
246265
} label: {
247266
HStack {
248267
Label("Guidance Scale", systemImage: "scalemass").foregroundColor(.secondary)
@@ -269,7 +288,11 @@ struct ControlsView: View {
269288
Text("Steps")
270289
Spacer()
271290
Text("\(Int(generation.steps))")
272-
}.padding(.leading, 10)
291+
}
292+
.onChange(of: generation.steps) { steps in
293+
Settings.shared.stepCount = steps
294+
}
295+
.padding(.leading, 10)
273296
} label: {
274297
HStack {
275298
Label("Step count", systemImage: "square.3.layers.3d.down.left").foregroundColor(.secondary)
@@ -295,7 +318,11 @@ struct ControlsView: View {
295318
Text("Previews")
296319
Spacer()
297320
Text("\(Int(generation.previews))")
298-
}.padding(.leading, 10)
321+
}
322+
.onChange(of: generation.previews) { previews in
323+
Settings.shared.previewCount = previews
324+
}
325+
.padding(.leading, 10)
299326
} label: {
300327
HStack {
301328
Label("Preview count", systemImage: "eye.square").foregroundColor(.secondary)
@@ -334,25 +361,32 @@ struct ControlsView: View {
334361
seedHelp($showSeedHelp)
335362
}
336363
} else {
337-
Text("\(Int(generation.seed))")
364+
Text(generation.seed.formatted(.number.grouping(.never)))
338365
}
339366
}
340367
.foregroundColor(.secondary)
341368
}
342369

343370
if Capabilities.hasANE {
344371
Divider()
372+
let isNeuralEngineDisabled = !(ModelInfo.from(modelVersion: model)?.supportsNeuralEngine ?? true)
345373
DisclosureGroup(isExpanded: $disclosedAdvanced) {
346374
HStack {
347375
Picker(selection: $generation.computeUnits, label: Text("Use")) {
348376
Text("GPU").tag(ComputeUnits.cpuAndGPU)
349-
Text("Neural Engine").tag(ComputeUnits.cpuAndNeuralEngine)
377+
Text("Neural Engine\(isNeuralEngineDisabled ? " (unavailable)" : "")")
378+
.foregroundColor(isNeuralEngineDisabled ? .secondary : .primary)
379+
.tag(ComputeUnits.cpuAndNeuralEngine)
350380
Text("GPU and Neural Engine").tag(ComputeUnits.all)
351381
}.pickerStyle(.radioGroup).padding(.leading)
352382
Spacer()
353383
}
354384
.onChange(of: generation.computeUnits) { units in
355385
guard let currentModel = ModelInfo.from(modelVersion: model) else { return }
386+
if isNeuralEngineDisabled && units == .cpuAndNeuralEngine {
387+
resetComputeUnitsState()
388+
return
389+
}
356390
let variantDownloaded = isModelDownloaded(currentModel, computeUnits: units)
357391
if variantDownloaded {
358392
updateComputeUnitsState()
@@ -430,8 +464,10 @@ struct ControlsView: View {
430464
set: { newValue in
431465
if let seed = UInt32(newValue) {
432466
generation.seed = seed
467+
Settings.shared.seed = seed
433468
} else {
434469
generation.seed = 0
470+
Settings.shared.seed = 0
435471
}
436472
}
437473
)
@@ -442,8 +478,10 @@ struct ControlsView: View {
442478
.onChange(of: seedBinding.wrappedValue, perform: { newValue in
443479
if let seed = UInt32(newValue) {
444480
generation.seed = seed
481+
Settings.shared.seed = seed
445482
} else {
446483
generation.seed = 0
484+
Settings.shared.seed = 0
447485
}
448486
})
449487
.onReceive(Just(seedBinding.wrappedValue)) { newValue in

Diffusion.xcodeproj/project.pbxproj

+16-16
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
objects = {
88

99
/* Begin PBXBuildFile section */
10+
16AFDD4F2C1B7D6200536A62 /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = 16AFDD4E2C1B7D6200536A62 /* StableDiffusion */; };
11+
16AFDD512C1B7D6700536A62 /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = 16AFDD502C1B7D6700536A62 /* StableDiffusion */; };
1012
8C4B32042A770C1D0090EF17 /* DiffusionImage+macOS.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8C4B32032A770C1D0090EF17 /* DiffusionImage+macOS.swift */; };
1113
8C4B32062A770C300090EF17 /* DiffusionImage+iOS.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8C4B32052A770C300090EF17 /* DiffusionImage+iOS.swift */; };
1214
8C4B32082A77F90C0090EF17 /* Utils_iOS.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8C4B32072A77F90C0090EF17 /* Utils_iOS.swift */; };
@@ -16,8 +18,6 @@
1618
8CEEB7D92A54C88C00C23829 /* DiffusionImage.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8CEEB7D82A54C88C00C23829 /* DiffusionImage.swift */; };
1719
8CEEB7DA2A54C88C00C23829 /* DiffusionImage.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8CEEB7D82A54C88C00C23829 /* DiffusionImage.swift */; };
1820
EB067F872992E561004D1AD9 /* HelpContent.swift in Sources */ = {isa = PBXBuildFile; fileRef = EB067F862992E561004D1AD9 /* HelpContent.swift */; };
19-
EB25B3D62A3A2DC4000E25A1 /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = EB25B3D52A3A2DC4000E25A1 /* StableDiffusion */; };
20-
EB25B3D82A3A2DD5000E25A1 /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = EB25B3D72A3A2DD5000E25A1 /* StableDiffusion */; };
2121
EB560F0429A3C20800C0F8B8 /* Capabilities.swift in Sources */ = {isa = PBXBuildFile; fileRef = EB560F0329A3C20800C0F8B8 /* Capabilities.swift */; };
2222
EBB5BA5329425BEE003A2A5B /* PipelineLoader.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBB5BA5229425BEE003A2A5B /* PipelineLoader.swift */; };
2323
EBB5BA5A29426E06003A2A5B /* Downloader.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBB5BA5929426E06003A2A5B /* Downloader.swift */; };
@@ -116,8 +116,8 @@
116116
isa = PBXFrameworksBuildPhase;
117117
buildActionMask = 2147483647;
118118
files = (
119-
EB25B3D62A3A2DC4000E25A1 /* StableDiffusion in Frameworks */,
120119
EBB5BA5D294504DE003A2A5B /* ZIPFoundation in Frameworks */,
120+
16AFDD512C1B7D6700536A62 /* StableDiffusion in Frameworks */,
121121
);
122122
runOnlyForDeploymentPostprocessing = 0;
123123
};
@@ -140,7 +140,7 @@
140140
buildActionMask = 2147483647;
141141
files = (
142142
F155203C297118E700DC009B /* CompactSlider in Frameworks */,
143-
EB25B3D82A3A2DD5000E25A1 /* StableDiffusion in Frameworks */,
143+
16AFDD4F2C1B7D6200536A62 /* StableDiffusion in Frameworks */,
144144
EBDD7DAF29731FB300C1C4B2 /* ZIPFoundation in Frameworks */,
145145
);
146146
runOnlyForDeploymentPostprocessing = 0;
@@ -318,7 +318,7 @@
318318
name = Diffusion;
319319
packageProductDependencies = (
320320
EBB5BA5C294504DE003A2A5B /* ZIPFoundation */,
321-
EB25B3D52A3A2DC4000E25A1 /* StableDiffusion */,
321+
16AFDD502C1B7D6700536A62 /* StableDiffusion */,
322322
);
323323
productName = Diffusion;
324324
productReference = EBE755C5293E37DD00806B32 /* Diffusion.app */;
@@ -378,7 +378,7 @@
378378
packageProductDependencies = (
379379
F155203B297118E700DC009B /* CompactSlider */,
380380
EBDD7DAE29731FB300C1C4B2 /* ZIPFoundation */,
381-
EB25B3D72A3A2DD5000E25A1 /* StableDiffusion */,
381+
16AFDD4E2C1B7D6200536A62 /* StableDiffusion */,
382382
);
383383
productName = "Diffusion-macOS";
384384
productReference = F15520212971093300DC009B /* Diffusers.app */;
@@ -422,7 +422,7 @@
422422
packageReferences = (
423423
EBB5BA5B294504DE003A2A5B /* XCRemoteSwiftPackageReference "ZIPFoundation" */,
424424
F155203A297118E600DC009B /* XCRemoteSwiftPackageReference "CompactSlider" */,
425-
EB25B3D42A3A2DC4000E25A1 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */,
425+
16AFDD4D2C1B7D4800536A62 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */,
426426
);
427427
productRefGroup = EBE755C6293E37DD00806B32 /* Products */;
428428
projectDirPath = "";
@@ -876,7 +876,7 @@
876876
"$(inherited)",
877877
"@executable_path/../Frameworks",
878878
);
879-
MACOSX_DEPLOYMENT_TARGET = 13.1;
879+
MACOSX_DEPLOYMENT_TARGET = 14.0;
880880
SDKROOT = macosx;
881881
SWIFT_EMIT_LOC_STRINGS = YES;
882882
SWIFT_VERSION = 5.0;
@@ -904,7 +904,7 @@
904904
"$(inherited)",
905905
"@executable_path/../Frameworks",
906906
);
907-
MACOSX_DEPLOYMENT_TARGET = 13.1;
907+
MACOSX_DEPLOYMENT_TARGET = 14.0;
908908
PRODUCT_BUNDLE_IDENTIFIER = com.huggingface.Diffusers;
909909
SDKROOT = macosx;
910910
SWIFT_EMIT_LOC_STRINGS = YES;
@@ -963,9 +963,9 @@
963963
/* End XCConfigurationList section */
964964

965965
/* Begin XCRemoteSwiftPackageReference section */
966-
EB25B3D42A3A2DC4000E25A1 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */ = {
966+
16AFDD4D2C1B7D4800536A62 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */ = {
967967
isa = XCRemoteSwiftPackageReference;
968-
repositoryURL = "https://github.com/apple/ml-stable-diffusion";
968+
repositoryURL = "https://github.com/argmaxinc/ml-stable-diffusion.git";
969969
requirement = {
970970
branch = main;
971971
kind = branch;
@@ -990,18 +990,18 @@
990990
/* End XCRemoteSwiftPackageReference section */
991991

992992
/* Begin XCSwiftPackageProductDependency section */
993-
EB0199482A31FEAF00B133E2 /* StableDiffusion */ = {
993+
16AFDD4E2C1B7D6200536A62 /* StableDiffusion */ = {
994994
isa = XCSwiftPackageProductDependency;
995+
package = 16AFDD4D2C1B7D4800536A62 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */;
995996
productName = StableDiffusion;
996997
};
997-
EB25B3D52A3A2DC4000E25A1 /* StableDiffusion */ = {
998+
16AFDD502C1B7D6700536A62 /* StableDiffusion */ = {
998999
isa = XCSwiftPackageProductDependency;
999-
package = EB25B3D42A3A2DC4000E25A1 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */;
1000+
package = 16AFDD4D2C1B7D4800536A62 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */;
10001001
productName = StableDiffusion;
10011002
};
1002-
EB25B3D72A3A2DD5000E25A1 /* StableDiffusion */ = {
1003+
EB0199482A31FEAF00B133E2 /* StableDiffusion */ = {
10031004
isa = XCSwiftPackageProductDependency;
1004-
package = EB25B3D42A3A2DC4000E25A1 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */;
10051005
productName = StableDiffusion;
10061006
};
10071007
EBB5BA5C294504DE003A2A5B /* ZIPFoundation */ = {

Diffusion.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{
2+
"originHash" : "e97aab54879429ea40e58df49ffe4eef5228d95a28a7cf4d5dca9204c33564e1",
23
"pins" : [
34
{
45
"identity" : "compactslider",
@@ -12,19 +13,19 @@
1213
{
1314
"identity" : "ml-stable-diffusion",
1415
"kind" : "remoteSourceControl",
15-
"location" : "https://github.com/apple/ml-stable-diffusion",
16+
"location" : "https://github.com/argmaxinc/ml-stable-diffusion.git",
1617
"state" : {
1718
"branch" : "main",
18-
"revision" : "d456a972cd7d84cab2ec353a29896d59b8602248"
19+
"revision" : "d1f0604fab5345011e0b9f5b87ee0c155612565f"
1920
}
2021
},
2122
{
2223
"identity" : "swift-argument-parser",
2324
"kind" : "remoteSourceControl",
2425
"location" : "https://github.com/apple/swift-argument-parser.git",
2526
"state" : {
26-
"revision" : "fddd1c00396eed152c45a46bea9f47b98e59301d",
27-
"version" : "1.2.0"
27+
"revision" : "0fbc8848e389af3bb55c182bc19ca9d5dc2f255b",
28+
"version" : "1.4.0"
2829
}
2930
},
3031
{
@@ -37,5 +38,5 @@
3738
}
3839
}
3940
],
40-
"version" : 2
41+
"version" : 3
4142
}
Loading

Diffusion/Common/Downloader.swift

+9-3
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,14 @@ class Downloader: NSObject, ObservableObject {
2828
self.destination = destination
2929
super.init()
3030

31+
var config = URLSessionConfiguration.default
32+
#if !os(macOS)
3133
// .background allows downloads to proceed in the background
32-
let config = URLSessionConfiguration.background(withIdentifier: "net.pcuenca.diffusion.download")
34+
// helpful for devices that may not keep the app in the foreground for the download duration
35+
config = URLSessionConfiguration.background(withIdentifier: "net.pcuenca.diffusion.download")
36+
config.isDiscretionary = false
37+
config.sessionSendsLaunchEvents = true
38+
#endif
3339
urlSession = URLSession(configuration: config, delegate: self, delegateQueue: OperationQueue())
3440
downloadState.value = .downloading(0)
3541
urlSession?.getAllTasks { tasks in
@@ -75,8 +81,8 @@ class Downloader: NSObject, ObservableObject {
7581
}
7682

7783
extension Downloader: URLSessionDelegate, URLSessionDownloadDelegate {
78-
func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten _: Int64, totalBytesExpectedToWrite _: Int64) {
79-
downloadState.value = .downloading(downloadTask.progress.fractionCompleted)
84+
func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten: Int64, totalBytesExpectedToWrite: Int64) {
85+
downloadState.value = .downloading(Double(totalBytesWritten) / Double(totalBytesExpectedToWrite))
8086
}
8187

8288
func urlSession(_: URLSession, downloadTask _: URLSessionDownloadTask, didFinishDownloadingTo location: URL) {

Diffusion/Common/ModelInfo.swift

+35-4
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ struct ModelInfo {
3333

3434
/// Suffix of the archive containing the SPLIT_EINSUM_V2 attention variant. Usually something like "split_einsum_v2_compiled"
3535
let splitAttentionV2Suffix: String
36-
36+
37+
/// Whether the archive contains ANE optimized models
38+
let supportsNeuralEngine: Bool
39+
3740
/// Whether the archive contains the VAE Encoder (for image to image tasks). Not yet in use.
3841
let supportsEncoder: Bool
3942

@@ -46,25 +49,33 @@ struct ModelInfo {
4649
/// Whether this is a Stable Diffusion XL model
4750
// TODO: retrieve from remote config
4851
let isXL: Bool
49-
52+
53+
/// Whether this is a Stable Diffusion 3 model
54+
// TODO: retrieve from remote config
55+
let isSD3: Bool
56+
5057
//TODO: refactor all these properties
5158
init(modelId: String, modelVersion: String,
5259
originalAttentionSuffix: String = "original_compiled",
5360
splitAttentionSuffix: String = "split_einsum_compiled",
5461
splitAttentionV2Suffix: String = "split_einsum_v2_compiled",
62+
supportsNeuralEngine: Bool = true,
5563
supportsEncoder: Bool = false,
5664
supportsAttentionV2: Bool = false,
5765
quantized: Bool = false,
58-
isXL: Bool = false) {
66+
isXL: Bool = false,
67+
isSD3: Bool = false) {
5968
self.modelId = modelId
6069
self.modelVersion = modelVersion
6170
self.originalAttentionSuffix = originalAttentionSuffix
6271
self.splitAttentionSuffix = splitAttentionSuffix
6372
self.splitAttentionV2Suffix = splitAttentionV2Suffix
73+
self.supportsNeuralEngine = supportsNeuralEngine
6474
self.supportsEncoder = supportsEncoder
6575
self.supportsAttentionV2 = supportsAttentionV2
6676
self.quantized = quantized
6777
self.isXL = isXL
78+
self.isSD3 = isSD3
6879
}
6980
}
7081

@@ -202,6 +213,24 @@ extension ModelInfo {
202213
isXL: true
203214
)
204215

216+
static let sd3 = ModelInfo(
217+
modelId: "argmaxinc/coreml-stable-diffusion-3-medium",
218+
modelVersion: "SD3 medium (512, macOS)",
219+
supportsNeuralEngine: false, // TODO: support SD3 on ANE
220+
supportsEncoder: false,
221+
quantized: false,
222+
isSD3: true
223+
)
224+
225+
static let sd3highres = ModelInfo(
226+
modelId: "argmaxinc/coreml-stable-diffusion-3-medium-1024-t5",
227+
modelVersion: "SD3 medium (1024, T5, macOS)",
228+
supportsNeuralEngine: false, // TODO: support SD3 on ANE
229+
supportsEncoder: false,
230+
quantized: false,
231+
isSD3: true
232+
)
233+
205234
static let MODELS: [ModelInfo] = {
206235
if deviceSupportsQuantization {
207236
var models = [
@@ -218,7 +247,9 @@ extension ModelInfo {
218247
models.append(contentsOf: [
219248
ModelInfo.xl,
220249
ModelInfo.xlWithRefiner,
221-
ModelInfo.xlmbp
250+
ModelInfo.xlmbp,
251+
ModelInfo.sd3,
252+
ModelInfo.sd3highres,
222253
])
223254
} else {
224255
models.append(ModelInfo.xlmbpChunked)

0 commit comments

Comments
 (0)