6
6
import onnx
7
7
import onnxruntime as rt
8
8
import torch
9
+ from mmcv import DictAction
9
10
10
11
from mmdet .core import (build_model_from_cfg , generate_inputs_and_wrap_model ,
11
12
preprocess_example_input )
@@ -22,7 +23,8 @@ def pytorch2onnx(config_path,
22
23
normalize_cfg = None ,
23
24
dataset = 'coco' ,
24
25
test_img = None ,
25
- do_simplify = False ):
26
+ do_simplify = False ,
27
+ cfg_options = None ):
26
28
27
29
input_config = {
28
30
'input_shape' : input_shape ,
@@ -31,10 +33,11 @@ def pytorch2onnx(config_path,
31
33
}
32
34
33
35
# prepare original model and meta for verifying the onnx model
34
- orig_model = build_model_from_cfg (config_path , checkpoint_path )
36
+ orig_model = build_model_from_cfg (
37
+ config_path , checkpoint_path , cfg_options = cfg_options )
35
38
one_img , one_meta = preprocess_example_input (input_config )
36
39
model , tensor_data = generate_inputs_and_wrap_model (
37
- config_path , checkpoint_path , input_config )
40
+ config_path , checkpoint_path , input_config , cfg_options = cfg_options )
38
41
output_names = ['boxes' ]
39
42
if model .with_bbox :
40
43
output_names .append ('labels' )
@@ -189,6 +192,16 @@ def parse_args():
189
192
nargs = '+' ,
190
193
default = [58.395 , 57.12 , 57.375 ],
191
194
help = 'variance value used for preprocess input data' )
195
+ parser .add_argument (
196
+ '--cfg-options' ,
197
+ nargs = '+' ,
198
+ action = DictAction ,
199
+ help = 'override some settings in the used config, the key-value pair '
200
+ 'in xxx=yyy format will be merged into config file. If the value to '
201
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
202
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
203
+ 'Note that the quotation marks are necessary and that no white space '
204
+ 'is allowed.' )
192
205
args = parser .parse_args ()
193
206
return args
194
207
@@ -227,4 +240,5 @@ def parse_args():
227
240
normalize_cfg = normalize_cfg ,
228
241
dataset = args .dataset ,
229
242
test_img = args .test_img ,
230
- do_simplify = args .simplify )
243
+ do_simplify = args .simplify ,
244
+ cfg_options = args .cfg_options )
0 commit comments