1
1
import os .path as osp
2
- import warnings
3
- from math import inf
4
2
5
- import mmcv
6
3
import torch .distributed as dist
7
- from mmcv .runner import Hook
4
+ from mmcv .runner import DistEvalHook as BaseDistEvalHook
5
+ from mmcv .runner import EvalHook as BaseEvalHook
8
6
from torch .nn .modules .batchnorm import _BatchNorm
9
- from torch .utils .data import DataLoader
10
7
11
- from mmdet .utils import get_root_logger
12
8
9
+ class EvalHook (BaseEvalHook ):
13
10
14
- class EvalHook (Hook ):
15
- """Evaluation hook.
16
-
17
- Notes:
18
- If new arguments are added for EvalHook, tools/test.py,
19
- tools/analysis_tools/eval_metric.py may be effected.
20
-
21
- Attributes:
22
- dataloader (DataLoader): A PyTorch dataloader.
23
- start (int, optional): Evaluation starting epoch. It enables evaluation
24
- before the training starts if ``start`` <= the resuming epoch.
25
- If None, whether to evaluate is merely decided by ``interval``.
26
- Default: None.
27
- interval (int): Evaluation interval (by epochs). Default: 1.
28
- save_best (str, optional): If a metric is specified, it would measure
29
- the best checkpoint during evaluation. The information about best
30
- checkpoint would be save in best.json.
31
- Options are the evaluation metrics to the test dataset. e.g.,
32
- ``bbox_mAP``, ``segm_mAP`` for bbox detection and instance
33
- segmentation. ``AR@100`` for proposal recall. If ``save_best`` is
34
- ``auto``, the first key will be used. The interval of
35
- ``CheckpointHook`` should device EvalHook. Default: None.
36
- rule (str, optional): Comparison rule for best score. If set to None,
37
- it will infer a reasonable rule. Keys such as 'mAP' or 'AR' will
38
- be inferred by 'greater' rule. Keys contain 'loss' will be inferred
39
- by 'less' rule. Options are 'greater', 'less'. Default: None.
40
- **eval_kwargs: Evaluation arguments fed into the evaluate function of
41
- the dataset.
42
- """
43
-
44
- rule_map = {'greater' : lambda x , y : x > y , 'less' : lambda x , y : x < y }
45
- init_value_map = {'greater' : - inf , 'less' : inf }
46
- greater_keys = ['mAP' , 'AR' ]
47
- less_keys = ['loss' ]
48
-
49
- def __init__ (self ,
50
- dataloader ,
51
- start = None ,
52
- interval = 1 ,
53
- by_epoch = True ,
54
- save_best = None ,
55
- rule = None ,
56
- ** eval_kwargs ):
57
- if not isinstance (dataloader , DataLoader ):
58
- raise TypeError ('dataloader must be a pytorch DataLoader, but got'
59
- f' { type (dataloader )} ' )
60
- if not interval > 0 :
61
- raise ValueError (f'interval must be positive, but got { interval } ' )
62
- if start is not None and start < 0 :
63
- warnings .warn (
64
- f'The evaluation start epoch { start } is smaller than 0, '
65
- f'use 0 instead' , UserWarning )
66
- start = 0
67
- self .dataloader = dataloader
68
- self .interval = interval
69
- self .by_epoch = by_epoch
70
- self .start = start
71
- assert isinstance (save_best , str ) or save_best is None
72
- self .save_best = save_best
73
- self .eval_kwargs = eval_kwargs
74
- self .initial_epoch_flag = True
75
-
76
- self .logger = get_root_logger ()
77
-
78
- if self .save_best is not None :
79
- self ._init_rule (rule , self .save_best )
80
-
81
- def _init_rule (self , rule , key_indicator ):
82
- """Initialize rule, key_indicator, comparison_func, and best score.
83
-
84
- Args:
85
- rule (str | None): Comparison rule for best score.
86
- key_indicator (str | None): Key indicator to determine the
87
- comparison rule.
88
- """
89
- if rule not in self .rule_map and rule is not None :
90
- raise KeyError (f'rule must be greater, less or None, '
91
- f'but got { rule } .' )
92
-
93
- if rule is None :
94
- if key_indicator != 'auto' :
95
- if any (key in key_indicator for key in self .greater_keys ):
96
- rule = 'greater'
97
- elif any (key in key_indicator for key in self .less_keys ):
98
- rule = 'less'
99
- else :
100
- raise ValueError (f'Cannot infer the rule for key '
101
- f'{ key_indicator } , thus a specific rule '
102
- f'must be specified.' )
103
- self .rule = rule
104
- self .key_indicator = key_indicator
105
- if self .rule is not None :
106
- self .compare_func = self .rule_map [self .rule ]
107
-
108
- def before_run (self , runner ):
109
- if self .save_best is not None :
110
- if runner .meta is None :
111
- warnings .warn ('runner.meta is None. Creating a empty one.' )
112
- runner .meta = dict ()
113
- runner .meta .setdefault ('hook_msgs' , dict ())
114
-
115
- def before_train_epoch (self , runner ):
116
- """Evaluate the model only at the start of training."""
117
- if not self .initial_epoch_flag :
11
+ def _do_evaluate (self , runner ):
12
+ """perform evaluation and save ckpt."""
13
+ if not self ._should_evaluate (runner ):
118
14
return
119
- if self .start is not None and runner .epoch >= self .start :
120
- self .after_train_epoch (runner )
121
- self .initial_epoch_flag = False
122
-
123
- def evaluation_flag (self , runner ):
124
- """Judge whether to perform_evaluation after this epoch.
125
15
126
- Returns:
127
- bool: The flag indicating whether to perform evaluation.
128
- """
129
- if self .start is None :
130
- if not self .every_n_epochs (runner , self .interval ):
131
- # No evaluation during the interval epochs.
132
- return False
133
- elif (runner .epoch + 1 ) < self .start :
134
- # No evaluation if start is larger than the current epoch.
135
- return False
136
- else :
137
- # Evaluation only at epochs 3, 5, 7... if start==3 and interval==2
138
- if (runner .epoch + 1 - self .start ) % self .interval :
139
- return False
140
- return True
141
-
142
- def after_train_epoch (self , runner ):
143
- if not self .by_epoch or not self .evaluation_flag (runner ):
144
- return
145
16
from mmdet .apis import single_gpu_test
146
17
results = single_gpu_test (runner .model , self .dataloader , show = False )
18
+ runner .log_buffer .output ['eval_iter_num' ] = len (self .dataloader )
147
19
key_score = self .evaluate (runner , results )
148
20
if self .save_best :
149
- self .save_best_checkpoint (runner , key_score )
150
-
151
- def after_train_iter (self , runner ):
152
- if self .by_epoch or not self .every_n_iters (runner , self .interval ):
153
- return
154
- from mmdet .apis import single_gpu_test
155
- results = single_gpu_test (runner .model , self .dataloader , show = False )
156
- key_score = self .evaluate (runner , results )
157
- if self .save_best :
158
- self .save_best_checkpoint (runner , key_score )
159
-
160
- def save_best_checkpoint (self , runner , key_score ):
161
- best_score = runner .meta ['hook_msgs' ].get (
162
- 'best_score' , self .init_value_map [self .rule ])
163
- if self .compare_func (key_score , best_score ):
164
- best_score = key_score
165
- runner .meta ['hook_msgs' ]['best_score' ] = best_score
166
- last_ckpt = runner .meta ['hook_msgs' ]['last_ckpt' ]
167
- runner .meta ['hook_msgs' ]['best_ckpt' ] = last_ckpt
168
- mmcv .symlink (
169
- last_ckpt ,
170
- osp .join (runner .work_dir , f'best_{ self .key_indicator } .pth' ))
171
- time_stamp = runner .epoch + 1 if self .by_epoch else runner .iter + 1
172
- self .logger .info (f'Now best checkpoint is epoch_{ time_stamp } .pth.'
173
- f'Best { self .key_indicator } is { best_score :0.4f} ' )
174
-
175
- def evaluate (self , runner , results ):
176
- eval_res = self .dataloader .dataset .evaluate (
177
- results , logger = runner .logger , ** self .eval_kwargs )
178
- for name , val in eval_res .items ():
179
- runner .log_buffer .output [name ] = val
180
- runner .log_buffer .ready = True
181
- if self .save_best is not None :
182
- if self .key_indicator == 'auto' :
183
- # infer from eval_results
184
- self ._init_rule (self .rule , list (eval_res .keys ())[0 ])
185
- return eval_res [self .key_indicator ]
186
- else :
187
- return None
188
-
21
+ self ._save_ckpt (runner , key_score )
189
22
190
- class DistEvalHook (EvalHook ):
191
- """Distributed evaluation hook.
192
23
193
- Notes:
194
- If new arguments are added, tools/test.py may be effected.
24
+ class DistEvalHook (BaseDistEvalHook ):
195
25
196
- Attributes:
197
- dataloader (DataLoader): A PyTorch dataloader.
198
- start (int, optional): Evaluation starting epoch. It enables evaluation
199
- before the training starts if ``start`` <= the resuming epoch.
200
- If None, whether to evaluate is merely decided by ``interval``.
201
- Default: None.
202
- interval (int): Evaluation interval (by epochs). Default: 1.
203
- tmpdir (str | None): Temporary directory to save the results of all
204
- processes. Default: None.
205
- gpu_collect (bool): Whether to use gpu or cpu to collect results.
206
- Default: False.
207
- save_best (str, optional): If a metric is specified, it would measure
208
- the best checkpoint during evaluation. The information about best
209
- checkpoint would be save in best.json.
210
- Options are the evaluation metrics to the test dataset. e.g.,
211
- ``bbox_mAP``, ``segm_mAP`` for bbox detection and instance
212
- segmentation. ``AR@100`` for proposal recall. If ``save_best`` is
213
- ``auto``, the first key will be used. The interval of
214
- ``CheckpointHook`` should device EvalHook. Default: None.
215
- rule (str | None): Comparison rule for best score. If set to None,
216
- it will infer a reasonable rule. Default: 'None'.
217
- broadcast_bn_buffer (bool): Whether to broadcast the
218
- buffer(running_mean and running_var) of rank 0 to other rank
219
- before evaluation. Default: True.
220
- **eval_kwargs: Evaluation arguments fed into the evaluate function of
221
- the dataset.
222
- """
223
-
224
- def __init__ (self ,
225
- dataloader ,
226
- start = None ,
227
- interval = 1 ,
228
- by_epoch = True ,
229
- tmpdir = None ,
230
- gpu_collect = False ,
231
- save_best = None ,
232
- rule = None ,
233
- broadcast_bn_buffer = True ,
234
- ** eval_kwargs ):
235
- super ().__init__ (
236
- dataloader ,
237
- start = start ,
238
- interval = interval ,
239
- by_epoch = by_epoch ,
240
- save_best = save_best ,
241
- rule = rule ,
242
- ** eval_kwargs )
243
- self .broadcast_bn_buffer = broadcast_bn_buffer
244
- self .tmpdir = tmpdir
245
- self .gpu_collect = gpu_collect
246
-
247
- def _broadcast_bn_buffer (self , runner ):
26
+ def _do_evaluate (self , runner ):
27
+ """perform evaluation and save ckpt."""
248
28
# Synchronization of BatchNorm's buffer (running_mean
249
29
# and running_var) is not supported in the DDP of pytorch,
250
30
# which may cause the inconsistent performance of models in
@@ -258,46 +38,23 @@ def _broadcast_bn_buffer(self, runner):
258
38
dist .broadcast (module .running_var , 0 )
259
39
dist .broadcast (module .running_mean , 0 )
260
40
261
- def after_train_epoch (self , runner ):
262
- if not self .by_epoch or not self .evaluation_flag (runner ):
41
+ if not self ._should_evaluate (runner ):
263
42
return
264
43
265
- if self .broadcast_bn_buffer :
266
- self ._broadcast_bn_buffer (runner )
267
-
268
- from mmdet .apis import multi_gpu_test
269
44
tmpdir = self .tmpdir
270
45
if tmpdir is None :
271
46
tmpdir = osp .join (runner .work_dir , '.eval_hook' )
272
- results = multi_gpu_test (
273
- runner .model ,
274
- self .dataloader ,
275
- tmpdir = tmpdir ,
276
- gpu_collect = self .gpu_collect )
277
- if runner .rank == 0 :
278
- print ('\n ' )
279
- key_score = self .evaluate (runner , results )
280
- if self .save_best :
281
- self .save_best_checkpoint (runner , key_score )
282
-
283
- def after_train_iter (self , runner ):
284
- if self .by_epoch or not self .every_n_iters (runner , self .interval ):
285
- return
286
-
287
- if self .broadcast_bn_buffer :
288
- self ._broadcast_bn_buffer (runner )
289
47
290
48
from mmdet .apis import multi_gpu_test
291
- tmpdir = self .tmpdir
292
- if tmpdir is None :
293
- tmpdir = osp .join (runner .work_dir , '.eval_hook' )
294
49
results = multi_gpu_test (
295
50
runner .model ,
296
51
self .dataloader ,
297
52
tmpdir = tmpdir ,
298
53
gpu_collect = self .gpu_collect )
299
54
if runner .rank == 0 :
300
55
print ('\n ' )
56
+ runner .log_buffer .output ['eval_iter_num' ] = len (self .dataloader )
301
57
key_score = self .evaluate (runner , results )
58
+
302
59
if self .save_best :
303
- self .save_best_checkpoint (runner , key_score )
60
+ self ._save_ckpt (runner , key_score )
0 commit comments