Skip to content

Commit b434cee

Browse files
authored
SD XL support (#67)
* Support SDXL base model. * Add support for 4.5 bit SDXL model. * Update package version
1 parent 839c9ba commit b434cee

File tree

4 files changed

+44
-12
lines changed

4 files changed

+44
-12
lines changed

Diffusion.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"location" : "https://github.com/apple/ml-stable-diffusion",
1616
"state" : {
1717
"branch" : "main",
18-
"revision" : "b61c9aea05370d4bc06fce2dc00a002b21f13da5"
18+
"revision" : "8cf34376f9faf87fc6fe63159e5fae6cbbb71de6"
1919
}
2020
},
2121
{

Diffusion/Common/ModelInfo.swift

+24-3
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,19 @@ struct ModelInfo {
4343
/// Are weights quantized? This is only used to decide whether to use `reduceMemory`
4444
let quantized: Bool
4545

46+
/// Whether this is a Stable Diffusion XL model
47+
// TODO: retrieve from remote config
48+
let isXL: Bool
49+
4650
//TODO: refactor all these properties
4751
init(modelId: String, modelVersion: String,
4852
originalAttentionSuffix: String = "original_compiled",
4953
splitAttentionSuffix: String = "split_einsum_compiled",
5054
splitAttentionV2Suffix: String = "split_einsum_v2_compiled",
5155
supportsEncoder: Bool = false,
5256
supportsAttentionV2: Bool = false,
53-
quantized: Bool = false) {
57+
quantized: Bool = false,
58+
isXL: Bool = false) {
5459
self.modelId = modelId
5560
self.modelVersion = modelVersion
5661
self.originalAttentionSuffix = originalAttentionSuffix
@@ -59,6 +64,7 @@ struct ModelInfo {
5964
self.supportsEncoder = supportsEncoder
6065
self.supportsAttentionV2 = supportsAttentionV2
6166
self.quantized = quantized
67+
self.isXL = isXL
6268
}
6369
}
6470

@@ -165,6 +171,21 @@ extension ModelInfo {
165171
modelVersion: "OFA-Sys/small-stable-diffusion-v0"
166172
)
167173

174+
static let xl = ModelInfo(
175+
modelId: "apple/coreml-stable-diffusion-xl-base",
176+
modelVersion: "Stable Diffusion XL base",
177+
supportsEncoder: true,
178+
isXL: true
179+
)
180+
181+
static let xlmbp = ModelInfo(
182+
modelId: "apple/coreml-stable-diffusion-mixed-bit-palettization",
183+
modelVersion: "Stable Diffusion XL base [4.5 bit]",
184+
supportsEncoder: true,
185+
quantized: true,
186+
isXL: true
187+
)
188+
168189
static let MODELS: [ModelInfo] = {
169190
if deviceSupportsQuantization {
170191
return [
@@ -176,15 +197,15 @@ extension ModelInfo {
176197
ModelInfo.v2Palettized,
177198
ModelInfo.v21Base,
178199
ModelInfo.v21Palettized,
179-
ModelInfo.ofaSmall
200+
ModelInfo.xl,
201+
ModelInfo.xlmbp
180202
]
181203
} else {
182204
return [
183205
ModelInfo.v14Base,
184206
ModelInfo.v15Base,
185207
ModelInfo.v2Base,
186208
ModelInfo.v21Base,
187-
ModelInfo.ofaSmall
188209
]
189210
}
190211
}()

Diffusion/Common/Pipeline/Pipeline.swift

+2-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ struct GenerationResult {
4040
}
4141

4242
class Pipeline {
43-
let pipeline: StableDiffusionPipeline
43+
let pipeline: StableDiffusionPipelineProtocol
4444
let maxSeed: UInt32
4545

4646
var progress: StableDiffusionProgress? = nil {
@@ -52,7 +52,7 @@ class Pipeline {
5252

5353
private var canceled = false
5454

55-
init(_ pipeline: StableDiffusionPipeline, maxSeed: UInt32 = UInt32.max) {
55+
init(_ pipeline: StableDiffusionPipelineProtocol, maxSeed: UInt32 = UInt32.max) {
5656
self.pipeline = pipeline
5757
self.maxSeed = maxSeed
5858
}

Diffusion/Common/Pipeline/PipelineLoader.swift

+17-6
Original file line numberDiff line numberDiff line change
@@ -161,15 +161,26 @@ extension PipelineLoader {
161161
state = .readyOnDisk
162162
}
163163

164-
func load(url: URL) async throws -> StableDiffusionPipeline {
164+
func load(url: URL) async throws -> StableDiffusionPipelineProtocol {
165165
let beginDate = Date()
166166
let configuration = MLModelConfiguration()
167167
configuration.computeUnits = computeUnits
168-
let pipeline = try StableDiffusionPipeline(resourcesAt: url,
169-
controlNet: [],
170-
configuration: configuration,
171-
disableSafety: false,
172-
reduceMemory: model.reduceMemory)
168+
let pipeline: StableDiffusionPipelineProtocol
169+
if model.isXL {
170+
if #available(macOS 14.0, iOS 17.0, *) {
171+
pipeline = try StableDiffusionXLPipeline(resourcesAt: url,
172+
configuration: configuration,
173+
reduceMemory: model.reduceMemory)
174+
} else {
175+
throw "Stable Diffusion XL requires macOS 14"
176+
}
177+
} else {
178+
pipeline = try StableDiffusionPipeline(resourcesAt: url,
179+
controlNet: [],
180+
configuration: configuration,
181+
disableSafety: false,
182+
reduceMemory: model.reduceMemory)
183+
}
173184
try pipeline.loadResources()
174185
print("Pipeline loaded in \(Date().timeIntervalSince(beginDate))")
175186
state = .loaded

0 commit comments

Comments
 (0)