@@ -24,16 +24,20 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
24
24
# 'folder1': tensor (num_frame x len(metrics)),
25
25
# 'folder2': tensor (num_frame x len(metrics))
26
26
# }
27
- if with_metrics and not hasattr (self , 'metric_results' ):
27
+ if with_metrics and not hasattr (self , 'metric_results' ): # only execute in the first run
28
28
self .metric_results = {}
29
29
num_frame_each_folder = Counter (dataset .data_info ['folder' ])
30
30
for folder , num_frame in num_frame_each_folder .items ():
31
31
self .metric_results [folder ] = torch .zeros (
32
32
num_frame , len (self .opt ['val' ]['metrics' ]), dtype = torch .float32 , device = 'cuda' )
33
+ # initialize the best metric results
34
+ self ._initialize_best_metric_results ()
35
+ # zero self.metric_results
33
36
rank , world_size = get_dist_info ()
34
37
if with_metrics :
35
38
for _ , tensor in self .metric_results .items ():
36
39
tensor .zero_ ()
40
+
37
41
metric_data = dict ()
38
42
# record all frames (border and center frames)
39
43
if rank == 0 :
@@ -111,6 +115,7 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
111
115
self .dist_validation (dataloader , current_iter , tb_logger , save_img )
112
116
113
117
def _log_validation_metric_values (self , current_iter , dataset_name , tb_logger ):
118
+ # ----------------- calculate the average values for each folder, and for each metric ----------------- #
114
119
# average all frames for each sub-folder
115
120
# metric_results_avg is a dict:{
116
121
# 'folder1': tensor (len(metrics)),
@@ -131,12 +136,18 @@ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
131
136
# average among folders
132
137
for metric in total_avg_results .keys ():
133
138
total_avg_results [metric ] /= len (metric_results_avg )
139
+ # update the best metric result
140
+ self ._update_best_metric_result (metric , total_avg_results [metric ], current_iter )
134
141
142
+ # ------------------------------------------ log the metric ------------------------------------------ #
135
143
log_str = f'Validation { dataset_name } \n '
136
144
for metric_idx , (metric , value ) in enumerate (total_avg_results .items ()):
137
145
log_str += f'\t # { metric } : { value :.4f} '
138
146
for folder , tensor in metric_results_avg .items ():
139
147
log_str += f'\t # { folder } : { tensor [metric_idx ].item ():.4f} '
148
+ if hasattr (self , 'best_metric_results' ):
149
+ log_str += (f'\n \t Best: { self .best_metric_results [metric ]["val" ]:.4f} @ '
150
+ f'{ self .best_metric_results [metric ]["iter" ]} iter' )
140
151
log_str += '\n '
141
152
142
153
logger = get_root_logger ()
0 commit comments