Skip to content

Commit e03c0e0

Browse files
authoredJun 14, 2023
Palettization Support (#49)
* Add palettized model definitions. * Adapt to new API TODO: update SPM package dependency. * Allow download of private repos with HF token This is not exposed in the UI. * Use split_einsum_v2 when available * Minor README changes. * iOS: reduceMemory if models not quantized or < 6GB RAM * Only show quantized models in iOS 17, macOS 14. * Update ml-stable-diffusion package reference. * Update URLs of quantized models.
1 parent 81a2d49 commit e03c0e0

File tree

10 files changed

+184
-50
lines changed

10 files changed

+184
-50
lines changed
 

‎Diffusion-macOS/Capabilities.swift

+10
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,16 @@
99
import Foundation
1010

1111
let runningOnMac = true
12+
let deviceHas6GBOrMore = true
13+
14+
let deviceSupportsQuantization = {
15+
if #available(macOS 14, *) {
16+
true
17+
} else {
18+
false
19+
}
20+
}()
21+
1222

1323
#if canImport(MLCompute)
1424
import MLCompute

‎Diffusion.xcodeproj/project.pbxproj

+30-12
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
/* Begin PBXBuildFile section */
1010
EB067F872992E561004D1AD9 /* HelpContent.swift in Sources */ = {isa = PBXBuildFile; fileRef = EB067F862992E561004D1AD9 /* HelpContent.swift */; };
11-
EB33A51D2954D89F00B16357 /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = EB33A51C2954D89F00B16357 /* StableDiffusion */; };
11+
EB25B3D62A3A2DC4000E25A1 /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = EB25B3D52A3A2DC4000E25A1 /* StableDiffusion */; };
12+
EB25B3D82A3A2DD5000E25A1 /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = EB25B3D72A3A2DD5000E25A1 /* StableDiffusion */; };
1213
EB560F0429A3C20800C0F8B8 /* Capabilities.swift in Sources */ = {isa = PBXBuildFile; fileRef = EB560F0329A3C20800C0F8B8 /* Capabilities.swift */; };
1314
EBB5BA5329425BEE003A2A5B /* PipelineLoader.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBB5BA5229425BEE003A2A5B /* PipelineLoader.swift */; };
1415
EBB5BA5829425E17003A2A5B /* Path in Frameworks */ = {isa = PBXBuildFile; productRef = EBB5BA5729425E17003A2A5B /* Path */; };
@@ -18,7 +19,6 @@
1819
EBDD7DAB29731F7500C1C4B2 /* PipelineLoader.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBB5BA5229425BEE003A2A5B /* PipelineLoader.swift */; };
1920
EBDD7DAD29731FB300C1C4B2 /* Path in Frameworks */ = {isa = PBXBuildFile; productRef = EBDD7DAC29731FB300C1C4B2 /* Path */; };
2021
EBDD7DAF29731FB300C1C4B2 /* ZIPFoundation in Frameworks */ = {isa = PBXBuildFile; productRef = EBDD7DAE29731FB300C1C4B2 /* ZIPFoundation */; };
21-
EBDD7DB129731FB300C1C4B2 /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = EBDD7DB029731FB300C1C4B2 /* StableDiffusion */; };
2222
EBDD7DB32973200200C1C4B2 /* Utils.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DB22973200200C1C4B2 /* Utils.swift */; };
2323
EBDD7DB42973200200C1C4B2 /* Utils.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DB22973200200C1C4B2 /* Utils.swift */; };
2424
EBDD7DB52973201800C1C4B2 /* ModelInfo.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */; };
@@ -105,7 +105,7 @@
105105
buildActionMask = 2147483647;
106106
files = (
107107
EBB5BA5829425E17003A2A5B /* Path in Frameworks */,
108-
EB33A51D2954D89F00B16357 /* StableDiffusion in Frameworks */,
108+
EB25B3D62A3A2DC4000E25A1 /* StableDiffusion in Frameworks */,
109109
EBB5BA5D294504DE003A2A5B /* ZIPFoundation in Frameworks */,
110110
);
111111
runOnlyForDeploymentPostprocessing = 0;
@@ -129,7 +129,7 @@
129129
buildActionMask = 2147483647;
130130
files = (
131131
F155203C297118E700DC009B /* CompactSlider in Frameworks */,
132-
EBDD7DB129731FB300C1C4B2 /* StableDiffusion in Frameworks */,
132+
EB25B3D82A3A2DD5000E25A1 /* StableDiffusion in Frameworks */,
133133
EBDD7DAD29731FB300C1C4B2 /* Path in Frameworks */,
134134
EBDD7DAF29731FB300C1C4B2 /* ZIPFoundation in Frameworks */,
135135
);
@@ -280,12 +280,13 @@
280280
buildRules = (
281281
);
282282
dependencies = (
283+
EBF61AB32A2F976600482CF3 /* PBXTargetDependency */,
283284
);
284285
name = Diffusion;
285286
packageProductDependencies = (
286287
EBB5BA5729425E17003A2A5B /* Path */,
287288
EBB5BA5C294504DE003A2A5B /* ZIPFoundation */,
288-
EB33A51C2954D89F00B16357 /* StableDiffusion */,
289+
EB25B3D52A3A2DC4000E25A1 /* StableDiffusion */,
289290
);
290291
productName = Diffusion;
291292
productReference = EBE755C5293E37DD00806B32 /* Diffusion.app */;
@@ -339,13 +340,14 @@
339340
buildRules = (
340341
);
341342
dependencies = (
343+
EB0199492A31FEAF00B133E2 /* PBXTargetDependency */,
342344
);
343345
name = "Diffusion-macOS";
344346
packageProductDependencies = (
345347
F155203B297118E700DC009B /* CompactSlider */,
346348
EBDD7DAC29731FB300C1C4B2 /* Path */,
347349
EBDD7DAE29731FB300C1C4B2 /* ZIPFoundation */,
348-
EBDD7DB029731FB300C1C4B2 /* StableDiffusion */,
350+
EB25B3D72A3A2DD5000E25A1 /* StableDiffusion */,
349351
);
350352
productName = "Diffusion-macOS";
351353
productReference = F15520212971093300DC009B /* Diffusers.app */;
@@ -389,8 +391,8 @@
389391
packageReferences = (
390392
EBB5BA5629425E17003A2A5B /* XCRemoteSwiftPackageReference "Path.swift" */,
391393
EBB5BA5B294504DE003A2A5B /* XCRemoteSwiftPackageReference "ZIPFoundation" */,
392-
EB33A51B2954D89F00B16357 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */,
393394
F155203A297118E600DC009B /* XCRemoteSwiftPackageReference "CompactSlider" */,
395+
EB25B3D42A3A2DC4000E25A1 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */,
394396
);
395397
productRefGroup = EBE755C6293E37DD00806B32 /* Products */;
396398
projectDirPath = "";
@@ -517,6 +519,10 @@
517519
/* End PBXSourcesBuildPhase section */
518520

