Skip to content

Commit 818e211

Browse files
committed
Support best metric values during validation
1 parent 2a590e9 commit 818e211

8 files changed

+66
-13
lines changed

basicsr/metrics/psnr_ssim.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
@METRIC_REGISTRY.register()
9-
def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False):
9+
def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
1010
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
1111
1212
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
@@ -81,7 +81,7 @@ def _ssim(img, img2):
8181

8282

8383
@METRIC_REGISTRY.register()
84-
def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False):
84+
def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
8585
"""Calculate SSIM (structural similarity).
8686
8787
Ref:

basicsr/models/base_model.py

+19
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,25 @@ def validation(self, dataloader, current_iter, tb_logger, save_img=False):
4747
else:
4848
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
4949

50+
def _initialize_best_metric_results(self):
51+
"""Initialize the best metric results dict for recording the best metric value and iteration."""
52+
if not hasattr(self, 'best_metric_results'):
53+
self.best_metric_results = dict()
54+
for metric, content in self.opt['val']['metrics'].items():
55+
better = content.get('better', 'higher')
56+
init_val = float('-inf') if better == 'higher' else float('inf')
57+
self.best_metric_results[metric] = dict(better=better, val=init_val, iter=-1)
58+
59+
def _update_best_metric_result(self, metric, val, current_iter):
60+
if self.best_metric_results[metric]['better'] == 'higher':
61+
if val >= self.best_metric_results[metric]['val']:
62+
self.best_metric_results[metric]['val'] = val
63+
self.best_metric_results[metric]['iter'] = current_iter
64+
else:
65+
if val <= self.best_metric_results[metric]['val']:
66+
self.best_metric_results[metric]['val'] = val
67+
self.best_metric_results[metric]['iter'] = current_iter
68+
5069
def model_ema(self, decay=0.999):
5170
net_g = self.get_bare_model(self.net_g)
5271

basicsr/models/sr_model.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,15 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
136136
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
137137
dataset_name = dataloader.dataset.opt['name']
138138
with_metrics = self.opt['val'].get('metrics') is not None
139-
if with_metrics:
139+
140+
if with_metrics and not hasattr(self, 'metric_results'): # only execute in the first run
140141
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
142+
# initialize the best metric results
143+
self._initialize_best_metric_results()
144+
# zero self.metric_results
145+
if with_metrics:
146+
self.metric_results = {metric: 0 for metric in self.metric_results}
147+
141148
metric_data = dict()
142149
pbar = tqdm(total=len(dataloader), unit='image')
143150

@@ -183,13 +190,20 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
183190
if with_metrics:
184191
for metric in self.metric_results.keys():
185192
self.metric_results[metric] /= (idx + 1)
193+
# update the best metric result
194+
self._update_best_metric_result(metric, self.metric_results[metric], current_iter)
186195

187196
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
188197

189198
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
190199
log_str = f'Validation {dataset_name}\n'
191200
for metric, value in self.metric_results.items():
192-
log_str += f'\t # {metric}: {value:.4f}\n'
201+
log_str += f'\t # {metric}: {value:.4f}'
202+
if hasattr(self, 'best_metric_results'):
203+
log_str += (f'\tBest: {self.best_metric_results[metric]["val"]:.4f} @ '
204+
f'{self.best_metric_results[metric]["iter"]} iter')
205+
log_str += '\n'
206+
193207
logger = get_root_logger()
194208
logger.info(log_str)
195209
if tb_logger:

basicsr/models/video_base_model.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,20 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
2424
# 'folder1': tensor (num_frame x len(metrics)),
2525
# 'folder2': tensor (num_frame x len(metrics))
2626
# }
27-
if with_metrics and not hasattr(self, 'metric_results'):
27+
if with_metrics and not hasattr(self, 'metric_results'): # only execute in the first run
2828
self.metric_results = {}
2929
num_frame_each_folder = Counter(dataset.data_info['folder'])
3030
for folder, num_frame in num_frame_each_folder.items():
3131
self.metric_results[folder] = torch.zeros(
3232
num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
33+
# initialize the best metric results
34+
self._initialize_best_metric_results()
35+
# zero self.metric_results
3336
rank, world_size = get_dist_info()
3437
if with_metrics:
3538
for _, tensor in self.metric_results.items():
3639
tensor.zero_()
40+
3741
metric_data = dict()
3842
# record all frames (border and center frames)
3943
if rank == 0:
@@ -111,6 +115,7 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
111115
self.dist_validation(dataloader, current_iter, tb_logger, save_img)
112116

113117
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
118+
# ----------------- calculate the average values for each folder, and for each metric ----------------- #
114119
# average all frames for each sub-folder
115120
# metric_results_avg is a dict:{
116121
# 'folder1': tensor (len(metrics)),
@@ -131,12 +136,18 @@ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
131136
# average among folders
132137
for metric in total_avg_results.keys():
133138
total_avg_results[metric] /= len(metric_results_avg)
139+
# update the best metric result
140+
self._update_best_metric_result(metric, total_avg_results[metric], current_iter)
134141

142+
# ------------------------------------------ log the metric ------------------------------------------ #
135143
log_str = f'Validation {dataset_name}\n'
136144
for metric_idx, (metric, value) in enumerate(total_avg_results.items()):
137145
log_str += f'\t # {metric}: {value:.4f}'
138146
for folder, tensor in metric_results_avg.items():
139147
log_str += f'\t # {folder}: {tensor[metric_idx].item():.4f}'
148+
if hasattr(self, 'best_metric_results'):
149+
log_str += (f'\n\t Best: {self.best_metric_results[metric]["val"]:.4f} @ '
150+
f'{self.best_metric_results[metric]["iter"]} iter')
140151
log_str += '\n'
141152

142153
logger = get_root_logger()

basicsr/models/video_recurrent_model.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -72,24 +72,27 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
7272
# 'folder1': tensor (num_frame x len(metrics)),
7373
# 'folder2': tensor (num_frame x len(metrics))
7474
# }
75-
if with_metrics and not hasattr(self, 'metric_results'):
75+
if with_metrics and not hasattr(self, 'metric_results'): # only execute in the first run
7676
self.metric_results = {}
7777
num_frame_each_folder = Counter(dataset.data_info['folder'])
7878
for folder, num_frame in num_frame_each_folder.items():
7979
self.metric_results[folder] = torch.zeros(
8080
num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
81-
81+
# initialize the best metric results
82+
self._initialize_best_metric_results()
83+
# zero self.metric_results
8284
rank, world_size = get_dist_info()
8385
if with_metrics:
8486
for _, tensor in self.metric_results.items():
8587
tensor.zero_()
88+
8689
metric_data = dict()
8790
num_folders = len(dataset)
8891
num_pad = (world_size - (num_folders % world_size)) % world_size
8992
if rank == 0:
9093
pbar = tqdm(total=len(dataset), unit='folder')
91-
# Will evaluate (num_folders + num_pad) times, but only the first
92-
# num_folders results will be recorded. (To avoid wait-dead)
94+
# Will evaluate (num_folders + num_pad) times, but only the first num_folders results will be recorded.
95+
# (To avoid wait-dead)
9396
for i in range(rank, num_folders + num_pad, world_size):
9497
idx = min(i, num_folders - 1)
9598
val_data = dataset[idx]

options/train/BasicVSR/train_BasicVSR_REDS.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
name: BasicVSR_REDS
33
model_type: VideoRecurrentModel
44
scale: 4
5-
num_gpu: 8 # set num_gpu: 0 for cpu mode
5+
num_gpu: auto # official: 8 GPUs
66
manual_seed: 0
77

88
# dataset and data loader settings

options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
name: 101_EDVR_M_x4_SR_REDS_woTSA_600k_B4G8_valREDS4_wandb
33
model_type: EDVRModel
44
scale: 4
5-
num_gpu: 8 # set num_gpu: 0 for cpu mode
5+
num_gpu: auto # official: 8 GPUs
66
manual_seed: 10
77

88
# dataset and data loader settings

options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml

+8-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ datasets:
1313
train:
1414
name: DIV2K
1515
type: PairedImageDataset
16-
dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub
17-
dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub
16+
dataroot_gt: datasets/DF2K/DIV2K_train_HR_sub
17+
dataroot_lq: datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub
18+
meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt
1819
# (for lmdb)
1920
# dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb
2021
# dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb
@@ -92,6 +93,11 @@ val:
9293
type: calculate_psnr
9394
crop_border: 4
9495
test_y_channel: false
96+
better: higher # the higher, the better. Default: higher
97+
niqe:
98+
type: calculate_niqe
99+
crop_border: 4
100+
better: lower # the lower, the better
95101

96102
# logging settings
97103
logger:

0 commit comments

Comments
 (0)