Skip to content

Commit 9e79ada

Browse files
authored
Use MMCV's EvalHook in MMDetection (#4806)
* Use EvalHook in MMCV * Add DeprecationWarning * update * fix unit test * add comment for unit test
1 parent 7d49b7b commit 9e79ada

File tree

3 files changed

+26
-278
lines changed

3 files changed

+26
-278
lines changed

mmdet/core/evaluation/eval_hooks.py

+15-258
Original file line numberDiff line numberDiff line change
@@ -1,250 +1,30 @@
11
import os.path as osp
2-
import warnings
3-
from math import inf
42

5-
import mmcv
63
import torch.distributed as dist
7-
from mmcv.runner import Hook
4+
from mmcv.runner import DistEvalHook as BaseDistEvalHook
5+
from mmcv.runner import EvalHook as BaseEvalHook
86
from torch.nn.modules.batchnorm import _BatchNorm
9-
from torch.utils.data import DataLoader
107

11-
from mmdet.utils import get_root_logger
128

9+
class EvalHook(BaseEvalHook):
1310

14-
class EvalHook(Hook):
15-
"""Evaluation hook.
16-
17-
Notes:
18-
If new arguments are added for EvalHook, tools/test.py,
19-
tools/analysis_tools/eval_metric.py may be effected.
20-
21-
Attributes:
22-
dataloader (DataLoader): A PyTorch dataloader.
23-
start (int, optional): Evaluation starting epoch. It enables evaluation
24-
before the training starts if ``start`` <= the resuming epoch.
25-
If None, whether to evaluate is merely decided by ``interval``.
26-
Default: None.
27-
interval (int): Evaluation interval (by epochs). Default: 1.
28-
save_best (str, optional): If a metric is specified, it would measure
29-
the best checkpoint during evaluation. The information about best
30-
checkpoint would be save in best.json.
31-
Options are the evaluation metrics to the test dataset. e.g.,
32-
``bbox_mAP``, ``segm_mAP`` for bbox detection and instance
33-
segmentation. ``AR@100`` for proposal recall. If ``save_best`` is
34-
``auto``, the first key will be used. The interval of
35-
``CheckpointHook`` should device EvalHook. Default: None.
36-
rule (str, optional): Comparison rule for best score. If set to None,
37-
it will infer a reasonable rule. Keys such as 'mAP' or 'AR' will
38-
be inferred by 'greater' rule. Keys contain 'loss' will be inferred
39-
by 'less' rule. Options are 'greater', 'less'. Default: None.
40-
**eval_kwargs: Evaluation arguments fed into the evaluate function of
41-
the dataset.
42-
"""
43-
44-
rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
45-
init_value_map = {'greater': -inf, 'less': inf}
46-
greater_keys = ['mAP', 'AR']
47-
less_keys = ['loss']
48-
49-
def __init__(self,
50-
dataloader,
51-
start=None,
52-
interval=1,
53-
by_epoch=True,
54-
save_best=None,
55-
rule=None,
56-
**eval_kwargs):
57-
if not isinstance(dataloader, DataLoader):
58-
raise TypeError('dataloader must be a pytorch DataLoader, but got'
59-
f' {type(dataloader)}')
60-
if not interval > 0:
61-
raise ValueError(f'interval must be positive, but got {interval}')
62-
if start is not None and start < 0:
63-
warnings.warn(
64-
f'The evaluation start epoch {start} is smaller than 0, '
65-
f'use 0 instead', UserWarning)
66-
start = 0
67-
self.dataloader = dataloader
68-
self.interval = interval
69-
self.by_epoch = by_epoch
70-
self.start = start
71-
assert isinstance(save_best, str) or save_best is None
72-
self.save_best = save_best
73-
self.eval_kwargs = eval_kwargs
74-
self.initial_epoch_flag = True
75-
76-
self.logger = get_root_logger()
77-
78-
if self.save_best is not None:
79-
self._init_rule(rule, self.save_best)
80-
81-
def _init_rule(self, rule, key_indicator):
82-
"""Initialize rule, key_indicator, comparison_func, and best score.
83-
84-
Args:
85-
rule (str | None): Comparison rule for best score.
86-
key_indicator (str | None): Key indicator to determine the
87-
comparison rule.
88-
"""
89-
if rule not in self.rule_map and rule is not None:
90-
raise KeyError(f'rule must be greater, less or None, '
91-
f'but got {rule}.')
92-
93-
if rule is None:
94-
if key_indicator != 'auto':
95-
if any(key in key_indicator for key in self.greater_keys):
96-
rule = 'greater'
97-
elif any(key in key_indicator for key in self.less_keys):
98-
rule = 'less'
99-
else:
100-
raise ValueError(f'Cannot infer the rule for key '
101-
f'{key_indicator}, thus a specific rule '
102-
f'must be specified.')
103-
self.rule = rule
104-
self.key_indicator = key_indicator
105-
if self.rule is not None:
106-
self.compare_func = self.rule_map[self.rule]
107-
108-
def before_run(self, runner):
109-
if self.save_best is not None:
110-
if runner.meta is None:
111-
warnings.warn('runner.meta is None. Creating a empty one.')
112-
runner.meta = dict()
113-
runner.meta.setdefault('hook_msgs', dict())
114-
115-
def before_train_epoch(self, runner):
116-
"""Evaluate the model only at the start of training."""
117-
if not self.initial_epoch_flag:
11+
def _do_evaluate(self, runner):
12+
"""perform evaluation and save ckpt."""
13+
if not self._should_evaluate(runner):
11814
return
119-
if self.start is not None and runner.epoch >= self.start:
120-
self.after_train_epoch(runner)
121-
self.initial_epoch_flag = False
122-
123-
def evaluation_flag(self, runner):
124-
"""Judge whether to perform_evaluation after this epoch.
12515

126-
Returns:
127-
bool: The flag indicating whether to perform evaluation.
128-
"""
129-
if self.start is None:
130-
if not self.every_n_epochs(runner, self.interval):
131-
# No evaluation during the interval epochs.
132-
return False
133-
elif (runner.epoch + 1) < self.start:
134-
# No evaluation if start is larger than the current epoch.
135-
return False
136-
else:
137-
# Evaluation only at epochs 3, 5, 7... if start==3 and interval==2
138-
if (runner.epoch + 1 - self.start) % self.interval:
139-
return False
140-
return True
141-
142-
def after_train_epoch(self, runner):
143-
if not self.by_epoch or not self.evaluation_flag(runner):
144-
return
14516
from mmdet.apis import single_gpu_test
14617
results = single_gpu_test(runner.model, self.dataloader, show=False)
18+
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
14719
key_score = self.evaluate(runner, results)
14820
if self.save_best:
149-
self.save_best_checkpoint(runner, key_score)
150-
151-
def after_train_iter(self, runner):
152-
if self.by_epoch or not self.every_n_iters(runner, self.interval):
153-
return
154-
from mmdet.apis import single_gpu_test
155-
results = single_gpu_test(runner.model, self.dataloader, show=False)
156-
key_score = self.evaluate(runner, results)
157-
if self.save_best:
158-
self.save_best_checkpoint(runner, key_score)
159-
160-
def save_best_checkpoint(self, runner, key_score):
161-
best_score = runner.meta['hook_msgs'].get(
162-
'best_score', self.init_value_map[self.rule])
163-
if self.compare_func(key_score, best_score):
164-
best_score = key_score
165-
runner.meta['hook_msgs']['best_score'] = best_score
166-
last_ckpt = runner.meta['hook_msgs']['last_ckpt']
167-
runner.meta['hook_msgs']['best_ckpt'] = last_ckpt
168-
mmcv.symlink(
169-
last_ckpt,
170-
osp.join(runner.work_dir, f'best_{self.key_indicator}.pth'))
171-
time_stamp = runner.epoch + 1 if self.by_epoch else runner.iter + 1
172-
self.logger.info(f'Now best checkpoint is epoch_{time_stamp}.pth.'
173-
f'Best {self.key_indicator} is {best_score:0.4f}')
174-
175-
def evaluate(self, runner, results):
176-
eval_res = self.dataloader.dataset.evaluate(
177-
results, logger=runner.logger, **self.eval_kwargs)
178-
for name, val in eval_res.items():
179-
runner.log_buffer.output[name] = val
180-
runner.log_buffer.ready = True
181-
if self.save_best is not None:
182-
if self.key_indicator == 'auto':
183-
# infer from eval_results
184-
self._init_rule(self.rule, list(eval_res.keys())[0])
185-
return eval_res[self.key_indicator]
186-
else:
187-
return None
188-
21+
self._save_ckpt(runner, key_score)
18922