519521
/* Begin PBXTargetDependency section */
522+
EB0199492A31FEAF00B133E2 /* PBXTargetDependency */ = {
523+
isa = PBXTargetDependency;
524+
productRef = EB0199482A31FEAF00B133E2 /* StableDiffusion */;
525+
};
520526
EBE755D8293E37DE00806B32 /* PBXTargetDependency */ = {
521527
isa = PBXTargetDependency;
522528
target = EBE755C4293E37DD00806B32 /* Diffusion */;
@@ -527,6 +533,10 @@
527533
target = EBE755C4293E37DD00806B32 /* Diffusion */;
528534
targetProxy = EBE755E1293E37DE00806B32 /* PBXContainerItemProxy */;
529535
};
536+
EBF61AB32A2F976600482CF3 /* PBXTargetDependency */ = {
537+
isa = PBXTargetDependency;
538+
productRef = EBF61AB22A2F976600482CF3 /* StableDiffusion */;
539+
};
530540
/* End PBXTargetDependency section */
531541

532542
/* Begin XCBuildConfiguration section */
@@ -915,7 +925,7 @@
915925
/* End XCConfigurationList section */
916926

917927
/* Begin XCRemoteSwiftPackageReference section */
918-
EB33A51B2954D89F00B16357 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */ = {
928+
EB25B3D42A3A2DC4000E25A1 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */ = {
919929
isa = XCRemoteSwiftPackageReference;
920930
repositoryURL = "https://github.com/apple/ml-stable-diffusion";
921931
requirement = {
@@ -950,9 +960,18 @@
950960
/* End XCRemoteSwiftPackageReference section */
951961

952962
/* Begin XCSwiftPackageProductDependency section */
953-
EB33A51C2954D89F00B16357 /* StableDiffusion */ = {
963+
EB0199482A31FEAF00B133E2 /* StableDiffusion */ = {
964+
isa = XCSwiftPackageProductDependency;
965+
productName = StableDiffusion;
966+
};
967+
EB25B3D52A3A2DC4000E25A1 /* StableDiffusion */ = {
968+
isa = XCSwiftPackageProductDependency;
969+
package = EB25B3D42A3A2DC4000E25A1 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */;
970+
productName = StableDiffusion;
971+
};
972+
EB25B3D72A3A2DD5000E25A1 /* StableDiffusion */ = {
954973
isa = XCSwiftPackageProductDependency;
955-
package = EB33A51B2954D89F00B16357 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */;
974+
package = EB25B3D42A3A2DC4000E25A1 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */;
956975
productName = StableDiffusion;
957976
};
958977
EBB5BA5729425E17003A2A5B /* Path */ = {
@@ -975,9 +994,8 @@
975994
package = EBB5BA5B294504DE003A2A5B /* XCRemoteSwiftPackageReference "ZIPFoundation" */;
976995
productName = ZIPFoundation;
977996
};
978-
EBDD7DB029731FB300C1C4B2 /* StableDiffusion */ = {
997+
EBF61AB22A2F976600482CF3 /* StableDiffusion */ = {
979998
isa = XCSwiftPackageProductDependency;
980-
package = EB33A51B2954D89F00B16357 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */;
981999
productName = StableDiffusion;
9821000
};
9831001
F155203B297118E700DC009B /* CompactSlider */ = {

‎Diffusion.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"location" : "https://github.com/apple/ml-stable-diffusion",
1616
"state" : {
1717
"branch" : "main",
18-
"revision" : "fb1fa01c9d30e9b2e02a8b7ed35d905e272a0262"
18+
"revision" : "48f07f24891155a14c51dd835bba7371bdf32d0e"
1919
}
2020
},
2121
{

‎Diffusion/DiffusionApp.swift

+9
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,12 @@ struct DiffusionApp: App {
1818
}
1919

2020
let runningOnMac = ProcessInfo.processInfo.isMacCatalystApp
21+
let deviceHas6GBOrMore = ProcessInfo.processInfo.physicalMemory > 5924000000 // Different devices report different amounts, so approximate
22+
23+
let deviceSupportsQuantization = {
24+
if #available(iOS 17, *) {
25+
true
26+
} else {
27+
false
28+
}
29+
}()

‎Diffusion/Downloader.swift

+13-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class Downloader: NSObject, ObservableObject {
2525

2626
private var urlSession: URLSession? = nil
2727

28-
init(from url: URL, to destination: URL) {
28+
init(from url: URL, to destination: URL, using authToken: String? = nil) {
2929
self.destination = destination
3030
super.init()
3131

@@ -40,7 +40,13 @@ class Downloader: NSObject, ObservableObject {
4040
return
4141
}
4242
print("Starting download of \(url)")
43-
self.urlSession?.downloadTask(with: url).resume()
43+
44+
var request = URLRequest(url: url)
45+
if let authToken = authToken {
46+
request.setValue("Bearer \(authToken)", forHTTPHeaderField: "Authorization")
47+
}
48+
49+
self.urlSession?.downloadTask(with: request).resume()
4450
}
4551
}
4652

@@ -91,9 +97,13 @@ extension Downloader: URLSessionDelegate, URLSessionDownloadDelegate {
9197
}
9298
}
9399

94-
func urlSession(_: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
100+
func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
95101
if let error = error {
96102
downloadState.value = .failed(error)
103+
} else if let response = task.response as? HTTPURLResponse {
104+
print("HTTP response status code: \(response.statusCode)")
105+
// let headers = response.allHeaderFields
106+
// print("HTTP response headers: \(headers)")
97107
}
98108
}
99109
}

‎Diffusion/ModelInfo.swift

+92-18
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import CoreML
1111
enum AttentionVariant: String {
1212
case original
1313
case splitEinsum
14+
case splitEinsumV2
1415
}
1516

1617
extension AttentionVariant {
@@ -30,15 +31,34 @@ struct ModelInfo {
3031
/// Suffix of the archive containing the SPLIT_EINSUM attention variant. Usually something like "split_einsum_compiled"
3132
let splitAttentionSuffix: String
3233

34+
/// Suffix of the archive containing the SPLIT_EINSUM_V2 attention variant. Usually something like "split_einsum_v2_compiled"
35+
let splitAttentionV2Suffix: String
36+
3337
/// Whether the archive contains the VAE Encoder (for image to image tasks). Not yet in use.
3438
let supportsEncoder: Bool
35-
36-
init(modelId: String, modelVersion: String, originalAttentionSuffix: String = "original_compiled", splitAttentionSuffix: String = "split_einsum_compiled", supportsEncoder: Bool = false) {
39+
40+
/// Is attention v2 supported? (Ideally, we should know by looking at the repo contents)
41+
let supportsAttentionV2: Bool
42+
43+
/// Are weights quantized? This is only used to decide whether to use `reduceMemory`
44+
let quantized: Bool
45+
46+
//TODO: refactor all these properties
47+
init(modelId: String, modelVersion: String,
48+
originalAttentionSuffix: String = "original_compiled",
49+
splitAttentionSuffix: String = "split_einsum_compiled",
50+
splitAttentionV2Suffix: String = "split_einsum_v2_compiled",
51+
supportsEncoder: Bool = false,
52+
supportsAttentionV2: Bool = false,
53+
quantized: Bool = false) {
3754
self.modelId = modelId
3855
self.modelVersion = modelVersion
3956
self.originalAttentionSuffix = originalAttentionSuffix
4057
self.splitAttentionSuffix = splitAttentionSuffix
58+
self.splitAttentionV2Suffix = splitAttentionV2Suffix
4159
self.supportsEncoder = supportsEncoder
60+
self.supportsAttentionV2 = supportsAttentionV2
61+
self.quantized = quantized
4262
}
4363
}
4464

@@ -56,7 +76,10 @@ extension ModelInfo {
5676

5777
static var defaultComputeUnits: MLComputeUnits { defaultAttention.defaultComputeUnits }
5878

59-
var bestAttention: AttentionVariant { ModelInfo.defaultAttention }
79+
var bestAttention: AttentionVariant {
80+
if !runningOnMac && supportsAttentionV2 { return .splitEinsumV2 }
81+
return ModelInfo.defaultAttention
82+
}
6083
var defaultComputeUnits: MLComputeUnits { bestAttention.defaultComputeUnits }
6184

6285
func modelURL(for variant: AttentionVariant) -> URL {
@@ -65,6 +88,7 @@ extension ModelInfo {
6588
switch variant {
6689
case .original: suffix = originalAttentionSuffix
6790
case .splitEinsum: suffix = splitAttentionSuffix
91+
case .splitEinsumV2: suffix = splitAttentionV2Suffix
6892
}
6993
let repo = modelId.split(separator: "/").last!
7094
return URL(string: "https://huggingface.co/\(modelId)/resolve/main/\(repo)_\(suffix).zip")!
@@ -73,47 +97,97 @@ extension ModelInfo {
7397
/// Best variant for the current platform.
7498
/// Currently using `split_einsum` for iOS and simple performance heuristics for macOS.
7599
var bestURL: URL { modelURL(for: bestAttention) }
76-
100+
77101
var reduceMemory: Bool {
78-
return !runningOnMac
102+
// Enable on iOS devices, except when using quantization
103+
if runningOnMac { return false }
104+
return !(quantized && deviceHas6GBOrMore)
79105
}
80106
}
81107

82108
extension ModelInfo {
83-
// TODO: repo does not exist yet
84109
static let v14Base = ModelInfo(
85110
modelId: "pcuenq/coreml-stable-diffusion-1-4",
86-
modelVersion: "CompVis/stable-diffusion-v1-4"
111+
modelVersion: "CompVis SD 1.4"
112+
)
113+
114+
static let v14Palettized = ModelInfo(
115+
modelId: "apple/coreml-stable-diffusion-1-4-palettized",
116+
modelVersion: "CompVis SD 1.4 [6 bit]",
117+
supportsEncoder: true,
118+
supportsAttentionV2: true,
119+
quantized: true
87120
)
88121

89122
static let v15Base = ModelInfo(
90123
modelId: "pcuenq/coreml-stable-diffusion-v1-5",
91-
modelVersion: "runwayml/stable-diffusion-v1-5"
124+
modelVersion: "RunwayML SD 1.5"
125+
)
126+
127+
static let v15Palettized = ModelInfo(
128+
modelId: "apple/coreml-stable-diffusion-v1-5-palettized",
129+
modelVersion: "RunwayML SD 1.5 [6 bit]",
130+
supportsEncoder: true,
131+
supportsAttentionV2: true,
132+
quantized: true
92133
)
93134

94135
static let v2Base = ModelInfo(
95136
modelId: "pcuenq/coreml-stable-diffusion-2-base",
96-
modelVersion: "stabilityai/stable-diffusion-2-base"
137+
modelVersion: "StabilityAI SD 2.0",
138+
supportsEncoder: true
139+
)
140+
141+
static let v2Palettized = ModelInfo(
142+
modelId: "apple/coreml-stable-diffusion-2-base-palettized",
143+
modelVersion: "StabilityAI SD 2.0 [6 bit]",
144+
supportsEncoder: true,
145+
supportsAttentionV2: true,
146+
quantized: true
97147
)
98148

99149
static let v21Base = ModelInfo(
100150
modelId: "pcuenq/coreml-stable-diffusion-2-1-base",
101-
modelVersion: "stabilityai/stable-diffusion-2-1-base",
151+
modelVersion: "StabilityAI SD 2.1",
102152
supportsEncoder: true
103153
)
104154

155+
static let v21Palettized = ModelInfo(
156+
modelId: "apple/coreml-stable-diffusion-2-1-base-palettized",
157+
modelVersion: "StabilityAI SD 2.1 [6 bit]",
158+
supportsEncoder: true,
159+
supportsAttentionV2: true,
160+
quantized: true
161+
)
162+
105163
static let ofaSmall = ModelInfo(
106164
modelId: "pcuenq/coreml-small-stable-diffusion-v0",
107165
modelVersion: "OFA-Sys/small-stable-diffusion-v0"
108166
)
109-
110-
static let MODELS = [
111-
ModelInfo.v14Base,
112-
ModelInfo.v15Base,
113-
ModelInfo.v2Base,
114-
ModelInfo.v21Base,
115-
ModelInfo.ofaSmall
116-
]
167+
168+
static let MODELS: [ModelInfo] = {
169+
if deviceSupportsQuantization {
170+
return [
171+
ModelInfo.v14Base,
172+
ModelInfo.v14Palettized,
173+
ModelInfo.v15Base,
174+
ModelInfo.v15Palettized,
175+
ModelInfo.v2Base,
176+
ModelInfo.v2Palettized,
177+
ModelInfo.v21Base,
178+
ModelInfo.v21Palettized,
179+
ModelInfo.ofaSmall
180+
]
181+
} else {
182+
return [
183+
ModelInfo.v14Base,
184+
ModelInfo.v15Base,
185+
ModelInfo.v2Base,
186+
ModelInfo.v21Base,
187+
ModelInfo.ofaSmall
188+
]
189+
}
190+
}()
117191

118192
static func from(modelVersion: String) -> ModelInfo? {
119193
ModelInfo.MODELS.first(where: {$0.modelVersion == modelVersion})

0 commit comments

Comments
 (0)