Skip to content

Commit f1db987

Browse files
Merge pull request AUTOMATIC1111#8958 from MrCheeze/variations-model
Add support for the unclip (Variations) models, unclip-h and unclip-l
2 parents e49c479 + 1f08600 commit f1db987

8 files changed

+88
-30
lines changed

launch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def prepare_environment():
235235
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
236236
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
237237

238-
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e")
238+
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
239239
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
240240
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
241241
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")

models/karlo/ViT-L-14_stats.th

6.91 KB
Binary file not shown.

modules/lowvram.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ def first_stage_model_decode_wrap(z):
5555
if hasattr(sd_model.cond_stage_model, 'model'):
5656
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
5757

58-
# remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then
58+
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
5959
# send the model to GPU. Then put modules back. the modules will be in CPU.
60-
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model
61-
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None
60+
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
61+
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
6262
sd_model.to(devices.device)
63-
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored
63+
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
6464

6565
# register hooks for those the first three models
6666
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
@@ -69,6 +69,8 @@ def first_stage_model_decode_wrap(z):
6969
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
7070
if sd_model.depth_model:
7171
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
72+
if sd_model.embedder:
73+
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
7274
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
7375

7476
if hasattr(sd_model.cond_stage_model, 'model'):

modules/processing.py

+29-12
Original file line numberDiff line numberDiff line change
@@ -78,21 +78,27 @@ def apply_overlay(image, paste_loc, index, overlays):
7878

7979

8080
def txt2img_image_conditioning(sd_model, x, width, height):
81-
if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
82-
# Dummy zero conditioning if we're not using inpainting model.
83-
# Still takes up a bit of memory, but no encoder call.
84-
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
85-
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
81+
if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
8682

87-
# The "masked-image" in this case will just be all zeros since the entire image is masked.
88-
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
89-
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
83+
# The "masked-image" in this case will just be all zeros since the entire image is masked.
84+
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
85+
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
9086

91-
# Add the fake full 1s mask to the first dimension.
92-
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
93-
image_conditioning = image_conditioning.to(x.dtype)
87+
# Add the fake full 1s mask to the first dimension.
88+
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
89+
image_conditioning = image_conditioning.to(x.dtype)
9490

95-
return image_conditioning
91+
return image_conditioning
92+
93+
elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
94+
95+
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
96+
97+
else:
98+
# Dummy zero conditioning if we're not using inpainting or unclip models.
99+
# Still takes up a bit of memory, but no encoder call.
100+
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
101+
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
96102

97103

98104
class StableDiffusionProcessing:
@@ -190,6 +196,14 @@ def edit_image_conditioning(self, source_image):
190196

191197
return conditioning_image
192198

199+
def unclip_image_conditioning(self, source_image):
200+
c_adm = self.sd_model.embedder(source_image)
201+
if self.sd_model.noise_augmentor is not None:
202+
noise_level = 0 # TODO: Allow other noise levels?
203+
c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
204+
c_adm = torch.cat((c_adm, noise_level_emb), 1)
205+
return c_adm
206+
193207
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
194208
self.is_using_inpainting_conditioning = True
195209

@@ -241,6 +255,9 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None
241255
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
242256
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
243257

258+
if self.sampler.conditioning_key == "crossattn-adm":
259+
return self.unclip_image_conditioning(source_image)
260+
244261
# Dummy zero conditioning if we're not using inpainting or depth model.
245262
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
246263

modules/sd_models.py

+8
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,14 @@ def repair_config(sd_config):
383383
elif shared.cmd_opts.upcast_sampling:
384384
sd_config.model.params.unet_config.params.use_fp16 = True
385385

386+
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
387+
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
388+
389+
# For UnCLIP-L, override the hardcoded karlo directory
390+
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
391+
karlo_path = os.path.join(paths.models_path, 'karlo')
392+
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
393+
386394

387395
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
388396
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'

modules/sd_models_config.py

+7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
1515
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
1616
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
17+
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
18+
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
1719
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
1820
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
1921
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
@@ -65,9 +67,14 @@ def is_using_v_parameterization_for_sd2(state_dict):
6567
def guess_model_config_from_state_dict(sd, filename):
6668
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
6769
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
70+
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
6871

6972
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
7073
return config_depth_model
74+
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
75+
return config_unclip
76+
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
77+
return config_unopenclip
7178

7279
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
7380
if diffusion_model_input.shape[1] == 9:

modules/sd_samplers_compvis.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,13 @@ def before_sample(self, x, ts, cond, unconditional_conditioning):
7070

7171
# Have to unwrap the inpainting conditioning here to perform pre-processing
7272
image_conditioning = None
73+
uc_image_conditioning = None
7374
if isinstance(cond, dict):
74-
image_conditioning = cond["c_concat"][0]
75+
if self.conditioning_key == "crossattn-adm":
76+
image_conditioning = cond["c_adm"]
77+
uc_image_conditioning = unconditional_conditioning["c_adm"]
78+
else:
79+
image_conditioning = cond["c_concat"][0]
7580
cond = cond["c_crossattn"][0]
7681
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
7782

@@ -98,8 +103,12 @@ def before_sample(self, x, ts, cond, unconditional_conditioning):
98103
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
99104
# Note that they need to be lists because it just concatenates them later.
100105
if image_conditioning is not None:
101-
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
102-
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
106+
if self.conditioning_key == "crossattn-adm":
107+
cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
108+
unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
109+
else:
110+
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
111+
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
103112

104113
return x, ts, cond, unconditional_conditioning
105114

@@ -176,8 +185,12 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning,
176185

177186
# Wrap the conditioning models with additional image conditioning for inpainting model
178187
if image_conditioning is not None:
179-
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
180-
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
188+
if self.conditioning_key == "crossattn-adm":
189+
conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
190+
unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
191+
else:
192+
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
193+
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
181194

182195
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
183196

@@ -195,8 +208,12 @@ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, ima
195208
# Wrap the conditioning models with additional image conditioning for inpainting model
196209
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
197210
if image_conditioning is not None:
198-
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
199-
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
211+
if self.conditioning_key == "crossattn-adm":
212+
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
213+
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
214+
else:
215+
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
216+
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
200217

201218
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
202219

modules/sd_samplers_kdiffusion.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,21 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
9292
batch_size = len(conds_list)
9393
repeats = [len(conds_list[i]) for i in range(batch_size)]
9494

95+
if shared.sd_model.model.conditioning_key == "crossattn-adm":
96+
image_uncond = torch.zeros_like(image_cond)
97+
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
98+
else:
99+
image_uncond = image_cond
100+
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
101+
95102
if not is_edit_model:
96103
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
97104
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
98-
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
105+
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
99106
else:
100107
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
101108
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
102-
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
109+
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
103110

104111
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
105112
cfg_denoiser_callback(denoiser_params)
@@ -116,13 +123,13 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
116123
cond_in = torch.cat([tensor, uncond, uncond])
117124

118125
if shared.batch_cond_uncond:
119-
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
126+
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
120127
else:
121128
x_out = torch.zeros_like(x_in)
122129
for batch_offset in range(0, x_out.shape[0], batch_size):
123130
a = batch_offset
124131
b = a + batch_size
125-
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
132+
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b]))
126133
else:
127134
x_out = torch.zeros_like(x_in)
128135
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
@@ -135,9 +142,9 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
135142
else:
136143
c_crossattn = torch.cat([tensor[a:b]], uncond)
137144

138-
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]})
145+
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
139146

140-
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
147+
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
141148

142149
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
143150
cfg_denoised_callback(denoised_params)

0 commit comments

Comments
 (0)