@@ -70,8 +70,13 @@ def before_sample(self, x, ts, cond, unconditional_conditioning):
70
70
71
71
# Have to unwrap the inpainting conditioning here to perform pre-processing
72
72
image_conditioning = None
73
+ uc_image_conditioning = None
73
74
if isinstance (cond , dict ):
74
- image_conditioning = cond ["c_concat" ][0 ]
75
+ if self .conditioning_key == "crossattn-adm" :
76
+ image_conditioning = cond ["c_adm" ]
77
+ uc_image_conditioning = unconditional_conditioning ["c_adm" ]
78
+ else :
79
+ image_conditioning = cond ["c_concat" ][0 ]
75
80
cond = cond ["c_crossattn" ][0 ]
76
81
unconditional_conditioning = unconditional_conditioning ["c_crossattn" ][0 ]
77
82
@@ -98,8 +103,12 @@ def before_sample(self, x, ts, cond, unconditional_conditioning):
98
103
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
99
104
# Note that they need to be lists because it just concatenates them later.
100
105
if image_conditioning is not None :
101
- cond = {"c_concat" : [image_conditioning ], "c_crossattn" : [cond ]}
102
- unconditional_conditioning = {"c_concat" : [image_conditioning ], "c_crossattn" : [unconditional_conditioning ]}
106
+ if self .conditioning_key == "crossattn-adm" :
107
+ cond = {"c_adm" : image_conditioning , "c_crossattn" : [cond ]}
108
+ unconditional_conditioning = {"c_adm" : uc_image_conditioning , "c_crossattn" : [unconditional_conditioning ]}
109
+ else :
110
+ cond = {"c_concat" : [image_conditioning ], "c_crossattn" : [cond ]}
111
+ unconditional_conditioning = {"c_concat" : [image_conditioning ], "c_crossattn" : [unconditional_conditioning ]}
103
112
104
113
return x , ts , cond , unconditional_conditioning
105
114
@@ -176,8 +185,12 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning,
176
185
177
186
# Wrap the conditioning models with additional image conditioning for inpainting model
178
187
if image_conditioning is not None :
179
- conditioning = {"c_concat" : [image_conditioning ], "c_crossattn" : [conditioning ]}
180
- unconditional_conditioning = {"c_concat" : [image_conditioning ], "c_crossattn" : [unconditional_conditioning ]}
188
+ if self .conditioning_key == "crossattn-adm" :
189
+ conditioning = {"c_adm" : image_conditioning , "c_crossattn" : [conditioning ]}
190
+ unconditional_conditioning = {"c_adm" : torch .zeros_like (image_conditioning ), "c_crossattn" : [unconditional_conditioning ]}
191
+ else :
192
+ conditioning = {"c_concat" : [image_conditioning ], "c_crossattn" : [conditioning ]}
193
+ unconditional_conditioning = {"c_concat" : [image_conditioning ], "c_crossattn" : [unconditional_conditioning ]}
181
194
182
195
samples = self .launch_sampling (t_enc + 1 , lambda : self .sampler .decode (x1 , conditioning , t_enc , unconditional_guidance_scale = p .cfg_scale , unconditional_conditioning = unconditional_conditioning ))
183
196
@@ -195,8 +208,12 @@ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, ima
195
208
# Wrap the conditioning models with additional image conditioning for inpainting model
196
209
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
197
210
if image_conditioning is not None :
198
- conditioning = {"dummy_for_plms" : np .zeros ((conditioning .shape [0 ],)), "c_crossattn" : [conditioning ], "c_concat" : [image_conditioning ]}
199
- unconditional_conditioning = {"c_crossattn" : [unconditional_conditioning ], "c_concat" : [image_conditioning ]}
211
+ if self .conditioning_key == "crossattn-adm" :
212
+ conditioning = {"dummy_for_plms" : np .zeros ((conditioning .shape [0 ],)), "c_crossattn" : [conditioning ], "c_adm" : image_conditioning }
213
+ unconditional_conditioning = {"c_crossattn" : [unconditional_conditioning ], "c_adm" : torch .zeros_like (image_conditioning )}
214
+ else :
215
+ conditioning = {"dummy_for_plms" : np .zeros ((conditioning .shape [0 ],)), "c_crossattn" : [conditioning ], "c_concat" : [image_conditioning ]}
216
+ unconditional_conditioning = {"c_crossattn" : [unconditional_conditioning ], "c_concat" : [image_conditioning ]}
200
217
201
218
samples_ddim = self .launch_sampling (steps , lambda : self .sampler .sample (S = steps , conditioning = conditioning , batch_size = int (x .shape [0 ]), shape = x [0 ].shape , verbose = False , unconditional_guidance_scale = p .cfg_scale , unconditional_conditioning = unconditional_conditioning , x_T = x , eta = self .eta )[0 ])
202
219
0 commit comments