1
+ import wandb
2
+ import numpy as np
3
+ import torch , torchvision
4
+ import torch .nn .functional as F
5
+ from PIL import Image
6
+ from tqdm .auto import tqdm
7
+ from fastcore .script import call_parse
8
+ from torchvision import transforms
9
+ from diffusers import DDPMPipeline
10
+ from diffusers import DDIMScheduler
11
+ from datasets import load_dataset
12
+ from matplotlib import pyplot as plt
13
+
14
+ @call_parse
15
+ def train (
16
+ image_size = 256 ,
17
+ batch_size = 16 ,
18
+ grad_accumulation_steps = 2 ,
19
+ num_epochs = 1 ,
20
+ start_model = "google/ddpm-bedroom-256" ,
21
+ dataset_name = "huggan/wikiart" ,
22
+ device = 'cuda' ,
23
+ model_save_name = 'wikiart_1e' ,
24
+ wandb_project = 'dm_finetune' ,
25
+ log_samples_every = 250 ,
26
+ save_model_every = 2500 ,
27
+ ):
28
+
29
+ # Initialize wandb for logging
30
+ wandb .init (project = wandb_project , config = locals ())
31
+
32
+
33
+ # Prepare pretrained model
34
+ image_pipe = DDPMPipeline .from_pretrained (start_model );
35
+ image_pipe .to (device )
36
+
37
+ # Get a scheduler for sampling
38
+ sampling_scheduler = DDIMScheduler .from_config (start_model )
39
+ sampling_scheduler .set_timesteps (num_inference_steps = 50 )
40
+
41
+ # Prepare dataset
42
+ dataset = load_dataset (dataset_name , split = "train" )
43
+ preprocess = transforms .Compose (
44
+ [
45
+ transforms .Resize ((image_size , image_size )),
46
+ transforms .RandomHorizontalFlip (),
47
+ transforms .ToTensor (),
48
+ transforms .Normalize ([0.5 ], [0.5 ]),
49
+ ]
50
+ )
51
+ def transform (examples ):
52
+ images = [preprocess (image .convert ("RGB" )) for image in examples ["image" ]]
53
+ return {"images" : images }
54
+ dataset .set_transform (transform )
55
+ train_dataloader = torch .utils .data .DataLoader (dataset , batch_size = batch_size , shuffle = True )
56
+
57
+
58
+ # Optimizer & lr scheduler
59
+ optimizer = torch .optim .AdamW (image_pipe .unet .parameters (), lr = 1e-5 )
60
+ scheduler = torch .optim .lr_scheduler .ExponentialLR (optimizer , gamma = 0.9 )
61
+
62
+ for epoch in range (num_epochs ):
63
+ for step , batch in tqdm (enumerate (train_dataloader ), total = len (train_dataloader )):
64
+
65
+ # Get the clean images
66
+ clean_images = batch ['images' ].to (device )
67
+
68
+ # Sample noise to add to the images
69
+ noise = torch .randn (clean_images .shape ).to (clean_images .device )
70
+ bs = clean_images .shape [0 ]
71
+
72
+ # Sample a random timestep for each image
73
+ timesteps = torch .randint (0 , image_pipe .scheduler .num_train_timesteps , (bs ,), device = clean_images .device ).long ()
74
+
75
+ # Add noise to the clean images according to the noise magnitude at each timestep
76
+ # (this is the forward diffusion process)
77
+ noisy_images = image_pipe .scheduler .add_noise (clean_images , noise , timesteps )
78
+
79
+ # Get the model prediction for the noise
80
+ noise_pred = image_pipe .unet (noisy_images , timesteps , return_dict = False )[0 ]
81
+
82
+ # Compare the prediction with the actual noise:
83
+ loss = F .mse_loss (noise_pred , noise )
84
+
85
+ # Log the loss
86
+ wandb .log ({'loss' :loss .item ()})
87
+
88
+ # Calculate the gradients
89
+ loss .backward ()
90
+
91
+ # Gradient Acccumulation: Only update every grad_accumulation_steps
92
+ if (step + 1 )% grad_accumulation_steps == 0 :
93
+ optimizer .step ()
94
+ optimizer .zero_grad ()
95
+
96
+ # Occasionally log samples
97
+ if (step + 1 )% log_samples_every == 0 :
98
+ x = torch .randn (8 , 3 , 256 , 256 ).to (device ) # Batch of 8
99
+ for i , t in tqdm (enumerate (sampling_scheduler .timesteps )):
100
+ model_input = sampling_scheduler .scale_model_input (x , t )
101
+ with torch .no_grad ():
102
+ noise_pred = image_pipe .unet (model_input , t )["sample" ]
103
+ x = sampling_scheduler .step (noise_pred , t , x ).prev_sample
104
+ grid = torchvision .utils .make_grid (x , nrow = 4 )
105
+ im = grid .permute (1 , 2 , 0 ).cpu ().clip (- 1 , 1 )* 0.5 + 0.5
106
+ im = Image .fromarray (np .array (im * 255 ).astype (np .uint8 ))
107
+ wandb .log ({'Sample generations' : wandb .Image (im )})
108
+
109
+ # Occasionally save model
110
+ if (step + 1 )% save_model_every == 0 :
111
+ image_pipe .save_pretrained (model_save_name + f'step_{ step + 1 } ' )
112
+
113
+ # Update the learning rate for the next epoch
114
+ scheduler .step ()
115
+
116
+ # Save the pipeline one last time
117
+ image_pipe .save_pretrained (model_save_name + f'step_{ step + 1 } ' )
118
+
119
+ # Wrap up the run
120
+ wandb .finish ()
0 commit comments