Skip to content

Commit 826d44f

Browse files
committed
update
1 parent a9c4dde commit 826d44f

17 files changed

+263224
-446
lines changed

.gitignore

+7
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,10 @@ log-latent
44
__pycache__
55
generated
66
latent_infer
7+
datasets/bedroom256.lmdb
8+
datasets/horse256.lmdb
9+
datasets/celebahq
10+
datasets/celebahq256.lmdb
11+
datasets/ffhq
12+
datasets/ffhq256.lmdb
13+
checkpoints

README.md

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Official implementation of Diffusion Autoencoders
2+
3+
A CVPR 2022 paper:
4+
5+
> Preechakul, Konpat, Nattanat Chatthee, Suttisak Wizadwongsa, and Supasorn Suwajanakorn. 2021. “Diffusion Autoencoders: Toward a Meaningful and Decodable Representation.” arXiv [cs.CV]. arXiv. http://arxiv.org/abs/2111.15640.
6+
7+
## Usage
8+
9+
Note: Since we expect a lot of changes on the codebase, please fork the repo before using.
10+
11+
### Quick start
12+
13+
A jupyter notebook.
14+
15+
For unconditional generation: `sample.ipynb`
16+
17+
For manipulation: `manipulate.ipynb`
18+
19+
### Checkpoints
20+
21+
Checkpoints ought to be put into a separate directory `checkpoints`.
22+
23+
The directory tree may look like:
24+
25+
```
26+
checkpoints/
27+
- bedroom128_autoenc
28+
- last.ckpt # diffae checkpoint
29+
- latent.ckpt # predicted z_sem on the dataset
30+
- bedroom128_autoenc_latent
31+
- last.ckpt # diffae + latent DPM checkpoint
32+
- bedroom128_ddpm
33+
- ...
34+
```
35+
36+
We provide checkpoints for the following models:
37+
38+
1. DDIM: FFHQ128 ([72M](https://drive.google.com/drive/folders/1-J8FPNZOQxSqpfTpwRXawLi2KKGL1qlK?usp=sharing), [130M](https://drive.google.com/drive/folders/17T5YJXpYdgE6cWltN8gZFxRsJzpVxnLh?usp=sharing)), [Bedroom128](https://drive.google.com/drive/folders/19s-lAiK7fGD5Meo5obNV5o0L3MfqU0Sk?usp=sharing), [Horse128](https://drive.google.com/drive/folders/1PiC5JWLcd8mZW9cghDCR0V4Hx0QCXOor?usp=sharing)
39+
2. DiffAE (autoencoding only): [FFHQ256](https://drive.google.com/drive/folders/1hTP9QbYXwv_Nl5sgcZNH0yKprJx7ivC5?usp=sharing), FFHQ128 ([72M](https://drive.google.com/drive/folders/15QHmZP1G5jEMh80R1Nbtdb4ZKb6VvfII?usp=sharing), [130M](https://drive.google.com/drive/folders/1UlwLwgv16cEqxTn7g-V2ykIyopmY_fVz?usp=sharing)), [Bedroom128](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [Horse128](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing)
40+
3. DiffAE (with latent DPM, can sample): [FFHQ256](https://drive.google.com/drive/folders/1MonJKYwVLzvCFYuVhp-l9mChq5V2XI6w?usp=sharing), [FFHQ128](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing), [Bedroom128](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [Horse128](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing)
41+
4. DiffAE's classifiers (for manipulation): [FFHQ256's latent on CelebAHQ](https://drive.google.com/drive/folders/1QGkTfvNhgi_TbbV8GbX1Emrp0lStsqLj?usp=sharing), [FFHQ128's latent on CelebAHQ](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing)
42+
43+
44+
### LMDB Datasets
45+
46+
We do not own any of the following datasets. We provide the LMDB ready-to-use dataset for the sake of convenience.
47+
48+
- [FFHQ](https://drive.google.com/drive/folders/1ww7itaSo53NDMa0q-wn-3HWZ3HHqK1IK?usp=sharing)
49+
- [CelebAHQ](https://drive.google.com/drive/folders/1SX3JuVHjYA8sA28EGxr_IoHJ63s4Btbl?usp=sharing)
50+
- [LSUN Bedroom](https://drive.google.com/drive/folders/1O_3aT3LtY1YDE2pOQCp6MFpCk7Pcpkhb?usp=sharing)
51+
- [LSUN Horse](https://drive.google.com/drive/folders/1ooHW7VivZUs4i5CarPaWxakCwfeqAK8l?usp=sharing)
52+
53+
The directory tree should be:
54+
55+
```
56+
datasets/
57+
- bedroom256.lmdb
58+
- celebahq256.lmdb
59+
- ffhq256.lmdb
60+
- horse256.lmdb
61+
```
62+
63+
You can also download from the original sources, and use our provided codes to package them as LMDB files.
64+
Original sources for each dataset is as follows:
65+
66+
- FFHQ (https://github.com/NVlabs/ffhq-dataset)
67+
- CelebAHQ (https://github.com/switchablenorms/CelebAMask-HQ)
68+
- LSUN (https://github.com/fyu/lsun)
69+
70+
The conversion codes are provided as:
71+
72+
```
73+
data_resize_bedroom.py
74+
data_resize_celebhq.py
75+
data_resize_ffhq.py
76+
data_resize_horse.py
77+
```
78+
79+
Google drive: https://drive.google.com/drive/folders/1abNP4QKGbNnymjn8607BF0cwxX2L23jh?usp=sharing
80+
81+
82+
## Training
83+
84+
Soon ...

choices.py

-25
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
class TrainMode(Enum):
66
# manipulate mode = training the classifier
77
manipulate = 'manipulate'
8-
# the classifier on the image domain
9-
manipulate_img = 'manipulateimg'
108
# default trainin mode!
119
diffusion = 'diffusion'
1210
# default latent training mode!
@@ -16,12 +14,6 @@ class TrainMode(Enum):
1614
def is_manipulate(self):
1715
return self in [
1816
TrainMode.manipulate,
19-
TrainMode.manipulate_img,
20-
]
21-
22-
def is_manipluate_img(self):
23-
return self in [
24-
TrainMode.manipulate_img,
2517
]
2618

2719
def is_diffusion(self):
@@ -61,49 +53,32 @@ class ManipulateMode(Enum):
6153
how to train the classifier to manipulate
6254
"""
6355
# train on whole celeba attr dataset
64-
celeba_all = 'all'
6556
celebahq_all = 'celebahq_all'
66-
# train on a few show subset
67-
celeba_fewshot = 'fewshot'
68-
celeba_fewshot_allneg = 'fewshotallneg'
6957
# celeba with D2C's crop
7058
d2c_fewshot = 'd2cfewshot'
7159
d2c_fewshot_allneg = 'd2cfewshotallneg'
72-
celebahq_fewshot = 'celebahq_fewshot'
73-
relighting = 'light'
7460

7561
def is_celeba_attr(self):
7662
return self in [
77-
ManipulateMode.celeba_all,
78-
ManipulateMode.celeba_fewshot,
79-
ManipulateMode.celeba_fewshot_allneg,
8063
ManipulateMode.d2c_fewshot,
8164
ManipulateMode.d2c_fewshot_allneg,
8265
ManipulateMode.celebahq_all,
83-
ManipulateMode.celebahq_fewshot,
8466
]
8567

8668
def is_single_class(self):
8769
return self in [
88-
ManipulateMode.celeba_fewshot,
89-
ManipulateMode.celeba_fewshot_allneg,
9070
ManipulateMode.d2c_fewshot,
9171
ManipulateMode.d2c_fewshot_allneg,
92-
ManipulateMode.celebahq_fewshot,
9372
]
9473

9574
def is_fewshot(self):
9675
return self in [
97-
ManipulateMode.celeba_fewshot,
98-
ManipulateMode.celeba_fewshot_allneg,
9976
ManipulateMode.d2c_fewshot,
10077
ManipulateMode.d2c_fewshot_allneg,
101-
ManipulateMode.celebahq_fewshot,
10278
]
10379

10480
def is_fewshot_allneg(self):
10581
return self in [
106-
ManipulateMode.celeba_fewshot_allneg,
10782
ManipulateMode.d2c_fewshot_allneg,
10883
]
10984

config.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class TrainConfig(BaseConfig):
6262
train_pred_xstart_detach: bool = True
6363
train_interpolate_prob: float = 0
6464
train_interpolate_img: bool = False
65-
manipulate_mode: ManipulateMode = ManipulateMode.celeba_all
65+
manipulate_mode: ManipulateMode = ManipulateMode.celebahq_all
6666
manipulate_cls: str = None
6767
manipulate_shots: int = None
6868
manipulate_loss: ManipulateLossType = ManipulateLossType.bce
@@ -365,10 +365,6 @@ def make_dataset(self, path=None, **kwargs):
365365
image_size=self.img_size,
366366
split='train',
367367
**kwargs)
368-
elif self.data_name == 'horse':
369-
return LSUNHorse(path=path or self.data_path,
370-
image_size=self.img_size,
371-
**kwargs)
372368
elif self.data_name == 'horse256':
373369
return Horse_lmdb(path=path or self.data_path,
374370
image_size=self.img_size,

data_resize_bedroom.py

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import argparse
2+
import multiprocessing
3+
import os
4+
from os.path import join, exists
5+
from functools import partial
6+
from io import BytesIO
7+
import shutil
8+
9+
import lmdb
10+
from PIL import Image
11+
from torchvision.datasets import LSUNClass
12+
from torchvision.transforms import functional as trans_fn
13+
from tqdm import tqdm
14+
15+
from multiprocessing import Process, Queue
16+
17+
18+
def resize_and_convert(img, size, resample, quality=100):
19+
img = trans_fn.resize(img, size, resample)
20+
img = trans_fn.center_crop(img, size)
21+
buffer = BytesIO()
22+
img.save(buffer, format="webp", quality=quality)
23+
val = buffer.getvalue()
24+
25+
return val
26+
27+
28+
def resize_multiple(img,
29+
sizes=(128, 256, 512, 1024),
30+
resample=Image.LANCZOS,
31+
quality=100):
32+
imgs = []
33+
34+
for size in sizes:
35+
imgs.append(resize_and_convert(img, size, resample, quality))
36+
37+
return imgs
38+
39+
40+
def resize_worker(idx, img, sizes, resample):
41+
img = img.convert("RGB")
42+
out = resize_multiple(img, sizes=sizes, resample=resample)
43+
return idx, out
44+
45+
46+
from torch.utils.data import Dataset, DataLoader
47+
48+
49+
class ConvertDataset(Dataset):
50+
def __init__(self, data) -> None:
51+
self.data = data
52+
53+
def __len__(self):
54+
return len(self.data)
55+
56+
def __getitem__(self, index):
57+
img, _ = self.data[index]
58+
bytes = resize_and_convert(img, 256, Image.LANCZOS, quality=90)
59+
return bytes
60+
61+
62+
if __name__ == "__main__":
63+
"""
64+
converting lsun' original lmdb to our lmdb, which is somehow more performant.
65+
"""
66+
from tqdm import tqdm
67+
68+
# path to the original lsun's lmdb
69+
src_path = 'datasets/bedroom_train_lmdb'
70+
out_path = 'datasets/bedroom256.lmdb'
71+
72+
dataset = LSUNClass(root=os.path.expanduser(src_path))
73+
dataset = ConvertDataset(dataset)
74+
loader = DataLoader(dataset,
75+
batch_size=50,
76+
num_workers=12,
77+
collate_fn=lambda x: x,
78+
shuffle=False)
79+
80+
target = os.path.expanduser(out_path)
81+
if os.path.exists(target):
82+
shutil.rmtree(target)
83+
84+
with lmdb.open(target, map_size=1024**4, readahead=False) as env:
85+
with tqdm(total=len(dataset)) as progress:
86+
i = 0
87+
for batch in loader:
88+
with env.begin(write=True) as txn:
89+
for img in batch:
90+
key = f"{256}-{str(i).zfill(7)}".encode("utf-8")
91+
# print(key)
92+
txn.put(key, img)
93+
i += 1
94+
progress.update()
95+
# if i == 1000:
96+
# break
97+
# if total == len(imgset):
98+
# break
99+
100+
with env.begin(write=True) as txn:
101+
txn.put("length".encode("utf-8"), str(i).encode("utf-8"))

0 commit comments

Comments
 (0)