Skip to content

Commit 9309e26

Browse files
committed
support multiple val dataset; best metric results support support multiple val dataset
1 parent 6bffc94 commit 9309e26

File tree

7 files changed

+53
-35
lines changed

7 files changed

+53
-35
lines changed

basicsr/models/base_model.py

+21-15
Original file line numberDiff line numberDiff line change
@@ -47,24 +47,30 @@ def validation(self, dataloader, current_iter, tb_logger, save_img=False):
4747
else:
4848
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
4949

50-
def _initialize_best_metric_results(self):
50+
def _initialize_best_metric_results(self, dataset_name):
5151
"""Initialize the best metric results dict for recording the best metric value and iteration."""
52-
if not hasattr(self, 'best_metric_results'):
52+
if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results:
53+
return
54+
elif not hasattr(self, 'best_metric_results'):
5355
self.best_metric_results = dict()
54-
for metric, content in self.opt['val']['metrics'].items():
55-
better = content.get('better', 'higher')
56-
init_val = float('-inf') if better == 'higher' else float('inf')
57-
self.best_metric_results[metric] = dict(better=better, val=init_val, iter=-1)
58-
59-
def _update_best_metric_result(self, metric, val, current_iter):
60-
if self.best_metric_results[metric]['better'] == 'higher':
61-
if val >= self.best_metric_results[metric]['val']:
62-
self.best_metric_results[metric]['val'] = val
63-
self.best_metric_results[metric]['iter'] = current_iter
56+
57+
# add a dataset record
58+
record = dict()
59+
for metric, content in self.opt['val']['metrics'].items():
60+
better = content.get('better', 'higher')
61+
init_val = float('-inf') if better == 'higher' else float('inf')
62+
record[metric] = dict(better=better, val=init_val, iter=-1)
63+
self.best_metric_results[dataset_name] = record
64+
65+
def _update_best_metric_result(self, dataset_name, metric, val, current_iter):
66+
if self.best_metric_results[dataset_name][metric]['better'] == 'higher':
67+
if val >= self.best_metric_results[dataset_name][metric]['val']:
68+
self.best_metric_results[dataset_name][metric]['val'] = val
69+
self.best_metric_results[dataset_name][metric]['iter'] = current_iter
6470
else:
65-
if val <= self.best_metric_results[metric]['val']:
66-
self.best_metric_results[metric]['val'] = val
67-
self.best_metric_results[metric]['iter'] = current_iter
71+
if val <= self.best_metric_results[dataset_name][metric]['val']:
72+
self.best_metric_results[dataset_name][metric]['val'] = val
73+
self.best_metric_results[dataset_name][metric]['iter'] = current_iter
6874

6975
def model_ema(self, decay=0.999):
7076
net_g = self.get_bare_model(self.net_g)

basicsr/models/sr_model.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
139139

140140
if with_metrics and not hasattr(self, 'metric_results'): # only execute in the first run
141141
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
142-
# initialize the best metric results
143-
self._initialize_best_metric_results()
142+
# initialize the best metric results for each dataset_name (supporting multiple validation datasets)
143+
self._initialize_best_metric_results(dataset_name)
144144
# zero self.metric_results
145145
if with_metrics:
146146
self.metric_results = {metric: 0 for metric in self.metric_results}
@@ -191,7 +191,7 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
191191
for metric in self.metric_results.keys():
192192
self.metric_results[metric] /= (idx + 1)
193193
# update the best metric result
194-
self._update_best_metric_result(metric, self.metric_results[metric], current_iter)
194+
self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
195195

196196
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
197197

@@ -200,8 +200,8 @@ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
200200
for metric, value in self.metric_results.items():
201201
log_str += f'\t # {metric}: {value:.4f}'
202202
if hasattr(self, 'best_metric_results'):
203-
log_str += (f'\tBest: {self.best_metric_results[metric]["val"]:.4f} @ '
204-
f'{self.best_metric_results[metric]["iter"]} iter')
203+
log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
204+
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
205205
log_str += '\n'
206206

207207
logger = get_root_logger()

