forked from huggingface/swift-coreml-diffusers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathModelInfo.swift
203 lines (172 loc) · 6.9 KB
/
ModelInfo.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
//
// ModelInfo.swift
// Diffusion
//
// Created by Pedro Cuenca on 29/12/22.
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
//
import CoreML
enum AttentionVariant: String {
case original
case splitEinsum
case splitEinsumV2
}
extension AttentionVariant {
var defaultComputeUnits: MLComputeUnits { self == .original ? .cpuAndGPU : .cpuAndNeuralEngine }
}
struct ModelInfo {
/// Hugging Face model Id that contains .zip archives with compiled Core ML models
let modelId: String
/// Arbitrary string for presentation purposes. Something like "2.1-base"
let modelVersion: String
/// Suffix of the archive containing the ORIGINAL attention variant. Usually something like "original_compiled"
let originalAttentionSuffix: String
/// Suffix of the archive containing the SPLIT_EINSUM attention variant. Usually something like "split_einsum_compiled"
let splitAttentionSuffix: String
/// Suffix of the archive containing the SPLIT_EINSUM_V2 attention variant. Usually something like "split_einsum_v2_compiled"
let splitAttentionV2Suffix: String
/// Whether the archive contains the VAE Encoder (for image to image tasks). Not yet in use.
let supportsEncoder: Bool
/// Is attention v2 supported? (Ideally, we should know by looking at the repo contents)
let supportsAttentionV2: Bool
/// Are weights quantized? This is only used to decide whether to use `reduceMemory`
let quantized: Bool
//TODO: refactor all these properties
init(modelId: String, modelVersion: String,
originalAttentionSuffix: String = "original_compiled",
splitAttentionSuffix: String = "split_einsum_compiled",
splitAttentionV2Suffix: String = "split_einsum_v2_compiled",
supportsEncoder: Bool = false,
supportsAttentionV2: Bool = false,
quantized: Bool = false) {
self.modelId = modelId
self.modelVersion = modelVersion
self.originalAttentionSuffix = originalAttentionSuffix
self.splitAttentionSuffix = splitAttentionSuffix
self.splitAttentionV2Suffix = splitAttentionV2Suffix
self.supportsEncoder = supportsEncoder
self.supportsAttentionV2 = supportsAttentionV2
self.quantized = quantized
}
}
extension ModelInfo {
//TODO: set compute units instead and derive variant from it
static var defaultAttention: AttentionVariant {
guard runningOnMac else { return .splitEinsum }
#if os(macOS)
guard Capabilities.hasANE else { return .original }
return Capabilities.performanceCores >= 8 ? .original : .splitEinsum
#else
return .splitEinsum
#endif
}
static var defaultComputeUnits: MLComputeUnits { defaultAttention.defaultComputeUnits }
var bestAttention: AttentionVariant {
if !runningOnMac && supportsAttentionV2 { return .splitEinsumV2 }
return ModelInfo.defaultAttention
}
var defaultComputeUnits: MLComputeUnits { bestAttention.defaultComputeUnits }
func modelURL(for variant: AttentionVariant) -> URL {
// Pattern: https://huggingface.co/pcuenq/coreml-stable-diffusion/resolve/main/coreml-stable-diffusion-v1-5_original_compiled.zip
let suffix: String
switch variant {
case .original: suffix = originalAttentionSuffix
case .splitEinsum: suffix = splitAttentionSuffix
case .splitEinsumV2: suffix = splitAttentionV2Suffix
}
let repo = modelId.split(separator: "/").last!
return URL(string: "https://huggingface.co/\(modelId)/resolve/main/\(repo)_\(suffix).zip")!
}
/// Best variant for the current platform.
/// Currently using `split_einsum` for iOS and simple performance heuristics for macOS.
var bestURL: URL { modelURL(for: bestAttention) }
var reduceMemory: Bool {
// Enable on iOS devices, except when using quantization
if runningOnMac { return false }
return !(quantized && deviceHas6GBOrMore)
}
}
extension ModelInfo {
static let v14Base = ModelInfo(
modelId: "pcuenq/coreml-stable-diffusion-1-4",
modelVersion: "CompVis SD 1.4"
)
static let v14Palettized = ModelInfo(
modelId: "apple/coreml-stable-diffusion-1-4-palettized",
modelVersion: "CompVis SD 1.4 [6 bit]",
supportsEncoder: true,
supportsAttentionV2: true,
quantized: true
)
static let v15Base = ModelInfo(
modelId: "pcuenq/coreml-stable-diffusion-v1-5",
modelVersion: "RunwayML SD 1.5"
)
static let v15Palettized = ModelInfo(
modelId: "apple/coreml-stable-diffusion-v1-5-palettized",
modelVersion: "RunwayML SD 1.5 [6 bit]",
supportsEncoder: true,
supportsAttentionV2: true,
quantized: true
)
static let v2Base = ModelInfo(
modelId: "pcuenq/coreml-stable-diffusion-2-base",
modelVersion: "StabilityAI SD 2.0",
supportsEncoder: true
)
static let v2Palettized = ModelInfo(
modelId: "apple/coreml-stable-diffusion-2-base-palettized",
modelVersion: "StabilityAI SD 2.0 [6 bit]",
supportsEncoder: true,
supportsAttentionV2: true,
quantized: true
)
static let v21Base = ModelInfo(
modelId: "pcuenq/coreml-stable-diffusion-2-1-base",
modelVersion: "StabilityAI SD 2.1",
supportsEncoder: true
)
static let v21Palettized = ModelInfo(
modelId: "apple/coreml-stable-diffusion-2-1-base-palettized",
modelVersion: "StabilityAI SD 2.1 [6 bit]",
supportsEncoder: true,
supportsAttentionV2: true,
quantized: true
)
static let ofaSmall = ModelInfo(
modelId: "pcuenq/coreml-small-stable-diffusion-v0",
modelVersion: "OFA-Sys/small-stable-diffusion-v0"
)
static let MODELS: [ModelInfo] = {
if deviceSupportsQuantization {
return [
ModelInfo.v14Base,
ModelInfo.v14Palettized,
ModelInfo.v15Base,
ModelInfo.v15Palettized,
ModelInfo.v2Base,
ModelInfo.v2Palettized,
ModelInfo.v21Base,
ModelInfo.v21Palettized,
ModelInfo.ofaSmall
]
} else {
return [
ModelInfo.v14Base,
ModelInfo.v15Base,
ModelInfo.v2Base,
ModelInfo.v21Base,
ModelInfo.ofaSmall
]
}
}()
static func from(modelVersion: String) -> ModelInfo? {
ModelInfo.MODELS.first(where: {$0.modelVersion == modelVersion})
}
static func from(modelId: String) -> ModelInfo? {
ModelInfo.MODELS.first(where: {$0.modelId == modelId})
}
}
extension ModelInfo : Equatable {
static func ==(lhs: ModelInfo, rhs: ModelInfo) -> Bool { lhs.modelId == rhs.modelId }
}