Skip to content

Commit a4af8eb

Browse files
Adding reference script used for fine-tuning an unconditional model
1 parent 87e0400 commit a4af8eb

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed

unit2/finetune_model.py

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import wandb
2+
import numpy as np
3+
import torch, torchvision
4+
import torch.nn.functional as F
5+
from PIL import Image
6+
from tqdm.auto import tqdm
7+
from fastcore.script import call_parse
8+
from torchvision import transforms
9+
from diffusers import DDPMPipeline
10+
from diffusers import DDIMScheduler
11+
from datasets import load_dataset
12+
from matplotlib import pyplot as plt
13+
14+
@call_parse
15+
def train(
16+
image_size = 256,
17+
batch_size = 16,
18+
grad_accumulation_steps = 2,
19+
num_epochs = 1,
20+
start_model = "google/ddpm-bedroom-256",
21+
dataset_name = "huggan/wikiart",
22+
device='cuda',
23+
model_save_name='wikiart_1e',
24+
wandb_project='dm_finetune',
25+
log_samples_every = 250,
26+
save_model_every = 2500,
27+
):
28+
29+
# Initialize wandb for logging
30+
wandb.init(project=wandb_project, config=locals())
31+
32+
33+
# Prepare pretrained model
34+
image_pipe = DDPMPipeline.from_pretrained(start_model);
35+
image_pipe.to(device)
36+
37+
# Get a scheduler for sampling
38+
sampling_scheduler = DDIMScheduler.from_config(start_model)
39+
sampling_scheduler.set_timesteps(num_inference_steps=50)
40+
41+
# Prepare dataset
42+
dataset = load_dataset(dataset_name, split="train")
43+
preprocess = transforms.Compose(
44+
[
45+
transforms.Resize((image_size, image_size)),
46+
transforms.RandomHorizontalFlip(),
47+
transforms.ToTensor(),
48+
transforms.Normalize([0.5], [0.5]),
49+
]
50+
)
51+
def transform(examples):
52+
images = [preprocess(image.convert("RGB")) for image in examples["image"]]
53+
return {"images": images}
54+
dataset.set_transform(transform)
55+
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
56+
57+
58+
# Optimizer & lr scheduler
59+
optimizer = torch.optim.AdamW(image_pipe.unet.parameters(), lr=1e-5)
60+
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
61+
62+
for epoch in range(num_epochs):
63+
for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
64+
65+
# Get the clean images
66+
clean_images = batch['images'].to(device)
67+
68+
# Sample noise to add to the images
69+
noise = torch.randn(clean_images.shape).to(clean_images.device)
70+
bs = clean_images.shape[0]
71+
72+
# Sample a random timestep for each image
73+
timesteps = torch.randint(0, image_pipe.scheduler.num_train_timesteps, (bs,), device=clean_images.device).long()
74+
75+
# Add noise to the clean images according to the noise magnitude at each timestep
76+
# (this is the forward diffusion process)
77+
noisy_images = image_pipe.scheduler.add_noise(clean_images, noise, timesteps)
78+
79+
# Get the model prediction for the noise
80+
noise_pred = image_pipe.unet(noisy_images, timesteps, return_dict=False)[0]
81+
82+
# Compare the prediction with the actual noise:
83+
loss = F.mse_loss(noise_pred, noise)
84+
85+
# Log the loss
86+
wandb.log({'loss':loss.item()})
87+
88+
# Calculate the gradients
89+
loss.backward()
90+
91+
# Gradient Acccumulation: Only update every grad_accumulation_steps
92+
if (step+1)%grad_accumulation_steps == 0:
93+
optimizer.step()
94+
optimizer.zero_grad()
95+
96+
# Occasionally log samples
97+
if (step+1)%log_samples_every == 0:
98+
x = torch.randn(8, 3, 256, 256).to(device) # Batch of 8
99+
for i, t in tqdm(enumerate(sampling_scheduler.timesteps)):
100+
model_input = sampling_scheduler.scale_model_input(x, t)
101+
with torch.no_grad():
102+
noise_pred = image_pipe.unet(model_input, t)["sample"]
103+
x = sampling_scheduler.step(noise_pred, t, x).prev_sample
104+
grid = torchvision.utils.make_grid(x, nrow=4)
105+
im = grid.permute(1, 2, 0).cpu().clip(-1, 1)*0.5 + 0.5
106+
im = Image.fromarray(np.array(im*255).astype(np.uint8))
107+
wandb.log({'Sample generations': wandb.Image(im)})
108+
109+
# Occasionally save model
110+
if (step+1)%save_model_every == 0:
111+
image_pipe.save_pretrained(model_save_name+f'step_{step+1}')
112+
113+
# Update the learning rate for the next epoch
114+
scheduler.step()
115+
116+
# Save the pipeline one last time
117+
image_pipe.save_pretrained(model_save_name+f'step_{step+1}')
118+
119+
# Wrap up the run
120+
wandb.finish()

0 commit comments

Comments
 (0)