basicsr/models/video_base_model.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
3030
for folder, num_frame in num_frame_each_folder.items():
3131
self.metric_results[folder] = torch.zeros(
3232
num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
33-
# initialize the best metric results
34-
self._initialize_best_metric_results()
33+
# initialize the best metric results
34+
self._initialize_best_metric_results(dataset_name)
3535
# zero self.metric_results
3636
rank, world_size = get_dist_info()
3737
if with_metrics:
@@ -137,7 +137,7 @@ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
137137
for metric in total_avg_results.keys():
138138
total_avg_results[metric] /= len(metric_results_avg)
139139
# update the best metric result
140-
self._update_best_metric_result(metric, total_avg_results[metric], current_iter)
140+
self._update_best_metric_result(dataset_name, metric, total_avg_results[metric], current_iter)
141141

142142
# ------------------------------------------ log the metric ------------------------------------------ #
143143
log_str = f'Validation {dataset_name}\n'
@@ -146,8 +146,8 @@ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
146146
for folder, tensor in metric_results_avg.items():
147147
log_str += f'\t # {folder}: {tensor[metric_idx].item():.4f}'
148148
if hasattr(self, 'best_metric_results'):
149-
log_str += (f'\n\t Best: {self.best_metric_results[metric]["val"]:.4f} @ '
150-
f'{self.best_metric_results[metric]["iter"]} iter')
149+
log_str += (f'\n\t Best: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
150+
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
151151
log_str += '\n'
152152

153153
logger = get_root_logger()

basicsr/models/video_recurrent_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
7878
for folder, num_frame in num_frame_each_folder.items():
7979
self.metric_results[folder] = torch.zeros(
8080
num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
81-
# initialize the best metric results
82-
self._initialize_best_metric_results()
81+
# initialize the best metric results
82+
self._initialize_best_metric_results(dataset_name)
8383
# zero self.metric_results
8484
rank, world_size = get_dist_info()
8585
if with_metrics:

basicsr/train.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def init_tb_loggers(opt):
2828

2929
def create_train_val_dataloader(opt, logger):
3030
# create train and val dataloaders
31-
train_loader, val_loader = None, None
31+
train_loader, val_loaders = None, []
3232
for phase, dataset_opt in opt['datasets'].items():
3333
if phase == 'train':
3434
dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
@@ -53,16 +53,16 @@ def create_train_val_dataloader(opt, logger):
5353
f'\n\tWorld size (gpu number): {opt["world_size"]}'
5454
f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
5555
f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
56-
57-
elif phase == 'val':
56+
elif phase.split('_')[0] == 'val':
5857
val_set = build_dataset(dataset_opt)
5958
val_loader = build_dataloader(
6059
val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
6160
logger.info(f'Number of val images/folders in {dataset_opt["name"]}: {len(val_set)}')
61+
val_loaders.append(val_loader)
6262
else:
6363
raise ValueError(f'Dataset phase {phase} is not recognized.')
6464

65-
return train_loader, train_sampler, val_loader, total_epochs, total_iters
65+
return train_loader, train_sampler, val_loaders, total_epochs, total_iters
6666

6767

6868
def load_resume_state(opt):
@@ -118,7 +118,7 @@ def train_pipeline(root_path):
118118

119119
# create train and validation dataloaders
120120
result = create_train_val_dataloader(opt, logger)
121-
train_loader, train_sampler, val_loader, total_epochs, total_iters = result
121+
train_loader, train_sampler, val_loaders, total_epochs, total_iters = result
122122

123123
# create model
124124
model = build_model(opt)
@@ -187,7 +187,10 @@ def train_pipeline(root_path):
187187

188188
# validation
189189
if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0):
190-
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
190+
if len(val_loaders) > 1:
191+
logger.warning('Multiple validation datasets are *only* supported by SRModel.')
192+
for val_loader in val_loaders:
193+
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
191194

192195
data_timer.start()
193196
iter_timer.start()
@@ -201,7 +204,8 @@ def train_pipeline(root_path):
201204
logger.info('Save the latest model.')
202205
model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
203206
if opt.get('val') is not None:
204-
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
207+
for val_loader in val_loaders:
208+
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
205209
if tb_logger:
206210
tb_logger.close()
207211

basicsr/utils/options.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def parse_options(root_path, is_train=True):
139139

140140
# datasets
141141
for phase, dataset in opt['datasets'].items():
142-
# for several datasets, e.g., test_1, test_2
142+
# for multiple datasets, e.g., val_1, val_2; test_1, test_2
143143
phase = phase.split('_')[0]
144144
dataset['phase'] = phase
145145
if 'scale' in opt:

options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml

+8
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@ datasets:
4444
io_backend:
4545
type: disk
4646

47+
val_2:
48+
name: Set14
49+
type: PairedImageDataset
50+
dataroot_gt: datasets/Set14/GTmod12
51+
dataroot_lq: datasets/Set14/LRbicx4
52+
io_backend:
53+
type: disk
54+
4755
# network structures
4856
network_g:
4957
type: MSRResNet

0 commit comments

Comments
 (0)