@@ -72,33 +72,24 @@ class TrainConfig(BaseConfig):
72
72
autoenc_mid_attn : bool = True
73
73
batch_size : int = 16
74
74
batch_size_eval : int = None
75
- beatgans_gen_type : GenerativeType = GenerativeType .ddpm
75
+ beatgans_gen_type : GenerativeType = GenerativeType .ddim
76
76
beatgans_loss_type : LossType = LossType .mse
77
77
beatgans_model_mean_type : ModelMeanType = ModelMeanType .eps
78
78
beatgans_model_var_type : ModelVarType = ModelVarType .fixed_large
79
- beatgans_model_mse_weight_type : MSEWeightType = MSEWeightType .var
80
- beatgans_xstart_weight_type : XStartWeightType = XStartWeightType .uniform
81
79
beatgans_rescale_timesteps : bool = False
82
80
latent_infer_path : str = None
83
81
latent_znormalize : bool = False
84
- latent_gen_type : GenerativeType = GenerativeType .ddpm
82
+ latent_gen_type : GenerativeType = GenerativeType .ddim
85
83
latent_loss_type : LossType = LossType .mse
86
84
latent_model_mean_type : ModelMeanType = ModelMeanType .eps
87
85
latent_model_var_type : ModelVarType = ModelVarType .fixed_large
88
- latent_model_mse_weight_type : MSEWeightType = MSEWeightType .var
89
- latent_xstart_weight_type : XStartWeightType = XStartWeightType .uniform
90
86
latent_rescale_timesteps : bool = False
91
87
latent_T_eval : int = 1_000
92
88
latent_clip_sample : bool = False
93
89
latent_beta_scheduler : str = 'linear'
94
90
beta_scheduler : str = 'linear'
95
- data_name : str = 'ffhq '
91
+ data_name : str = ''
96
92
data_val_name : str = None
97
- def_beta_1 : float = 1e-4
98
- def_beta_T : float = 0.02
99
- def_mean_type : str = 'epsilon'
100
- def_var_type : str = 'fixedlarge'
101
- device : str = 'cuda:0'
102
93
diffusion_type : str = None
103
94
dropout : float = 0.1
104
95
ema_decay : float = 0.9999
@@ -109,10 +100,7 @@ class TrainConfig(BaseConfig):
109
100
fp16 : bool = False
110
101
grad_clip : float = 1
111
102
img_size : int = 64
112
- kl_coef : float = None
113
- chamfer_coef : float = 1
114
- chamfer_type : ChamferType = ChamferType .chamfer
115
- lr : float = 0.0002
103
+ lr : float = 0.0001
116
104
optimizer : OptimizerType = OptimizerType .adam
117
105
weight_decay : float = 0
118
106
model_conf : ModelConfig = None
@@ -124,49 +112,32 @@ class TrainConfig(BaseConfig):
124
112
net_beatgans_embed_channels : int = 512
125
113
net_resblock_updown : bool = True
126
114
net_enc_use_time : bool = False
127
- net_enc_pool : str = 'depthconv'
128
- net_enc_pool_tail_layer : int = None
115
+ net_enc_pool : str = 'adaptivenonzero'
129
116
net_beatgans_gradient_checkpoint : bool = False
130
117
net_beatgans_resnet_two_cond : bool = False
131
118
net_beatgans_resnet_use_zero_module : bool = True
132
119
net_beatgans_resnet_scale_at : ScaleAt = ScaleAt .after_norm
133
120
net_beatgans_resnet_cond_channels : int = None
134
- mmd_alphas : Tuple [float ] = (0.5 , )
135
- mmd_coef : float = 0.1
136
- latent_detach : bool = True
137
- latent_unit_normalize : bool = False
138
121
net_ch_mult : Tuple [int ] = None
139
122
net_ch : int = 64
140
123
net_enc_attn : Tuple [int ] = None
141
124
net_enc_k : int = None
142
- net_enc_name : EncoderName = EncoderName .v1
143
125
# number of resblocks for the encoder (half-unet)
144
126
net_enc_num_res_blocks : int = 2
145
- net_enc_tail_depth : int = 2
146
127
net_enc_channel_mult : Tuple [int ] = None
147
128
net_enc_grad_checkpoint : bool = False
148
129
net_autoenc_stochastic : bool = False
149
130
net_latent_activation : Activation = Activation .silu
150
- net_latent_attn_resolutions : Tuple [int ] = tuple ()
151
- net_latent_blocks : int = None
152
131
net_latent_channel_mult : Tuple [int ] = (1 , 2 , 4 )
153
- net_latent_cond_both : bool = True
154
132
net_latent_condition_bias : float = 0
155
133
net_latent_dropout : float = 0
156
134
net_latent_layers : int = None
157
135
net_latent_net_last_act : Activation = Activation .none
158
136
net_latent_net_type : LatentNetType = LatentNetType .none
159
137
net_latent_num_hid_channels : int = 1024
160
- net_latent_num_res_blocks : int = 2
161
138
net_latent_num_time_layers : int = 2
162
- net_latent_pooling : str = 'linear'
163
- net_latent_project_size : int = 4
164
- net_latent_residual : bool = False
165
139
net_latent_skip_layers : Tuple [int ] = None
166
140
net_latent_time_emb_channels : int = 64
167
- net_latent_time_layer_init : bool = False
168
- net_latent_unpool : str = 'conv'
169
- net_latent_use_mid_attn : bool = True
170
141
net_latent_use_norm : bool = False
171
142
net_latent_time_last_act : bool = False
172
143
net_num_res_blocks : int = 2
@@ -190,12 +161,11 @@ class TrainConfig(BaseConfig):
190
161
eval_programs : Tuple [str ] = None
191
162
# if present load the checkpoint from this path instead
192
163
eval_path : str = None
193
- base_dir : str = 'logs '
164
+ base_dir : str = 'checkpoints '
194
165
use_cache_dataset : bool = False
195
166
data_cache_dir : str = os .path .expanduser ('~/cache' )
196
167
work_cache_dir : str = os .path .expanduser ('~/mycache' )
197
- # data_cache_dir: str = os.path.expanduser('/scratch/konpat')
198
- # work_cache_dir: str = os.path.expanduser('/scratch/konpat')
168
+ # to be overridden
199
169
name : str = ''
200
170
201
171
def __post_init__ (self ):
@@ -265,15 +235,11 @@ def _make_diffusion_conf(self, T=None):
265
235
betas = get_named_beta_schedule (self .beta_scheduler , self .T ),
266
236
model_mean_type = self .beatgans_model_mean_type ,
267
237
model_var_type = self .beatgans_model_var_type ,
268
- model_mse_weight_type = self .beatgans_model_mse_weight_type ,
269
- xstart_weight_type = self .beatgans_xstart_weight_type ,
270
238
loss_type = self .beatgans_loss_type ,
271
239
rescale_timesteps = self .beatgans_rescale_timesteps ,
272
240
use_timesteps = space_timesteps (num_timesteps = self .T ,
273
241
section_counts = section_counts ),
274
242
fp16 = self .fp16 ,
275
- mmd_alphas = self .mmd_alphas ,
276
- mmd_coef = self .mmd_coef ,
277
243
)
278
244
else :
279
245
raise NotImplementedError ()
@@ -298,8 +264,6 @@ def _make_latent_diffusion_conf(self, T=None):
298
264
betas = get_named_beta_schedule (self .latent_beta_scheduler , self .T ),
299
265
model_mean_type = self .latent_model_mean_type ,
300
266
model_var_type = self .latent_model_var_type ,
301
- model_mse_weight_type = self .latent_model_mse_weight_type ,
302
- xstart_weight_type = self .latent_xstart_weight_type ,
303
267
loss_type = self .latent_loss_type ,
304
268
rescale_timesteps = self .latent_rescale_timesteps ,
305
269
use_timesteps = space_timesteps (num_timesteps = self .T ,
@@ -348,10 +312,15 @@ def make_dataset(self, path=None, **kwargs):
348
312
return Horse_lmdb (path = path or self .data_path ,
349
313
image_size = self .img_size ,
350
314
** kwargs )
315
+ elif self .data_name == 'celebalmdb' :
316
+ # always use d2c crop
317
+ return CelebAlmdb (path = path or self .data_path ,
318
+ image_size = self .img_size ,
319
+ original_resolution = None ,
320
+ crop_d2c = True ,
321
+ ** kwargs )
351
322
else :
352
- return ImageDataset (folder = path or self .data_path ,
353
- image_size = self .img_size ,
354
- ** kwargs )
323
+ raise NotImplementedError ()
355
324
356
325
def make_loader (self ,
357
326
dataset ,
@@ -431,8 +400,6 @@ def make_model_conf(self):
431
400
dropout = self .net_latent_dropout ,
432
401
last_act = self .net_latent_net_last_act ,
433
402
num_time_layers = self .net_latent_num_time_layers ,
434
- time_layer_init = self .net_latent_time_layer_init ,
435
- residual = self .net_latent_residual ,
436
403
time_last_act = self .net_latent_time_last_act ,
437
404
)
438
405
else :
@@ -447,7 +414,6 @@ def make_model_conf(self):
447
414
embed_channels = self .net_beatgans_embed_channels ,
448
415
enc_out_channels = self .style_ch ,
449
416
enc_pool = self .net_enc_pool ,
450
- enc_pool_tail_layer = self .net_enc_pool_tail_layer ,
451
417
enc_num_res_block = self .net_enc_num_res_blocks ,
452
418
enc_channel_mult = self .net_enc_channel_mult ,
453
419
enc_grad_checkpoint = self .net_enc_grad_checkpoint ,
0 commit comments