Skip to content

Commit d041577

Browse files
authored
SDXL for iOS, and refiner for Mac (huggingface#83)
* Bring model definitions over from benchmark branch * PrompTextField uses location from PipelineLoader This could be merged in main. * Fix potential model mismatch. * Disable previews only on iOS * Fix macOS build * Update package * Enable SDXL on big iPhones * Update model description * Refiner model * Use Karras by default for XL models * Update URL
1 parent 4202e12 commit d041577

File tree

9 files changed

+50
-31
lines changed

9 files changed

+50
-31
lines changed

Diffusion-macOS/Capabilities.swift

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import Foundation
1010

1111
let runningOnMac = true
1212
let deviceHas6GBOrMore = true
13+
let deviceHas8GBOrMore = true
14+
let BENCHMARK = false
1315

1416
let deviceSupportsQuantization = {
1517
if #available(macOS 14, *) {

Diffusion.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"location" : "https://github.com/apple/ml-stable-diffusion",
1616
"state" : {
1717
"branch" : "main",
18-
"revision" : "ce8ee78e28613d8a2e4c8b56932b236cb57e7e20"
18+
"revision" : "94814cfa41935efd8151a43758360a54e4e3c5d5"
1919
}
2020
},
2121
{

Diffusion/Common/ModelInfo.swift

+30-6
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ extension ModelInfo {
107107
var reduceMemory: Bool {
108108
// Enable on iOS devices, except when using quantization
109109
if runningOnMac { return false }
110+
if isXL { return !deviceHas8GBOrMore }
110111
return !(quantized && deviceHas6GBOrMore)
111112
}
112113
}
@@ -173,33 +174,56 @@ extension ModelInfo {
173174

174175
static let xl = ModelInfo(
175176
modelId: "apple/coreml-stable-diffusion-xl-base",
176-
modelVersion: "Stable Diffusion XL base",
177+
modelVersion: "SDXL base (1024, macOS)",
177178
supportsEncoder: true,
178179
isXL: true
179180
)
180181

182+
static let xlWithRefiner = ModelInfo(
183+
modelId: "apple/coreml-stable-diffusion-xl-base-with-refiner",
184+
modelVersion: "SDXL with refiner (1024, macOS)",
185+
supportsEncoder: true,
186+
isXL: true
187+
)
188+
181189
static let xlmbp = ModelInfo(
182190
modelId: "apple/coreml-stable-diffusion-mixed-bit-palettization",
183-
modelVersion: "Stable Diffusion XL base [4.5 bit]",
191+
modelVersion: "SDXL base (1024, macOS) [4.5 bit]",
184192
supportsEncoder: true,
185193
quantized: true,
186194
isXL: true
187195
)
188196

197+
static let xlmbpChunked = ModelInfo(
198+
modelId: "apple/coreml-stable-diffusion-xl-base-ios",
199+
modelVersion: "SDXL base (768, iOS) [4 bit]",
200+
supportsEncoder: false,
201+
quantized: true,
202+
isXL: true
203+
)
204+
189205
static let MODELS: [ModelInfo] = {
190206
if deviceSupportsQuantization {
191-
return [
207+
var models = [
192208
ModelInfo.v14Base,
193209
ModelInfo.v14Palettized,
194210
ModelInfo.v15Base,
195211
ModelInfo.v15Palettized,
196212
ModelInfo.v2Base,
197213
ModelInfo.v2Palettized,
198214
ModelInfo.v21Base,
199-
ModelInfo.v21Palettized,
200-
ModelInfo.xl,
201-
ModelInfo.xlmbp
215+
ModelInfo.v21Palettized
202216
]
217+
if runningOnMac {
218+
models.append(contentsOf: [
219+
ModelInfo.xl,
220+
ModelInfo.xlWithRefiner,
221+
ModelInfo.xlmbp
222+
])
223+
} else {
224+
models.append(ModelInfo.xlmbpChunked)
225+
}
226+
return models
203227
} else {
204228
return [
205229
ModelInfo.v14Base,

Diffusion/Common/Pipeline/Pipeline.swift

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class Pipeline {
9191
if isXL {
9292
config.encoderScaleFactor = 0.13025
9393
config.decoderScaleFactor = 0.13025
94+
config.schedulerTimestepSpacing = .karras
9495
}
9596

9697
// Evenly distribute previews based on inference steps

Diffusion/Common/State.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class GenerationContext: ObservableObject {
6666
@Published var numImages = 1.0
6767
@Published var seed: UInt32 = 0
6868
@Published var guidanceScale = 7.5
69-
@Published var previews = 5.0
69+
@Published var previews = runningOnMac ? 5.0 : 0.0
7070
@Published var disableSafety = false
7171
@Published var previewImage: CGImage? = nil
7272

Diffusion/Common/Views/PromptTextField.swift

+5-19
Original file line numberDiff line numberDiff line change
@@ -28,27 +28,13 @@ struct PromptTextField: View {
2828
ModelInfo.from(modelVersion: $model.wrappedValue)
2929
}
3030

31-
private var filename: String? {
32-
let variant = modelInfo?.bestAttention ?? .original
33-
return modelInfo?.modelURL(for: variant).lastPathComponent
31+
private var pipelineLoader: PipelineLoader? {
32+
guard let modelInfo = modelInfo else { return nil }
33+
return PipelineLoader(model: modelInfo)
3434
}
35-
36-
private var downloadedURL: URL? {
37-
if let filename = filename {
38-
return PipelineLoader.models.appendingPathComponent(filename)
39-
}
40-
return nil
41-
}
42-
43-
private var packagesFilename: String? {
44-
(filename as NSString?)?.deletingPathExtension
45-
}
46-
35+
4736
private var compiledURL: URL? {
48-
if let packagesFilename = packagesFilename {
49-
return downloadedURL?.deletingLastPathComponent().appendingPathComponent(packagesFilename)
50-
}
51-
return nil
37+
return pipelineLoader?.compiledURL
5238
}
5339

5440
private var textColor: Color {

Diffusion/DiffusionApp.swift

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ struct DiffusionApp: App {
1818
}
1919

2020
let runningOnMac = ProcessInfo.processInfo.isMacCatalystApp
21-
let deviceHas6GBOrMore = ProcessInfo.processInfo.physicalMemory > 5924000000 // Different devices report different amounts, so approximate
21+
let deviceHas6GBOrMore = ProcessInfo.processInfo.physicalMemory > 5910000000 // Reported by iOS 17 beta (21A5319a) on iPhone 13 Pro: 5917753344
22+
let deviceHas8GBOrMore = ProcessInfo.processInfo.physicalMemory > 7900000000 // Reported by iOS 17.0.2 on iPhone 15 Pro Max: 8021032960
2223

2324
let deviceSupportsQuantization = {
2425
if #available(iOS 17, *) {

Diffusion/Views/Loading.swift

+7-2
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,14 @@
99
import SwiftUI
1010
import Combine
1111

12-
let model = deviceSupportsQuantization ? ModelInfo.v21Palettized : ModelInfo.v21Base
12+
func iosModel() -> ModelInfo {
13+
guard deviceSupportsQuantization else { return ModelInfo.v21Base }
14+
if deviceHas6GBOrMore { return ModelInfo.xlmbpChunked }
15+
return ModelInfo.v21Palettized
16+
}
1317

1418
struct LoadingView: View {
19+
1520
@StateObject var generation = GenerationContext()
1621

1722
@State private var preparationPhase = "Downloading…"
@@ -40,7 +45,7 @@ struct LoadingView: View {
4045
.environmentObject(generation)
4146
.onAppear {
4247
Task.init {
43-
let loader = PipelineLoader(model: model)
48+
let loader = PipelineLoader(model: iosModel())
4449
stateSubscriber = loader.statePublisher.sink { state in
4550
DispatchQueue.main.async {
4651
switch state {

Diffusion/Views/TextToImage.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ struct TextToImage: View {
124124
var body: some View {
125125
VStack {
126126
HStack {
127-
PromptTextField(text: $generation.positivePrompt, isPositivePrompt: true, model: deviceSupportsQuantization ? ModelInfo.v21Palettized.modelVersion : ModelInfo.v21Base.modelVersion)
127+
PromptTextField(text: $generation.positivePrompt, isPositivePrompt: true, model: iosModel().modelVersion)
128128
Button("Generate") {
129129
submit()
130130
}

0 commit comments

Comments
 (0)