@@ -28,7 +28,7 @@ def init_tb_loggers(opt):
28
28
29
29
def create_train_val_dataloader (opt , logger ):
30
30
# create train and val dataloaders
31
- train_loader , val_loader = None , None
31
+ train_loader , val_loaders = None , []
32
32
for phase , dataset_opt in opt ['datasets' ].items ():
33
33
if phase == 'train' :
34
34
dataset_enlarge_ratio = dataset_opt .get ('dataset_enlarge_ratio' , 1 )
@@ -53,16 +53,16 @@ def create_train_val_dataloader(opt, logger):
53
53
f'\n \t World size (gpu number): { opt ["world_size" ]} '
54
54
f'\n \t Require iter number per epoch: { num_iter_per_epoch } '
55
55
f'\n \t Total epochs: { total_epochs } ; iters: { total_iters } .' )
56
-
57
- elif phase == 'val' :
56
+ elif phase .split ('_' )[0 ] == 'val' :
58
57
val_set = build_dataset (dataset_opt )
59
58
val_loader = build_dataloader (
60
59
val_set , dataset_opt , num_gpu = opt ['num_gpu' ], dist = opt ['dist' ], sampler = None , seed = opt ['manual_seed' ])
61
60
logger .info (f'Number of val images/folders in { dataset_opt ["name" ]} : { len (val_set )} ' )
61
+ val_loaders .append (val_loader )
62
62
else :
63
63
raise ValueError (f'Dataset phase { phase } is not recognized.' )
64
64
65
- return train_loader , train_sampler , val_loader , total_epochs , total_iters
65
+ return train_loader , train_sampler , val_loaders , total_epochs , total_iters
66
66
67
67
68
68
def load_resume_state (opt ):
@@ -118,7 +118,7 @@ def train_pipeline(root_path):
118
118
119
119
# create train and validation dataloaders
120
120
result = create_train_val_dataloader (opt , logger )
121
- train_loader , train_sampler , val_loader , total_epochs , total_iters = result
121
+ train_loader , train_sampler , val_loaders , total_epochs , total_iters = result
122
122
123
123
# create model
124
124
model = build_model (opt )
@@ -187,7 +187,10 @@ def train_pipeline(root_path):
187
187
188
188
# validation
189
189
if opt .get ('val' ) is not None and (current_iter % opt ['val' ]['val_freq' ] == 0 ):
190
- model .validation (val_loader , current_iter , tb_logger , opt ['val' ]['save_img' ])
190
+ if len (val_loaders ) > 1 :
191
+ logger .warning ('Multiple validation datasets are *only* supported by SRModel.' )
192
+ for val_loader in val_loaders :
193
+ model .validation (val_loader , current_iter , tb_logger , opt ['val' ]['save_img' ])
191
194
192
195
data_timer .start ()
193
196
iter_timer .start ()
@@ -201,7 +204,8 @@ def train_pipeline(root_path):
201
204
logger .info ('Save the latest model.' )
202
205
model .save (epoch = - 1 , current_iter = - 1 ) # -1 stands for the latest
203
206
if opt .get ('val' ) is not None :
204
- model .validation (val_loader , current_iter , tb_logger , opt ['val' ]['save_img' ])
207
+ for val_loader in val_loaders :
208
+ model .validation (val_loader , current_iter , tb_logger , opt ['val' ]['save_img' ])
205
209
if tb_logger :
206
210
tb_logger .close ()
207
211
0 commit comments