@@ -11,6 +11,7 @@ import CoreML
11
11
enum AttentionVariant : String {
12
12
case original
13
13
case splitEinsum
14
+ case splitEinsumV2
14
15
}
15
16
16
17
extension AttentionVariant {
@@ -30,15 +31,34 @@ struct ModelInfo {
30
31
/// Suffix of the archive containing the SPLIT_EINSUM attention variant. Usually something like "split_einsum_compiled"
31
32
let splitAttentionSuffix : String
32
33
34
+ /// Suffix of the archive containing the SPLIT_EINSUM_V2 attention variant. Usually something like "split_einsum_v2_compiled"
35
+ let splitAttentionV2Suffix : String
36
+
33
37
/// Whether the archive contains the VAE Encoder (for image to image tasks). Not yet in use.
34
38
let supportsEncoder : Bool
35
-
36
- init ( modelId: String , modelVersion: String , originalAttentionSuffix: String = " original_compiled " , splitAttentionSuffix: String = " split_einsum_compiled " , supportsEncoder: Bool = false ) {
39
+
40
+ /// Is attention v2 supported? (Ideally, we should know by looking at the repo contents)
41
+ let supportsAttentionV2 : Bool
42
+
43
+ /// Are weights quantized? This is only used to decide whether to use `reduceMemory`
44
+ let quantized : Bool
45
+
46
+ //TODO: refactor all these properties
47
+ init ( modelId: String , modelVersion: String ,
48
+ originalAttentionSuffix: String = " original_compiled " ,
49
+ splitAttentionSuffix: String = " split_einsum_compiled " ,
50
+ splitAttentionV2Suffix: String = " split_einsum_v2_compiled " ,
51
+ supportsEncoder: Bool = false ,
52
+ supportsAttentionV2: Bool = false ,
53
+ quantized: Bool = false ) {
37
54
self . modelId = modelId
38
55
self . modelVersion = modelVersion
39
56
self . originalAttentionSuffix = originalAttentionSuffix
40
57
self . splitAttentionSuffix = splitAttentionSuffix
58
+ self . splitAttentionV2Suffix = splitAttentionV2Suffix
41
59
self . supportsEncoder = supportsEncoder
60
+ self . supportsAttentionV2 = supportsAttentionV2
61
+ self . quantized = quantized
42
62
}
43
63
}
44
64
@@ -56,7 +76,10 @@ extension ModelInfo {
56
76
57
77
static var defaultComputeUnits : MLComputeUnits { defaultAttention. defaultComputeUnits }
58
78
59
- var bestAttention : AttentionVariant { ModelInfo . defaultAttention }
79
+ var bestAttention : AttentionVariant {
80
+ if !runningOnMac && supportsAttentionV2 { return . splitEinsumV2 }
81
+ return ModelInfo . defaultAttention
82
+ }
60
83
var defaultComputeUnits : MLComputeUnits { bestAttention. defaultComputeUnits }
61
84
62
85
func modelURL( for variant: AttentionVariant ) -> URL {
@@ -65,6 +88,7 @@ extension ModelInfo {
65
88
switch variant {
66
89
case . original: suffix = originalAttentionSuffix
67
90
case . splitEinsum: suffix = splitAttentionSuffix
91
+ case . splitEinsumV2: suffix = splitAttentionV2Suffix
68
92
}
69
93
let repo = modelId. split ( separator: " / " ) . last!
70
94
return URL ( string: " https://huggingface.co/ \( modelId) /resolve/main/ \( repo) _ \( suffix) .zip " ) !
@@ -73,47 +97,97 @@ extension ModelInfo {
73
97
/// Best variant for the current platform.
74
98
/// Currently using `split_einsum` for iOS and simple performance heuristics for macOS.
75
99
var bestURL : URL { modelURL ( for: bestAttention) }
76
-
100
+
77
101
var reduceMemory : Bool {
78
- return !runningOnMac
102
+ // Enable on iOS devices, except when using quantization
103
+ if runningOnMac { return false }
104
+ return !( quantized && deviceHas6GBOrMore)
79
105
}
80
106
}
81
107
82
108
extension ModelInfo {
83
- // TODO: repo does not exist yet
84
109
static let v14Base = ModelInfo (
85
110
modelId: " pcuenq/coreml-stable-diffusion-1-4 " ,
86
- modelVersion: " CompVis/stable-diffusion-v1-4 "
111
+ modelVersion: " CompVis SD 1.4 "
112
+ )
113
+
114
+ static let v14Palettized = ModelInfo (
115
+ modelId: " apple/coreml-stable-diffusion-1-4-palettized " ,
116
+ modelVersion: " CompVis SD 1.4 [6 bit] " ,
117
+ supportsEncoder: true ,
118
+ supportsAttentionV2: true ,
119
+ quantized: true
87
120
)
88
121
89
122
static let v15Base = ModelInfo (
90
123
modelId: " pcuenq/coreml-stable-diffusion-v1-5 " ,
91
- modelVersion: " runwayml/stable-diffusion-v1-5 "
124
+ modelVersion: " RunwayML SD 1.5 "
125
+ )
126
+
127
+ static let v15Palettized = ModelInfo (
128
+ modelId: " apple/coreml-stable-diffusion-v1-5-palettized " ,
129
+ modelVersion: " RunwayML SD 1.5 [6 bit] " ,
130
+ supportsEncoder: true ,
131
+ supportsAttentionV2: true ,
132
+ quantized: true
92
133
)
93
134
94
135
static let v2Base = ModelInfo (
95
136
modelId: " pcuenq/coreml-stable-diffusion-2-base " ,
96
- modelVersion: " stabilityai/stable-diffusion-2-base "
137
+ modelVersion: " StabilityAI SD 2.0 " ,
138
+ supportsEncoder: true
139
+ )
140
+
141
+ static let v2Palettized = ModelInfo (
142
+ modelId: " apple/coreml-stable-diffusion-2-base-palettized " ,
143
+ modelVersion: " StabilityAI SD 2.0 [6 bit] " ,
144
+ supportsEncoder: true ,
145
+ supportsAttentionV2: true ,
146
+ quantized: true
97
147
)
98
148
99
149
static let v21Base = ModelInfo (
100
150
modelId: " pcuenq/coreml-stable-diffusion-2-1-base " ,
101
- modelVersion: " stabilityai/stable-diffusion-2-1-base " ,
151
+ modelVersion: " StabilityAI SD 2.1 " ,
102
152
supportsEncoder: true
103
153
)
104
154
155
+ static let v21Palettized = ModelInfo (
156
+ modelId: " apple/coreml-stable-diffusion-2-1-base-palettized " ,
157
+ modelVersion: " StabilityAI SD 2.1 [6 bit] " ,
158
+ supportsEncoder: true ,
159
+ supportsAttentionV2: true ,
160
+ quantized: true
161
+ )
162
+
105
163
static let ofaSmall = ModelInfo (
106
164
modelId: " pcuenq/coreml-small-stable-diffusion-v0 " ,
107
165
modelVersion: " OFA-Sys/small-stable-diffusion-v0 "
108
166
)
109
-
110
- static let MODELS = [
111
- ModelInfo . v14Base,
112
- ModelInfo . v15Base,
113
- ModelInfo . v2Base,
114
- ModelInfo . v21Base,
115
- ModelInfo . ofaSmall
116
- ]
167
+
168
+ static let MODELS : [ ModelInfo ] = {
169
+ if deviceSupportsQuantization {
170
+ return [
171
+ ModelInfo . v14Base,
172
+ ModelInfo . v14Palettized,
173
+ ModelInfo . v15Base,
174
+ ModelInfo . v15Palettized,
175
+ ModelInfo . v2Base,
176
+ ModelInfo . v2Palettized,
177
+ ModelInfo . v21Base,
178
+ ModelInfo . v21Palettized,
179
+ ModelInfo . ofaSmall
180
+ ]
181
+ } else {
182
+ return [
183
+ ModelInfo . v14Base,
184
+ ModelInfo . v15Base,
185
+ ModelInfo . v2Base,
186
+ ModelInfo . v21Base,
187
+ ModelInfo . ofaSmall
188
+ ]
189
+ }
190
+ } ( )
117
191
118
192
static func from( modelVersion: String ) -> ModelInfo ? {
119
193
ModelInfo . MODELS. first ( where: { $0. modelVersion == modelVersion} )
0 commit comments