Skip to content

Commit 40f8568

Browse files
committed
update
1 parent 2a7c745 commit 40f8568

15 files changed

+133
-729
lines changed

README.md

+8-9
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,15 @@ For manipulation: `manipulate.ipynb`
1818

1919
### Checkpoints
2020

21-
Checkpoints ought to be put into a separate directory `checkpoints`.
21+
We provide checkpoints for the following models:
2222

23-
The directory tree may look like:
23+
1. DDIM: **FFHQ128** ([72M](https://drive.google.com/drive/folders/1-J8FPNZOQxSqpfTpwRXawLi2KKGL1qlK?usp=sharing), [130M](https://drive.google.com/drive/folders/17T5YJXpYdgE6cWltN8gZFxRsJzpVxnLh?usp=sharing)), [**Bedroom128**](https://drive.google.com/drive/folders/19s-lAiK7fGD5Meo5obNV5o0L3MfqU0Sk?usp=sharing), [**Horse128**](https://drive.google.com/drive/folders/1PiC5JWLcd8mZW9cghDCR0V4Hx0QCXOor?usp=sharing)
24+
2. DiffAE (autoencoding only): [**FFHQ256**](https://drive.google.com/drive/folders/1hTP9QbYXwv_Nl5sgcZNH0yKprJx7ivC5?usp=sharing), **FFHQ128** ([72M](https://drive.google.com/drive/folders/15QHmZP1G5jEMh80R1Nbtdb4ZKb6VvfII?usp=sharing), [130M](https://drive.google.com/drive/folders/1UlwLwgv16cEqxTn7g-V2ykIyopmY_fVz?usp=sharing)), [**Bedroom128**](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [**Horse128**](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing)
25+
3. DiffAE (with latent DPM, can sample): [**FFHQ256**](https://drive.google.com/drive/folders/1MonJKYwVLzvCFYuVhp-l9mChq5V2XI6w?usp=sharing), [**FFHQ128**](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing), [**Bedroom128**](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [**Horse128**](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing)
26+
4. DiffAE's classifiers (for manipulation): [**FFHQ256's latent on CelebAHQ**](https://drive.google.com/drive/folders/1QGkTfvNhgi_TbbV8GbX1Emrp0lStsqLj?usp=sharing), [**FFHQ128's latent on CelebAHQ**](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing)
27+
28+
Checkpoints ought to be put into a separate directory `checkpoints`.
29+
Download the checkpoints and put them into `checkpoints` directory. It should look like this:
2430

2531
```
2632
checkpoints/
@@ -33,13 +39,6 @@ checkpoints/
3339
- ...
3440
```
3541

36-
We provide checkpoints for the following models:
37-
38-
1. DDIM: FFHQ128 ([72M](https://drive.google.com/drive/folders/1-J8FPNZOQxSqpfTpwRXawLi2KKGL1qlK?usp=sharing), [130M](https://drive.google.com/drive/folders/17T5YJXpYdgE6cWltN8gZFxRsJzpVxnLh?usp=sharing)), [Bedroom128](https://drive.google.com/drive/folders/19s-lAiK7fGD5Meo5obNV5o0L3MfqU0Sk?usp=sharing), [Horse128](https://drive.google.com/drive/folders/1PiC5JWLcd8mZW9cghDCR0V4Hx0QCXOor?usp=sharing)
39-
2. DiffAE (autoencoding only): [FFHQ256](https://drive.google.com/drive/folders/1hTP9QbYXwv_Nl5sgcZNH0yKprJx7ivC5?usp=sharing), FFHQ128 ([72M](https://drive.google.com/drive/folders/15QHmZP1G5jEMh80R1Nbtdb4ZKb6VvfII?usp=sharing), [130M](https://drive.google.com/drive/folders/1UlwLwgv16cEqxTn7g-V2ykIyopmY_fVz?usp=sharing)), [Bedroom128](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [Horse128](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing)
40-
3. DiffAE (with latent DPM, can sample): [FFHQ256](https://drive.google.com/drive/folders/1MonJKYwVLzvCFYuVhp-l9mChq5V2XI6w?usp=sharing), [FFHQ128](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing), [Bedroom128](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [Horse128](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing)
41-
4. DiffAE's classifiers (for manipulation): [FFHQ256's latent on CelebAHQ](https://drive.google.com/drive/folders/1QGkTfvNhgi_TbbV8GbX1Emrp0lStsqLj?usp=sharing), [FFHQ128's latent on CelebAHQ](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing)
42-
4342

4443
### LMDB Datasets
4544

choices.py

-61
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,6 @@ def can_sample(self):
102102
return self in [ModelType.ddpm]
103103

104104

105-
class ChamferType(Enum):
106-
chamfer = 'chamfer'
107-
stochastic = 'stochastic'
108-
109-
110105
class ModelName(Enum):
111106
"""
112107
List of all supported model classes
@@ -116,24 +111,12 @@ class ModelName(Enum):
116111
beatgans_autoenc = 'beatgans_autoenc'
117112

118113

119-
class EncoderName(Enum):
120-
"""
121-
List of all encoders for ddpm models
122-
"""
123-
124-
v1 = 'v1'
125-
v2 = 'v2'
126-
127-
128114
class ModelMeanType(Enum):
129115
"""
130116
Which type of output the model predicts.
131117
"""
132118

133-
prev_x = 'x_prev' # the model predicts x_{t-1}
134-
start_x = 'x_start' # the model predicts x_0
135119
eps = 'eps' # the model predicts epsilon
136-
scaled_start_x = 'scaledxstart' # the model predicts sqrt(alphacum) x_0
137120

138121

139122
class ModelVarType(Enum):
@@ -144,59 +127,15 @@ class ModelVarType(Enum):
144127
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
145128
"""
146129

147-
# learned directly
148-
learned = 'learned'
149130
# posterior beta_t
150131
fixed_small = 'fixed_small'
151132
# beta_t
152133
fixed_large = 'fixed_large'
153-
# predict values between FIXED_SMALL and FIXED_LARGE, making its job easier
154-
learned_range = 'learned_range'
155134

156135

157136
class LossType(Enum):
158137
mse = 'mse' # use raw MSE loss (and KL when learning variances)
159138
l1 = 'l1'
160-
# mse weighted by the variance, somewhat like in kl
161-
mse_var_weighted = 'mse_weighted'
162-
mse_rescaled = 'mse_rescaled' # use raw MSE loss (with RESCALED_KL when learning variances)
163-
kl = 'kl' # use the variational lower-bound
164-
kl_rescaled = 'kl_rescaled' # like KL, but rescale to estimate the full VLB
165-
166-
def is_vb(self):
167-
return self == LossType.kl or self == LossType.kl_rescaled
168-
169-
170-
class MSEWeightType(Enum):
171-
# use the ddpm's default variance (either analytical or learned)
172-
var = 'var'
173-
# optimal variance by deriving the min kl per image (based on mse of epsilon)
174-
# = small sigma + mse
175-
var_min_kl_img = 'varoptimg'
176-
# optimal variance regradless of the posterior sigmas
177-
# = mse only
178-
var_min_kl_mse_img = 'varoptmseimg'
179-
# same as the above but is based on mse of mu of xprev
180-
var_min_kl_xprev_img = 'varoptxprevimg'
181-
182-
183-
class XStartWeightType(Enum):
184-
# weights for the mse of the xstart
185-
# unweighted x start
186-
uniform = 'uniform'
187-
# reciprocal 1 - alpha_bar
188-
reciprocal_alphabar = 'recipalpha'
189-
# same as the above but not exceeding mse = 1
190-
reciprocal_alphabar_safe = 'recipalphasafe'
191-
# turning x0 into eps as use the mse(eps)
192-
eps = 'eps'
193-
# the same as above but not turning into eps
194-
eps2 = 'eps2'
195-
# same as the above but not exceeding mse = 1
196-
eps2_safe = 'eps2safe'
197-
eps_huber = 'epshuber'
198-
unit_mse_x0 = 'unitmsex0'
199-
unit_mse_eps = 'unitmseeps'
200139

201140

202141
class GenerativeType(Enum):

config.py

+15-49
Original file line numberDiff line numberDiff line change
@@ -72,33 +72,24 @@ class TrainConfig(BaseConfig):
7272
autoenc_mid_attn: bool = True
7373
batch_size: int = 16
7474
batch_size_eval: int = None
75-
beatgans_gen_type: GenerativeType = GenerativeType.ddpm
75+
beatgans_gen_type: GenerativeType = GenerativeType.ddim
7676
beatgans_loss_type: LossType = LossType.mse
7777
beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps
7878
beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large
79-
beatgans_model_mse_weight_type: MSEWeightType = MSEWeightType.var
80-
beatgans_xstart_weight_type: XStartWeightType = XStartWeightType.uniform
8179
beatgans_rescale_timesteps: bool = False
8280
latent_infer_path: str = None
8381
latent_znormalize: bool = False
84-
latent_gen_type: GenerativeType = GenerativeType.ddpm
82+
latent_gen_type: GenerativeType = GenerativeType.ddim
8583
latent_loss_type: LossType = LossType.mse
8684
latent_model_mean_type: ModelMeanType = ModelMeanType.eps
8785
latent_model_var_type: ModelVarType = ModelVarType.fixed_large
88-
latent_model_mse_weight_type: MSEWeightType = MSEWeightType.var
89-
latent_xstart_weight_type: XStartWeightType = XStartWeightType.uniform
9086
latent_rescale_timesteps: bool = False
9187
latent_T_eval: int = 1_000
9288
latent_clip_sample: bool = False
9389
latent_beta_scheduler: str = 'linear'
9490
beta_scheduler: str = 'linear'
95-
data_name: str = 'ffhq'
91+
data_name: str = ''
9692
data_val_name: str = None
97-
def_beta_1: float = 1e-4
98-
def_beta_T: float = 0.02
99-
def_mean_type: str = 'epsilon'
100-
def_var_type: str = 'fixedlarge'
101-
device: str = 'cuda:0'
10293
diffusion_type: str = None
10394
dropout: float = 0.1
10495
ema_decay: float = 0.9999
@@ -109,10 +100,7 @@ class TrainConfig(BaseConfig):
109100
fp16: bool = False
110101
grad_clip: float = 1
111102
img_size: int = 64
112-
kl_coef: float = None
113-
chamfer_coef: float = 1
114-
chamfer_type: ChamferType = ChamferType.chamfer
115-
lr: float = 0.0002
103+
lr: float = 0.0001
116104
optimizer: OptimizerType = OptimizerType.adam
117105
weight_decay: float = 0
118106
model_conf: ModelConfig = None
@@ -124,49 +112,32 @@ class TrainConfig(BaseConfig):
124112
net_beatgans_embed_channels: int = 512
125113
net_resblock_updown: bool = True
126114
net_enc_use_time: bool = False
127-
net_enc_pool: str = 'depthconv'
128-
net_enc_pool_tail_layer: int = None
115+
net_enc_pool: str = 'adaptivenonzero'
129116
net_beatgans_gradient_checkpoint: bool = False
130117
net_beatgans_resnet_two_cond: bool = False
131118
net_beatgans_resnet_use_zero_module: bool = True
132119
net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm
133120
net_beatgans_resnet_cond_channels: int = None
134-
mmd_alphas: Tuple[float] = (0.5, )
135-
mmd_coef: float = 0.1
136-
latent_detach: bool = True
137-
latent_unit_normalize: bool = False
138121
net_ch_mult: Tuple[int] = None
139122
net_ch: int = 64
140123
net_enc_attn: Tuple[int] = None
141124
net_enc_k: int = None
142-
net_enc_name: EncoderName = EncoderName.v1
143125
# number of resblocks for the encoder (half-unet)
144126
net_enc_num_res_blocks: int = 2
145-
net_enc_tail_depth: int = 2
146127
net_enc_channel_mult: Tuple[int] = None
147128
net_enc_grad_checkpoint: bool = False
148129
net_autoenc_stochastic: bool = False
149130
net_latent_activation: Activation = Activation.silu
150-
net_latent_attn_resolutions: Tuple[int] = tuple()
151-
net_latent_blocks: int = None
152131
net_latent_channel_mult: Tuple[int] = (1, 2, 4)
153-
net_latent_cond_both: bool = True
154132
net_latent_condition_bias: float = 0
155133
net_latent_dropout: float = 0
156134
net_latent_layers: int = None
157135
net_latent_net_last_act: Activation = Activation.none
158136
net_latent_net_type: LatentNetType = LatentNetType.none
159137
net_latent_num_hid_channels: int = 1024
160-
net_latent_num_res_blocks: int = 2
161138
net_latent_num_time_layers: int = 2
162-
net_latent_pooling: str = 'linear'
163-
net_latent_project_size: int = 4
164-
net_latent_residual: bool = False
165139
net_latent_skip_layers: Tuple[int] = None
166140
net_latent_time_emb_channels: int = 64
167-
net_latent_time_layer_init: bool = False
168-
net_latent_unpool: str = 'conv'
169-
net_latent_use_mid_attn: bool = True
170141
net_latent_use_norm: bool = False
171142
net_latent_time_last_act: bool = False
172143
net_num_res_blocks: int = 2
@@ -190,12 +161,11 @@ class TrainConfig(BaseConfig):
190161
eval_programs: Tuple[str] = None
191162
# if present load the checkpoint from this path instead
192163
eval_path: str = None
193-
base_dir: str = 'logs'
164+
base_dir: str = 'checkpoints'
194165
use_cache_dataset: bool = False
195166
data_cache_dir: str = os.path.expanduser('~/cache')
196167
work_cache_dir: str = os.path.expanduser('~/mycache')
197-
# data_cache_dir: str = os.path.expanduser('/scratch/konpat')
198-
# work_cache_dir: str = os.path.expanduser('/scratch/konpat')
168+
# to be overridden
199169
name: str = ''
200170

201171
def __post_init__(self):
@@ -265,15 +235,11 @@ def _make_diffusion_conf(self, T=None):
265235
betas=get_named_beta_schedule(self.beta_scheduler, self.T),
266236
model_mean_type=self.beatgans_model_mean_type,
267237
model_var_type=self.beatgans_model_var_type,
268-
model_mse_weight_type=self.beatgans_model_mse_weight_type,
269-
xstart_weight_type=self.beatgans_xstart_weight_type,
270238
loss_type=self.beatgans_loss_type,
271239
rescale_timesteps=self.beatgans_rescale_timesteps,
272240
use_timesteps=space_timesteps(num_timesteps=self.T,
273241
section_counts=section_counts),
274242
fp16=self.fp16,
275-
mmd_alphas=self.mmd_alphas,
276-
mmd_coef=self.mmd_coef,
277243
)
278244
else:
279245
raise NotImplementedError()
@@ -298,8 +264,6 @@ def _make_latent_diffusion_conf(self, T=None):
298264
betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T),
299265
model_mean_type=self.latent_model_mean_type,
300266
model_var_type=self.latent_model_var_type,
301-
model_mse_weight_type=self.latent_model_mse_weight_type,
302-
xstart_weight_type=self.latent_xstart_weight_type,
303267
loss_type=self.latent_loss_type,
304268
rescale_timesteps=self.latent_rescale_timesteps,
305269
use_timesteps=space_timesteps(num_timesteps=self.T,
@@ -348,10 +312,15 @@ def make_dataset(self, path=None, **kwargs):
348312
return Horse_lmdb(path=path or self.data_path,
349313
image_size=self.img_size,
350314
**kwargs)
315+
elif self.data_name == 'celebalmdb':
316+
# always use d2c crop
317+
return CelebAlmdb(path=path or self.data_path,
318+
image_size=self.img_size,
319+
original_resolution=None,
320+
crop_d2c=True,
321+
**kwargs)
351322
else:
352-
return ImageDataset(folder=path or self.data_path,
353-
image_size=self.img_size,
354-
**kwargs)
323+
raise NotImplementedError()
355324

356325
def make_loader(self,
357326
dataset,
@@ -431,8 +400,6 @@ def make_model_conf(self):
431400
dropout=self.net_latent_dropout,
432401
last_act=self.net_latent_net_last_act,
433402
num_time_layers=self.net_latent_num_time_layers,
434-
time_layer_init=self.net_latent_time_layer_init,
435-
residual=self.net_latent_residual,
436403
time_last_act=self.net_latent_time_last_act,
437404
)
438405
else:
@@ -447,7 +414,6 @@ def make_model_conf(self):
447414
embed_channels=self.net_beatgans_embed_channels,
448415
enc_out_channels=self.style_ch,
449416
enc_pool=self.net_enc_pool,
450-
enc_pool_tail_layer=self.net_enc_pool_tail_layer,
451417
enc_num_res_block=self.net_enc_num_res_blocks,
452418
enc_channel_mult=self.net_enc_channel_mult,
453419
enc_grad_checkpoint=self.net_enc_grad_checkpoint,

dataset.py

+65-10
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,62 @@ def d2c_crop():
217217
return Crop(x1, x2, y1, y2)
218218

219219

220+
class CelebAlmdb(Dataset):
221+
"""
222+
also supports for d2c crop.
223+
"""
224+
def __init__(self,
225+
path,
226+
image_size,
227+
original_resolution=128,
228+
split=None,
229+
as_tensor: bool = True,
230+
do_augment: bool = True,
231+
do_normalize: bool = True,
232+
crop_d2c: bool = False,
233+
**kwargs):
234+
self.original_resolution = original_resolution
235+
self.data = BaseLMDB(path, original_resolution, zfill=7)
236+
self.length = len(self.data)
237+
self.crop_d2c = crop_d2c
238+
239+
if split is None:
240+
self.offset = 0
241+
else:
242+
raise NotImplementedError()
243+
244+
if crop_d2c:
245+
transform = [
246+
d2c_crop(),
247+
transforms.Resize(image_size),
248+
]
249+
else:
250+
transform = [
251+
transforms.Resize(image_size),
252+
transforms.CenterCrop(image_size),
253+
]
254+
255+
if do_augment:
256+
transform.append(transforms.RandomHorizontalFlip())
257+
if as_tensor:
258+
transform.append(transforms.ToTensor())
259+
if do_normalize:
260+
transform.append(
261+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
262+
self.transform = transforms.Compose(transform)
263+
264+
def __len__(self):
265+
return self.length
266+
267+
def __getitem__(self, index):
268+
assert index < self.length
269+
index = index + self.offset
270+
img = self.data[index]
271+
if self.transform is not None:
272+
img = self.transform(img)
273+
return {'img': img, 'index': index}
274+
275+
220276
class Horse_lmdb(Dataset):
221277
def __init__(self,
222278
path=os.path.expanduser('datasets/horse256.lmdb'),
@@ -534,16 +590,15 @@ class CelebHQAttrDataset(Dataset):
534590
]
535591
cls_to_id = {v: k for k, v in enumerate(id_to_cls)}
536592

537-
def __init__(
538-
self,
539-
path=os.path.expanduser('datasets/celebahq256.lmdb'),
540-
image_size=None,
541-
attr_path=os.path.expanduser(
542-
'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'),
543-
original_resolution=256,
544-
do_augment: bool = False,
545-
do_transform: bool = True,
546-
do_normalize: bool = True):
593+
def __init__(self,
594+
path=os.path.expanduser('datasets/celebahq256.lmdb'),
595+
image_size=None,
596+
attr_path=os.path.expanduser(
597+
'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'),
598+
original_resolution=256,
599+
do_augment: bool = False,
600+
do_transform: bool = True,
601+
do_normalize: bool = True):
547602
super().__init__()
548603
self.image_size = image_size
549604
self.data = BaseLMDB(path, original_resolution, zfill=5)

0 commit comments

Comments
 (0)