190-
class DistEvalHook(EvalHook):
191-
"""Distributed evaluation hook.
19223

193-
Notes:
194-
If new arguments are added, tools/test.py may be effected.
24+
class DistEvalHook(BaseDistEvalHook):
19525

196-
Attributes:
197-
dataloader (DataLoader): A PyTorch dataloader.
198-
start (int, optional): Evaluation starting epoch. It enables evaluation
199-
before the training starts if ``start`` <= the resuming epoch.
200-
If None, whether to evaluate is merely decided by ``interval``.
201-
Default: None.
202-
interval (int): Evaluation interval (by epochs). Default: 1.
203-
tmpdir (str | None): Temporary directory to save the results of all
204-
processes. Default: None.
205-
gpu_collect (bool): Whether to use gpu or cpu to collect results.
206-
Default: False.
207-
save_best (str, optional): If a metric is specified, it would measure
208-
the best checkpoint during evaluation. The information about best
209-
checkpoint would be save in best.json.
210-
Options are the evaluation metrics to the test dataset. e.g.,
211-
``bbox_mAP``, ``segm_mAP`` for bbox detection and instance
212-
segmentation. ``AR@100`` for proposal recall. If ``save_best`` is
213-
``auto``, the first key will be used. The interval of
214-
``CheckpointHook`` should device EvalHook. Default: None.
215-
rule (str | None): Comparison rule for best score. If set to None,
216-
it will infer a reasonable rule. Default: 'None'.
217-
broadcast_bn_buffer (bool): Whether to broadcast the
218-
buffer(running_mean and running_var) of rank 0 to other rank
219-
before evaluation. Default: True.
220-
**eval_kwargs: Evaluation arguments fed into the evaluate function of
221-
the dataset.
222-
"""
223-
224-
def __init__(self,
225-
dataloader,
226-
start=None,
227-
interval=1,
228-
by_epoch=True,
229-
tmpdir=None,
230-
gpu_collect=False,
231-
save_best=None,
232-
rule=None,
233-
broadcast_bn_buffer=True,
234-
**eval_kwargs):
235-
super().__init__(
236-
dataloader,
237-
start=start,
238-
interval=interval,
239-
by_epoch=by_epoch,
240-
save_best=save_best,
241-
rule=rule,
242-
**eval_kwargs)
243-
self.broadcast_bn_buffer = broadcast_bn_buffer
244-
self.tmpdir = tmpdir
245-
self.gpu_collect = gpu_collect
246-
247-
def _broadcast_bn_buffer(self, runner):
26+
def _do_evaluate(self, runner):
27+
"""perform evaluation and save ckpt."""
24828
# Synchronization of BatchNorm's buffer (running_mean
24929
# and running_var) is not supported in the DDP of pytorch,
25030
# which may cause the inconsistent performance of models in
@@ -258,46 +38,23 @@ def _broadcast_bn_buffer(self, runner):
25838
dist.broadcast(module.running_var, 0)
25939
dist.broadcast(module.running_mean, 0)
26040

