@@ -49,7 +49,7 @@ def export(model, args, f, export_params=True, verbose=False, training=False,
4949 input_names = None , output_names = None , aten = False , export_raw_ir = False ,
5050 operator_export_type = None , opset_version = None , _retain_param_name = True ,
5151 do_constant_folding = False , example_outputs = None , strip_doc_string = True ,
52- dynamic_axes = None , keep_initializers_as_inputs = True ):
52+ dynamic_axes = None , keep_initializers_as_inputs = None ):
5353 if aten or export_raw_ir :
5454 assert operator_export_type is None
5555 assert aten ^ export_raw_ir
@@ -276,7 +276,7 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t
276276 operator_export_type = None , export_type = ExportTypes .PROTOBUF_FILE ,
277277 example_outputs = None , propagate = False , google_printer = False ,
278278 opset_version = None , _retain_param_name = True ,
279- keep_initializers_as_inputs = True ):
279+ keep_initializers_as_inputs = None ):
280280 if aten or export_raw_ir :
281281 assert operator_export_type is None
282282 assert aten ^ export_raw_ir
@@ -294,13 +294,16 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False,
294294 input_names = None , output_names = None , operator_export_type = OperatorExportTypes .ONNX ,
295295 export_type = ExportTypes .PROTOBUF_FILE , example_outputs = None , propagate = False ,
296296 google_printer = False , opset_version = None , _retain_param_name = False ,
297- do_constant_folding = False , keep_initializers_as_inputs = True , fixed_batch_size = False ):
297+ do_constant_folding = False , keep_initializers_as_inputs = None , fixed_batch_size = False ):
298298 from torch .onnx .symbolic_helper import _default_onnx_opset_version , _set_opset_version
299299 from torch .onnx .symbolic_helper import _set_operator_export_type
300300 if opset_version is None :
301301 opset_version = _default_onnx_opset_version
302302 _set_opset_version (opset_version )
303303 _set_operator_export_type (operator_export_type )
304+ val_keep_init_as_ip = True if keep_initializers_as_inputs is None else keep_initializers_as_inputs
305+ if keep_initializers_as_inputs is None and operator_export_type is OperatorExportTypes .ONNX :
306+ val_keep_init_as_ip = False
304307 graph , params_dict , torch_out = _model_to_graph (model , args , verbose ,
305308 training , input_names ,
306309 output_names , operator_export_type ,
@@ -309,7 +312,7 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False,
309312
310313 return graph ._pretty_print_onnx (params_dict , opset_version , False ,
311314 operator_export_type , google_printer ,
312- keep_initializers_as_inputs )
315+ val_keep_init_as_ip )
313316
314317
315318# NOTE: the output `torch_out` will contain the output tensors resulting from
@@ -320,7 +323,7 @@ def _export(model, args, f, export_params=True, verbose=False, training=False,
320323 input_names = None , output_names = None , operator_export_type = OperatorExportTypes .ONNX ,
321324 export_type = ExportTypes .PROTOBUF_FILE , example_outputs = None , propagate = False ,
322325 opset_version = None , _retain_param_name = False , do_constant_folding = False ,
323- strip_doc_string = True , dynamic_axes = None , keep_initializers_as_inputs = True , fixed_batch_size = False ):
326+ strip_doc_string = True , dynamic_axes = None , keep_initializers_as_inputs = None , fixed_batch_size = False ):
324327 if isinstance (model , torch .nn .DataParallel ):
325328 raise ValueError ('torch.nn.DataParallel is not supported by ONNX '
326329 'exporter, please use \' attribute\' module to '
@@ -336,6 +339,9 @@ def _export(model, args, f, export_params=True, verbose=False, training=False,
336339 opset_version = _default_onnx_opset_version
337340 _set_opset_version (opset_version )
338341 _set_operator_export_type (operator_export_type )
342+ val_keep_init_as_ip = True if keep_initializers_as_inputs is None else keep_initializers_as_inputs
343+ if keep_initializers_as_inputs is None and operator_export_type is OperatorExportTypes .ONNX :
344+ val_keep_init_as_ip = False
339345 graph , params_dict , torch_out = _model_to_graph (model , args , verbose ,
340346 training , input_names ,
341347 output_names , operator_export_type ,
@@ -353,11 +359,11 @@ def _export(model, args, f, export_params=True, verbose=False, training=False,
353359 if export_params :
354360 proto , export_map = graph ._export_onnx (
355361 params_dict , opset_version , dynamic_axes , defer_weight_export ,
356- operator_export_type , strip_doc_string , keep_initializers_as_inputs )
362+ operator_export_type , strip_doc_string , val_keep_init_as_ip )
357363 else :
358364 proto , export_map = graph ._export_onnx (
359365 {}, opset_version , dynamic_axes , False , operator_export_type ,
360- strip_doc_string , keep_initializers_as_inputs )
366+ strip_doc_string , val_keep_init_as_ip )
361367
362368 if export_type == ExportTypes .PROTOBUF_FILE :
363369 assert (len (export_map ) == 0 )
0 commit comments