Skip to content

Commit 2ca1323

Browse files
authored
[Feature]: Add parser argument 'cfg-options' for some tools (#4691)
1 parent e1fe9c8 commit 2ca1323

File tree

8 files changed

+135
-15
lines changed

8 files changed

+135
-15
lines changed

mmdet/core/export/pytorch2onnx.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
from mmcv.runner import load_checkpoint
77

88

9-
def generate_inputs_and_wrap_model(config_path, checkpoint_path, input_config):
9+
def generate_inputs_and_wrap_model(config_path,
10+
checkpoint_path,
11+
input_config,
12+
cfg_options=None):
1013
"""Prepare sample input and wrap model for ONNX export.
1114
1215
The ONNX export API only accept args, and all inputs should be
@@ -37,7 +40,8 @@ def generate_inputs_and_wrap_model(config_path, checkpoint_path, input_config):
3740
the model while exporting.
3841
"""
3942

40-
model = build_model_from_cfg(config_path, checkpoint_path)
43+
model = build_model_from_cfg(
44+
config_path, checkpoint_path, cfg_options=cfg_options)
4145
one_img, one_meta = preprocess_example_input(input_config)
4246
tensor_data = [one_img]
4347
model.forward = partial(
@@ -57,7 +61,7 @@ def generate_inputs_and_wrap_model(config_path, checkpoint_path, input_config):
5761
return model, tensor_data
5862

5963

60-
def build_model_from_cfg(config_path, checkpoint_path):
64+
def build_model_from_cfg(config_path, checkpoint_path, cfg_options=None):
6165
"""Build a model from config and load the given checkpoint.
6266
6367
Args:
@@ -71,10 +75,15 @@ def build_model_from_cfg(config_path, checkpoint_path):
7175
from mmdet.models import build_detector
7276

7377
cfg = mmcv.Config.fromfile(config_path)
78+
if cfg_options is not None:
79+
cfg.merge_from_dict(cfg_options)
7480
# import modules from string list.
7581
if cfg.get('custom_imports', None):
7682
from mmcv.utils import import_modules_from_strings
7783
import_modules_from_strings(**cfg['custom_imports'])
84+
# set cudnn_benchmark
85+
if cfg.get('cudnn_benchmark', False):
86+
torch.backends.cudnn.benchmark = True
7887
cfg.model.pretrained = None
7988
cfg.data.test.test_mode = True
8089

tools/analysis_tools/benchmark.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33

44
import torch
5-
from mmcv import Config
5+
from mmcv import Config, DictAction
66
from mmcv.cnn import fuse_conv_bn
77
from mmcv.parallel import MMDataParallel
88
from mmcv.runner import load_checkpoint, wrap_fp16_model
@@ -23,6 +23,16 @@ def parse_args():
2323
action='store_true',
2424
help='Whether to fuse conv and bn, this will slightly increase'
2525
'the inference speed')
26+
parser.add_argument(
27+
'--cfg-options',
28+
nargs='+',
29+
action=DictAction,
30+
help='override some settings in the used config, the key-value pair '
31+
'in xxx=yyy format will be merged into config file. If the value to '
32+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
33+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
34+
'Note that the quotation marks are necessary and that no white space '
35+
'is allowed.')
2636
args = parser.parse_args()
2737
return args
2838

@@ -31,6 +41,12 @@ def main():
3141
args = parse_args()
3242

3343
cfg = Config.fromfile(args.config)
44+
if args.cfg_options is not None:
45+
cfg.merge_from_dict(args.cfg_options)
46+
# import modules from string list.
47+
if cfg.get('custom_imports', None):
48+
from mmcv.utils import import_modules_from_strings
49+
import_modules_from_strings(**cfg['custom_imports'])
3450
# set cudnn_benchmark
3551
if cfg.get('cudnn_benchmark', False):
3652
torch.backends.cudnn.benchmark = True

tools/analysis_tools/eval_metric.py

+4
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def main():
5555

5656
if args.cfg_options is not None:
5757
cfg.merge_from_dict(args.cfg_options)
58+
# import modules from string list.
59+
if cfg.get('custom_imports', None):
60+
from mmcv.utils import import_modules_from_strings
61+
import_modules_from_strings(**cfg['custom_imports'])
5862
cfg.data.test.test_mode = True
5963

6064
dataset = build_dataset(cfg.data.test)

tools/analysis_tools/get_flops.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
22

33
import torch
4-
from mmcv import Config
4+
from mmcv import Config, DictAction
55

66
from mmdet.models import build_detector
77

@@ -20,6 +20,16 @@ def parse_args():
2020
nargs='+',
2121
default=[1280, 800],
2222
help='input image size')
23+
parser.add_argument(
24+
'--cfg-options',
25+
nargs='+',
26+
action=DictAction,
27+
help='override some settings in the used config, the key-value pair '
28+
'in xxx=yyy format will be merged into config file. If the value to '
29+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
30+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
31+
'Note that the quotation marks are necessary and that no white space '
32+
'is allowed.')
2333
args = parser.parse_args()
2434
return args
2535

@@ -36,6 +46,12 @@ def main():
3646
raise ValueError('invalid input shape')
3747

3848
cfg = Config.fromfile(args.config)
49+
if args.cfg_options is not None:
50+
cfg.merge_from_dict(args.cfg_options)
51+
# import modules from string list.
52+
if cfg.get('custom_imports', None):
53+
from mmcv.utils import import_modules_from_strings
54+
import_modules_from_strings(**cfg['custom_imports'])
3955

4056
model = build_detector(
4157
cfg.model,

tools/analysis_tools/test_robustness.py

+17
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import mmcv
77
import torch
8+
from mmcv import DictAction
89
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
910
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
1011
wrap_fp16_model)
@@ -160,6 +161,16 @@ def parse_args():
160161
choices=['all', 'benchmark'],
161162
default='benchmark',
162163
help='aggregate all results or only those for benchmark corruptions')
164+
parser.add_argument(
165+
'--cfg-options',
166+
nargs='+',
167+
action=DictAction,
168+
help='override some settings in the used config, the key-value pair '
169+
'in xxx=yyy format will be merged into config file. If the value to '
170+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
171+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
172+
'Note that the quotation marks are necessary and that no white space '
173+
'is allowed.')
163174
args = parser.parse_args()
164175
if 'LOCAL_RANK' not in os.environ:
165176
os.environ['LOCAL_RANK'] = str(args.local_rank)
@@ -177,6 +188,12 @@ def main():
177188
raise ValueError('The output file must be a pkl file.')
178189

179190
cfg = mmcv.Config.fromfile(args.config)
191+
if args.cfg_options is not None:
192+
cfg.merge_from_dict(args.cfg_options)
193+
# import modules from string list.
194+
if cfg.get('custom_imports', None):
195+
from mmcv.utils import import_modules_from_strings
196+
import_modules_from_strings(**cfg['custom_imports'])
180197
# set cudnn_benchmark
181198
if cfg.get('cudnn_benchmark', False):
182199
torch.backends.cudnn.benchmark = True

tools/deployment/pytorch2onnx.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import onnx
77
import onnxruntime as rt
88
import torch
9+
from mmcv import DictAction
910

1011
from mmdet.core import (build_model_from_cfg, generate_inputs_and_wrap_model,
1112
preprocess_example_input)
@@ -22,7 +23,8 @@ def pytorch2onnx(config_path,
2223
normalize_cfg=None,
2324
dataset='coco',
2425
test_img=None,
25-
do_simplify=False):
26+
do_simplify=False,
27+
cfg_options=None):
2628

2729
input_config = {
2830
'input_shape': input_shape,
@@ -31,10 +33,11 @@ def pytorch2onnx(config_path,
3133
}
3234

3335
# prepare original model and meta for verifying the onnx model
34-
orig_model = build_model_from_cfg(config_path, checkpoint_path)
36+
orig_model = build_model_from_cfg(
37+
config_path, checkpoint_path, cfg_options=cfg_options)
3538
one_img, one_meta = preprocess_example_input(input_config)
3639
model, tensor_data = generate_inputs_and_wrap_model(
37-
config_path, checkpoint_path, input_config)
40+
config_path, checkpoint_path, input_config, cfg_options=cfg_options)
3841
output_names = ['boxes']
3942
if model.with_bbox:
4043
output_names.append('labels')
@@ -189,6 +192,16 @@ def parse_args():
189192
nargs='+',
190193
default=[58.395, 57.12, 57.375],
191194
help='variance value used for preprocess input data')
195+
parser.add_argument(
196+
'--cfg-options',
197+
nargs='+',
198+
action=DictAction,
199+
help='override some settings in the used config, the key-value pair '
200+
'in xxx=yyy format will be merged into config file. If the value to '
201+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
202+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
203+
'Note that the quotation marks are necessary and that no white space '
204+
'is allowed.')
192205
args = parser.parse_args()
193206
return args
194207

@@ -227,4 +240,5 @@ def parse_args():
227240
normalize_cfg=normalize_cfg,
228241
dataset=args.dataset,
229242
test_img=args.test_img,
230-
do_simplify=args.simplify)
243+
do_simplify=args.simplify,
244+
cfg_options=args.cfg_options)

tools/misc/browse_dataset.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pathlib import Path
44

55
import mmcv
6-
from mmcv import Config
6+
from mmcv import Config, DictAction
77

88
from mmdet.core.utils import mask2ndarray
99
from mmdet.core.visualization import imshow_det_bboxes
@@ -30,12 +30,28 @@ def parse_args():
3030
type=float,
3131
default=2,
3232
help='the interval of show (s)')
33+
parser.add_argument(
34+
'--cfg-options',
35+
nargs='+',
36+
action=DictAction,
37+
help='override some settings in the used config, the key-value pair '
38+
'in xxx=yyy format will be merged into config file. If the value to '
39+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
40+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
41+
'Note that the quotation marks are necessary and that no white space '
42+
'is allowed.')
3343
args = parser.parse_args()
3444
return args
3545

3646

37-
def retrieve_data_cfg(config_path, skip_type):
47+
def retrieve_data_cfg(config_path, skip_type, cfg_options):
3848
cfg = Config.fromfile(config_path)
49+
if cfg_options is not None:
50+
cfg.merge_from_dict(cfg_options)
51+
# import modules from string list.
52+
if cfg.get('custom_imports', None):
53+
from mmcv.utils import import_modules_from_strings
54+
import_modules_from_strings(**cfg['custom_imports'])
3955
train_data_cfg = cfg.data.train
4056
train_data_cfg['pipeline'] = [
4157
x for x in train_data_cfg.pipeline if x['type'] not in skip_type
@@ -46,7 +62,7 @@ def retrieve_data_cfg(config_path, skip_type):
4662

4763
def main():
4864
args = parse_args()
49-
cfg = retrieve_data_cfg(args.config, args.skip_type)
65+
cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options)
5066

5167
dataset = build_dataset(cfg.data.train)
5268

tools/misc/print_config.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import warnings
23

34
from mmcv import Config, DictAction
45

@@ -7,18 +8,45 @@ def parse_args():
78
parser = argparse.ArgumentParser(description='Print the whole config')
89
parser.add_argument('config', help='config file path')
910
parser.add_argument(
10-
'--options', nargs='+', action=DictAction, help='arguments in dict')
11+
'--options',
12+
nargs='+',
13+
action=DictAction,
14+
help='override some settings in the used config, the key-value pair '
15+
'in xxx=yyy format will be merged into config file (deprecate), '
16+
'change to --cfg-options instead.')
17+
parser.add_argument(
18+
'--cfg-options',
19+
nargs='+',
20+
action=DictAction,
21+
help='override some settings in the used config, the key-value pair '
22+
'in xxx=yyy format will be merged into config file. If the value to '
23+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
24+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
25+
'Note that the quotation marks are necessary and that no white space '
26+
'is allowed.')
1127
args = parser.parse_args()
1228

29+
if args.options and args.cfg_options:
30+
raise ValueError(
31+
'--options and --cfg-options cannot be both '
32+
'specified, --options is deprecated in favor of --cfg-options')
33+
if args.options:
34+
warnings.warn('--options is deprecated in favor of --cfg-options')
35+
args.cfg_options = args.options
36+
1337
return args
1438

1539

1640
def main():
1741
args = parse_args()
1842

1943
cfg = Config.fromfile(args.config)
20-
if args.options is not None:
21-
cfg.merge_from_dict(args.options)
44+
if args.cfg_options is not None:
45+
cfg.merge_from_dict(args.cfg_options)
46+
# import modules from string list.
47+
if cfg.get('custom_imports', None):
48+
from mmcv.utils import import_modules_from_strings
49+
import_modules_from_strings(**cfg['custom_imports'])
2250
print(f'Config:\n{cfg.pretty_text}')
2351

2452

0 commit comments

Comments
 (0)