261-
def after_train_epoch(self, runner):
262-
if not self.by_epoch or not self.evaluation_flag(runner):
41+
if not self._should_evaluate(runner):
26342
return
26443

265-
if self.broadcast_bn_buffer:
266-
self._broadcast_bn_buffer(runner)
267-
268-
from mmdet.apis import multi_gpu_test
26944
tmpdir = self.tmpdir
27045
if tmpdir is None:
27146
tmpdir = osp.join(runner.work_dir, '.eval_hook')
272-
results = multi_gpu_test(
273-
runner.model,
274-
self.dataloader,
275-
tmpdir=tmpdir,
276-
gpu_collect=self.gpu_collect)
277-
if runner.rank == 0:
278-
print('\n')
279-
key_score = self.evaluate(runner, results)
280-
if self.save_best:
281-
self.save_best_checkpoint(runner, key_score)
282-
283-
def after_train_iter(self, runner):
284-
if self.by_epoch or not self.every_n_iters(runner, self.interval):
285-
return
286-
287-
if self.broadcast_bn_buffer:
288-
self._broadcast_bn_buffer(runner)
28947

29048
from mmdet.apis import multi_gpu_test
291-
tmpdir = self.tmpdir
292-
if tmpdir is None:
293-
tmpdir = osp.join(runner.work_dir, '.eval_hook')
29449
results = multi_gpu_test(
29550
runner.model,
29651
self.dataloader,
29752
tmpdir=tmpdir,
29853
gpu_collect=self.gpu_collect)
29954
if runner.rank == 0:
30055
print('\n')
56+
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
30157
key_score = self.evaluate(runner, results)
58+
30259
if self.save_best:
303-
self.save_best_checkpoint(runner, key_score)
60+
self._save_ckpt(runner, key_score)

tests/test_data/test_datasets/test_common.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,12 @@ def test_evaluation_hook(EvalHookParam):
274274
runner.run([dataloader], [('train', 1)], 2)
275275
assert evalhook.evaluate.call_count == 3 # before epoch1 and after e1 & e2
276276

277+
# the evaluation start epoch cannot be less than 0
277278
runner = _build_demo_runner()
278-
with pytest.warns(UserWarning):
279-
evalhook = EvalHookParam(dataloader, start=-2)
279+
with pytest.raises(ValueError):
280+
EvalHookParam(dataloader, start=-2)
281+
282+
evalhook = EvalHookParam(dataloader, start=0)
280283
evalhook.evaluate = MagicMock()
281284
runner.register_hook(evalhook)
282285
runner.run([dataloader], [('train', 1)], 2)

0 commit comments

Comments
 (0)