@@ -33,9 +33,6 @@ public struct StableDiffusionPipeline: ResourceManaging {
33
33
/// Optional model for checking safety of generated image
34
34
var safetyChecker : SafetyChecker ? = nil
35
35
36
- /// Controls the influence of the text prompt on sampling process (0=random images)
37
- var guidanceScale : Float = 7.5
38
-
39
36
/// Reports whether this pipeline can perform safety checks
40
37
public var canSafetyCheck : Bool {
41
38
safetyChecker != nil
@@ -56,7 +53,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
56
53
/// - unet: Model for noise prediction on latent samples
57
54
/// - decoder: Model for decoding latent sample to image
58
55
/// - safetyChecker: Optional model for checking safety of generated images
59
- /// - guidanceScale: Influence of the text prompt on generation process
56
+ /// - guidanceScale: Influence of the text prompt on generation process (0=random images)
60
57
/// - reduceMemory: Option to enable reduced memory mode
61
58
/// - Returns: Pipeline ready for image generation
62
59
public init ( textEncoder: TextEncoder ,
@@ -108,7 +105,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
108
105
/// - prompt: Text prompt to guide sampling
109
106
/// - stepCount: Number of inference steps to perform
110
107
/// - imageCount: Number of samples/images to generate for the input prompt
111
- /// - seed: Random seed which
108
+ /// - seed: Random seed which allows us to re-generate the same image for the same prompt by re-using the seed
112
109
/// - guidanceScale: For classifier guidance
113
110
/// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety`
114
111
/// - progressHandler: Callback to perform after each step, stops on receiving false response
@@ -125,7 +122,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
125
122
scheduler: StableDiffusionScheduler = . pndm,
126
123
progressHandler: ( Progress ) -> Bool = { _ in true }
127
124
) throws -> [ CGImage ? ] {
128
-
125
+
129
126
// Encode the input prompt as well as a blank unconditioned input
130
127
let promptEmbedding = try textEncoder. encode ( prompt)
131
128
let negativePromptEmbedding = try textEncoder. encode ( negativePrompt)
@@ -172,7 +169,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
172
169
hiddenStates: hiddenStates
173
170
)
174
171
175
- noise = performGuidance ( noise)
172
+ noise = performGuidance ( noise: noise , guidance : guidanceScale )
176
173
177
174
// Have the scheduler compute the previous (t-1) latent
178
175
// sample given the predicted noise and current sample
@@ -235,22 +232,18 @@ public struct StableDiffusionPipeline: ResourceManaging {
235
232
return states
236
233
}
237
234
238
- func performGuidance( _ noise: [ MLShapedArray < Float32 > ] ) -> [ MLShapedArray < Float32 > ] {
239
- noise. map { performGuidance ( $0 ) }
235
+ func performGuidance( noise: [ MLShapedArray < Float32 > ] , guidance : Float ) -> [ MLShapedArray < Float32 > ] {
236
+ noise. map { performGuidance ( noise : $0 , guidance : guidance ) }
240
237
}
241
238
242
- func performGuidance( _ noise: MLShapedArray < Float32 > ) -> MLShapedArray < Float32 > {
243
-
239
+ func performGuidance( noise: MLShapedArray < Float32 > , guidance: Float ) -> MLShapedArray < Float32 > {
244
240
let blankNoiseScalars = noise [ 0 ] . scalars
245
241
let textNoiseScalars = noise [ 1 ] . scalars
246
-
247
242
var resultScalars = blankNoiseScalars
248
-
249
243
for i in 0 ..< resultScalars. count {
250
244
// unconditioned + guidance*(text - unconditioned)
251
- resultScalars [ i] += guidanceScale* ( textNoiseScalars [ i] - blankNoiseScalars[ i] )
245
+ resultScalars [ i] += guidance * ( textNoiseScalars [ i] - blankNoiseScalars[ i] )
252
246
}
253
-
254
247
var shape = noise. shape
255
248
shape [ 0 ] = 1
256
249
return MLShapedArray < Float32 > ( scalars: resultScalars, shape: shape)
0 commit comments