Skip to content

Commit 6051920

Browse files
committed
release
0 parents  commit 6051920

24 files changed

+4045
-0
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
.DS_Store
2+
__pycache__/
3+

README.md

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# improved-diffusion
2+
3+
This is the codebase for "Improved Denoising Diffusion Probabilistic Models".
4+
5+
# Usage
6+
7+
This section of the README walks through how to train and sample from a model.
8+
9+
## Installation
10+
11+
Clone this repository and navigate to it in your terminal. Then run:
12+
13+
```
14+
pip install -e .
15+
```
16+
17+
This should install the `improved_diffusion` python package that the scripts depend on.
18+
19+
## Preparing Data
20+
21+
The training code reads images from a directory of image files. In the [datasets](datasets) folder, we have provided instructions/scripts for preparing these directories for ImageNet, LSUN bedrooms, and CIFAR-10.
22+
23+
For creating your own dataset, simply dump all of your images into a directory with ".jpg", ".jpeg", or ".png" extensions. If you wish to train a class-conditional model, name the files like "mylabel1_XXX.jpg", "mylabel2_YYY.jpg", etc., so that the data loader knows that "mylabel1" and "mylabel2" are the labels. Subdirectories will automatically be enumerated as well, so the images can be organized into a recursive structure (although the directory names will be ignored, and the underscore prefixes are used as names).
24+
25+
The images will automatically be scaled and center-cropped by the data-loading pipeline. Simply pass `--data_dir path/to/images` to the training script, and it will take care of the rest.
26+
27+
## Training
28+
29+
To train your model, you should first decide some hyperparameters. We will split up our hyperparameters into three groups: model architecture, diffusion process, and training flags. Here are some reasonable defaults for a baseline:
30+
31+
```
32+
MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3"
33+
DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule linear"
34+
TRAIN_FLAGS="--lr 1e-4 --batch_size 128"
35+
```
36+
37+
Here are some changes we experiment with, and how to set them in the flags:
38+
39+
* **Learned sigmas:** add `--learn_sigma True` to `MODEL_FLAGS`
40+
* **Cosine schedule:** change `--noise_schedule linear` to `--noise_schedule cosine`
41+
* **Reweighted VLB:** add `--use_kl True` to `DIFFUSION_FLAGS` and add `--schedule_sampler loss-second-moment` to `TRAIN_FLAGS`.
42+
* **Class-conditional:** add `--class_cond True` to `MODEL_FLAGS`.
43+
44+
Once you have setup your hyper-parameters, you can run an experiment like so:
45+
46+
```
47+
python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
48+
```
49+
50+
You may also want to train in a distributed manner. In this case, run the same command with `mpiexec`:
51+
52+
```
53+
mpiexec -n $NUM_GPUS python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
54+
```
55+
56+
When training in a distributed manner, you must manually divide the `--batch_size` argument by the number of ranks. In lieu of distributed training, you may use `--microbatch 16` (or `--microbatch 1` in extreme memory-limited cases) to reduce memory usage.
57+
58+
The logs and saved models will be written to a logging directory determined by the `OPENAI_LOGDIR` environment variable. If it is not set, then a temporary directory will be created in `/tmp`.
59+
60+
## Sampling
61+
62+
The above training script saves checkpoints to `.pt` files in the logging directory. These checkpoints will have names like `ema_0.9999_200000.pt` and `model200000.pt`. You will likely want to sample from the EMA models, since those produce much better samples.
63+
64+
Once you have a path to your model, you can generate a large batch of samples like so:
65+
66+
```
67+
python scripts/image_sample.py --model_path /path/to/model.pt $MODEL_FLAGS $DIFFUSION_FLAGS
68+
```
69+
70+
Again, this will save results to a logging directory. Samples are saved as a large `npz` file, where `arr_0` in the file is a large batch of samples.
71+
72+
Just like for training, you can run `image_sample.py` through MPI to use multiple GPUs and machines.
73+
74+
You can change the number of sampling steps using the `--timestep_respacing` argument. For example, `--timestep_respacing 250` uses 250 steps to sample. Passing `--timestep_respacing ddim250` is similar, but uses the uniform stride from the [DDIM paper](https://arxiv.org/abs/2010.02502) rather than our stride.
75+
76+
To sample using [DDIM](https://arxiv.org/abs/2010.02502), pass `--use_ddim True`.

datasets/README.md

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Downloading datasets
2+
3+
This directory includes instructions and scripts for downloading ImageNet, LSUN bedrooms, and CIFAR-10 for use in this codebase.
4+
5+
## ImageNet-64
6+
7+
To download unconditional ImageNet-64, go to [this page on image-net.org](http://www.image-net.org/small/download.php) and click on "Train (64x64)". Simply download the file and unzip it, and use the resulting directory as the data directory (the `--data_dir` argument for the training script).
8+
9+
## Class-conditional ImageNet
10+
11+
For our class-conditional models, we use the official ILSVRC2012 dataset with manual center cropping and downsampling. To obtain this dataset, navigate to [this page on image-net.org](http://www.image-net.org/challenges/LSVRC/2012/downloads) and sign in (or create an account if you do not already have one). Then click on the link reading "Training images (Task 1 & 2)". This is a 138GB tar file containing 1000 sub-tar files, one per class.
12+
13+
Once the file is downloaded, extract it and look inside. You should see 1000 `.tar` files. You need to extract each of these, which may be impractical to do by hand on your operating system. To automate the process on a Unix-based system, you can `cd` into the directory and run this short shell script:
14+
15+
```
16+
for file in *.tar; do tar xf "$file"; rm "$file"; done
17+
```
18+
19+
This will extract and remove each tar file in turn.
20+
21+
Once all of the images have been extracted, the resulting directory should be usable as a data directory (the `--data_dir` argument for the training script). The filenames should all start with WNID (class ids) followed by underscores, like `n01440764_2708.JPEG`. Conveniently (but not by accident) this is how the automated data-loader expects to discover class labels.
22+
23+
## CIFAR-10
24+
25+
For CIFAR-10, we created a script [cifar10.py](cifar10.py) that creates `cifar_train` and `cifar_test` directories. These directories contain files named like `truck_49997.png`, so that the class name is discernable to the data loader.
26+
27+
The `cifar_train` and `cifar_test` directories can be passed directly to the training scripts via the `--data_dir` argument.
28+
29+
## LSUN bedroom
30+
31+
To download and pre-process LSUN bedroom, clone [fyu/lsun](https://github.com/fyu/lsun) on GitHub and run their download script `python3 download.py bedroom`. The result will be an "lmdb" database named like `bedroom_train_lmdb`. You can pass this to our [lsun_bedroom.py](lsun_bedroom.py) script like so:
32+
33+
```
34+
python lsun_bedroom.py bedroom_train_lmdb lsun_train_output_dir
35+
```
36+
37+
This creates a directory called `lsun_train_output_dir`. This directory can be passed to the training scripts via the `--data_dir` argument.

datasets/cifar10.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import os
2+
import tempfile
3+
4+
import torchvision
5+
from tqdm.auto import tqdm
6+
7+
CLASSES = (
8+
"plane",
9+
"car",
10+
"bird",
11+
"cat",
12+
"deer",
13+
"dog",
14+
"frog",
15+
"horse",
16+
"ship",
17+
"truck",
18+
)
19+
20+
21+
def main():
22+
for split in ["train", "test"]:
23+
out_dir = f"cifar_{split}"
24+
if os.path.exists(out_dir):
25+
print(f"skipping split {split} since {out_dir} already exists.")
26+
continue
27+
28+
print("downloading...")
29+
with tempfile.TemporaryDirectory() as tmp_dir:
30+
dataset = torchvision.datasets.CIFAR10(
31+
root=tmp_dir, train=split == "train", download=True
32+
)
33+
34+
print("dumping images...")
35+
os.mkdir(out_dir)
36+
for i in tqdm(range(len(dataset))):
37+
image, label = dataset[i]
38+
filename = os.path.join(out_dir, f"{CLASSES[label]}_{i:05d}.png")
39+
image.save(filename)
40+
41+
42+
if __name__ == "__main__":
43+
main()

datasets/lsun_bedroom.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""
2+
Convert an LSUN lmdb database into a directory of images.
3+
"""
4+
5+
import argparse
6+
import io
7+
import os
8+
9+
from PIL import Image
10+
import lmdb
11+
import numpy as np
12+
13+
14+
def read_images(lmdb_path, image_size):
15+
env = lmdb.open(lmdb_path, map_size=1099511627776, max_readers=100, readonly=True)
16+
with env.begin(write=False) as transaction:
17+
cursor = transaction.cursor()
18+
for _, webp_data in cursor:
19+
img = Image.open(io.BytesIO(webp_data))
20+
width, height = img.size
21+
scale = image_size / min(width, height)
22+
img = img.resize(
23+
(int(round(scale * width)), int(round(scale * height))),
24+
resample=Image.BOX,
25+
)
26+
arr = np.array(img)
27+
h, w, _ = arr.shape
28+
h_off = (h - image_size) // 2
29+
w_off = (w - image_size) // 2
30+
arr = arr[h_off : h_off + image_size, w_off : w_off + image_size]
31+
yield arr
32+
33+
34+
def dump_images(out_dir, images, prefix):
35+
if not os.path.exists(out_dir):
36+
os.mkdir(out_dir)
37+
for i, img in enumerate(images):
38+
Image.fromarray(img).save(os.path.join(out_dir, f"{prefix}_{i:07d}.png"))
39+
40+
41+
def main():
42+
parser = argparse.ArgumentParser()
43+
parser.add_argument("--image-size", help="new image size", type=int, default=256)
44+
parser.add_argument("--prefix", help="class name", type=str, default="bedroom")
45+
parser.add_argument("lmdb_path", help="path to an LSUN lmdb database")
46+
parser.add_argument("out_dir", help="path to output directory")
47+
args = parser.parse_args()
48+
49+
images = read_images(args.lmdb_path, args.image_size)
50+
dump_images(args.out_dir, images, args.prefix)
51+
52+
53+
if __name__ == "__main__":
54+
main()

improved_diffusion/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""
2+
Codebase for "Improved Denoising Diffusion Probabilistic Models".
3+
"""

improved_diffusion/dist_util.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""
2+
Helpers for distributed training.
3+
"""
4+
5+
import io
6+
import os
7+
import socket
8+
9+
import blobfile as bf
10+
from mpi4py import MPI
11+
import torch as th
12+
import torch.distributed as dist
13+
14+
# Change this to reflect your cluster layout.
15+
# The GPU for a given rank is (rank % GPUS_PER_NODE).
16+
GPUS_PER_NODE = 8
17+
18+
SETUP_RETRY_COUNT = 3
19+
20+
21+
def setup_dist():
22+
"""
23+
Setup a distributed process group.
24+
"""
25+
if dist.is_initialized():
26+
return
27+
28+
comm = MPI.COMM_WORLD
29+
backend = "gloo" if not th.cuda.is_available() else "nccl"
30+
31+
if backend == "gloo":
32+
hostname = "localhost"
33+
else:
34+
hostname = socket.gethostbyname(socket.getfqdn())
35+
os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
36+
os.environ["RANK"] = str(comm.rank)
37+
os.environ["WORLD_SIZE"] = str(comm.size)
38+
39+
port = comm.bcast(_find_free_port(), root=0)
40+
os.environ["MASTER_PORT"] = str(port)
41+
dist.init_process_group(backend=backend, init_method="env://")
42+
43+
44+
def dev():
45+
"""
46+
Get the device to use for torch.distributed.
47+
"""
48+
if th.cuda.is_available():
49+
return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}")
50+
return th.device("cpu")
51+
52+
53+
def load_state_dict(path, **kwargs):
54+
"""
55+
Load a PyTorch file without redundant fetches across MPI ranks.
56+
"""
57+
if MPI.COMM_WORLD.Get_rank() == 0:
58+
with bf.BlobFile(path, "rb") as f:
59+
data = f.read()
60+
else:
61+
data = None
62+
data = MPI.COMM_WORLD.bcast(data)
63+
return th.load(io.BytesIO(data), **kwargs)
64+
65+
66+
def sync_params(params):
67+
"""
68+
Synchronize a sequence of Tensors across ranks from rank 0.
69+
"""
70+
for p in params:
71+
with th.no_grad():
72+
dist.broadcast(p, 0)
73+
74+
75+
def _find_free_port():
76+
try:
77+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
78+
s.bind(("", 0))
79+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
80+
return s.getsockname()[1]
81+
finally:
82+
s.close()

improved_diffusion/fp16_util.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""
2+
Helpers to train with 16-bit precision.
3+
"""
4+
5+
import torch.nn as nn
6+
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
7+
8+
9+
def convert_module_to_f16(l):
10+
"""
11+
Convert primitive modules to float16.
12+
"""
13+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
14+
l.weight.data = l.weight.data.half()
15+
l.bias.data = l.bias.data.half()
16+
17+
18+
def convert_module_to_f32(l):
19+
"""
20+
Convert primitive modules to float32, undoing convert_module_to_f16().
21+
"""
22+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
23+
l.weight.data = l.weight.data.float()
24+
l.bias.data = l.bias.data.float()
25+
26+
27+
def make_master_params(model_params):
28+
"""
29+
Copy model parameters into a (differently-shaped) list of full-precision
30+
parameters.
31+
"""
32+
master_params = _flatten_dense_tensors(
33+
[param.detach().float() for param in model_params]
34+
)
35+
master_params = nn.Parameter(master_params)
36+
master_params.requires_grad = True
37+
return [master_params]
38+
39+
40+
def model_grads_to_master_grads(model_params, master_params):
41+
"""
42+
Copy the gradients from the model parameters into the master parameters
43+
from make_master_params().
44+
"""
45+
master_params[0].grad = _flatten_dense_tensors(
46+
[param.grad.data.detach().float() for param in model_params]
47+
)
48+
49+
50+
def master_params_to_model_params(model_params, master_params):
51+
"""
52+
Copy the master parameter data back into the model parameters.
53+
"""
54+
# Without copying to a list, if a generator is passed, this will
55+
# silently not copy any parameters.
56+
model_params = list(model_params)
57+
58+
for param, master_param in zip(
59+
model_params, unflatten_master_params(model_params, master_params)
60+
):
61+
param.detach().copy_(master_param)
62+
63+
64+
def unflatten_master_params(model_params, master_params):
65+
"""
66+
Unflatten the master parameters to look like model_params.
67+
"""
68+
return _unflatten_dense_tensors(master_params[0].detach(), model_params)
69+
70+
71+
def zero_grad(model_params):
72+
for param in model_params:
73+
# Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
74+
if param.grad is not None:
75+
param.grad.detach_()
76+
param.grad.zero_()

0 commit comments

Comments
 (0)