Skip to content

Commit 096d219

Browse files
committedDec 18, 2022
Guidance Handling
* Update how guidance scaling is handled
1 parent 9e02241 commit 096d219

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed
 

‎CoreML/pipeline/StableDiffusionPipeline.swift

+8-15
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@ public struct StableDiffusionPipeline: ResourceManaging {
3333
/// Optional model for checking safety of generated image
3434
var safetyChecker: SafetyChecker? = nil
3535

36-
/// Controls the influence of the text prompt on sampling process (0=random images)
37-
var guidanceScale: Float = 7.5
38-
3936
/// Reports whether this pipeline can perform safety checks
4037
public var canSafetyCheck: Bool {
4138
safetyChecker != nil
@@ -56,7 +53,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
5653
/// - unet: Model for noise prediction on latent samples
5754
/// - decoder: Model for decoding latent sample to image
5855
/// - safetyChecker: Optional model for checking safety of generated images
59-
/// - guidanceScale: Influence of the text prompt on generation process
56+
/// - guidanceScale: Influence of the text prompt on generation process (0=random images)
6057
/// - reduceMemory: Option to enable reduced memory mode
6158
/// - Returns: Pipeline ready for image generation
6259
public init(textEncoder: TextEncoder,
@@ -108,7 +105,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
108105
/// - prompt: Text prompt to guide sampling
109106
/// - stepCount: Number of inference steps to perform
110107
/// - imageCount: Number of samples/images to generate for the input prompt
111-
/// - seed: Random seed which
108+
/// - seed: Random seed which allows us to re-generate the same image for the same prompt by re-using the seed
112109
/// - guidanceScale: For classifier guidance
113110
/// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety`
114111
/// - progressHandler: Callback to perform after each step, stops on receiving false response
@@ -125,7 +122,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
125122
scheduler: StableDiffusionScheduler = .pndm,
126123
progressHandler: (Progress) -> Bool = { _ in true }
127124
) throws -> [CGImage?] {
128-
125+
129126
// Encode the input prompt as well as a blank unconditioned input
130127
let promptEmbedding = try textEncoder.encode(prompt)
131128
let negativePromptEmbedding = try textEncoder.encode(negativePrompt)
@@ -172,7 +169,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
172169
hiddenStates: hiddenStates
173170
)
174171

175-
noise = performGuidance(noise)
172+
noise = performGuidance(noise: noise, guidance: guidanceScale)
176173

177174
// Have the scheduler compute the previous (t-1) latent
178175
// sample given the predicted noise and current sample
@@ -235,22 +232,18 @@ public struct StableDiffusionPipeline: ResourceManaging {
235232
return states
236233
}
237234

238-
func performGuidance(_ noise: [MLShapedArray<Float32>]) -> [MLShapedArray<Float32>] {
239-
noise.map { performGuidance($0) }
235+
func performGuidance(noise: [MLShapedArray<Float32>], guidance: Float) -> [MLShapedArray<Float32>] {
236+
noise.map { performGuidance(noise: $0, guidance: guidance) }
240237
}
241238

242-
func performGuidance(_ noise: MLShapedArray<Float32>) -> MLShapedArray<Float32> {
243-
239+
func performGuidance(noise: MLShapedArray<Float32>, guidance: Float) -> MLShapedArray<Float32> {
244240
let blankNoiseScalars = noise[0].scalars
245241
let textNoiseScalars = noise[1].scalars
246-
247242
var resultScalars = blankNoiseScalars
248-
249243
for i in 0..<resultScalars.count {
250244
// unconditioned + guidance*(text - unconditioned)
251-
resultScalars[i] += guidanceScale*(textNoiseScalars[i]-blankNoiseScalars[i])
245+
resultScalars[i] += guidance * (textNoiseScalars[i]-blankNoiseScalars[i])
252246
}
253-
254247
var shape = noise.shape
255248
shape[0] = 1
256249
return MLShapedArray<Float32>(scalars: resultScalars, shape: shape)

0 commit comments

Comments
 (0)
Please sign in to comment.