Skip to content

Commit af3b15b

Browse files
Spandan Tiwarifacebook-github-bot
authored andcommitted
Setting automatic default selection for ONNX IR v4 semantics in ONNX export API (pytorch#26146)
Summary: This is a follow-up PR for pytorch#23284. In that PR we had removed changing the default behavior for `keep_initializers_as_input` argument to the export API. With this PR we are enabling that change in that if `keep_initializers_as_input` is not specified then value/behavior for this argument is chosen automatically depending on whether the export type is ONNX or not. This was part of the earlier PR was removed for further review. The test points have also been updated. This change may fail some internal tests which may require explicitly setting `keep_initializers_as_input=True` to preserve old behavior. Pull Request resolved: pytorch#26146 Reviewed By: hl475 Differential Revision: D17369677 Pulled By: houseroad fbshipit-source-id: 2aec2cff50d215714ee8769505ef24d2b7865a11
1 parent 8b12602 commit af3b15b

File tree

3 files changed

+22
-12
lines changed

3 files changed

+22
-12
lines changed

test/onnx/test_operators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def test_batchnorm(self):
247247

248248
def test_batchnorm_onnx_irv4(self):
249249
x = torch.ones(2, 2, 2, 2, requires_grad=True)
250-
self.assertONNX(nn.BatchNorm2d(2), x, keep_initializers_as_inputs=False)
250+
self.assertONNX(nn.BatchNorm2d(2), x)
251251

252252
def test_batchnorm_1d(self):
253253
x = torch.ones(2, 2, requires_grad=True)
@@ -263,7 +263,7 @@ def test_conv(self):
263263

264264
def test_conv_onnx_irv4(self):
265265
x = torch.ones(20, 16, 50, 40, requires_grad=True)
266-
self.assertONNX(nn.Conv2d(16, 13, 3, bias=False), x, keep_initializers_as_inputs=False)
266+
self.assertONNX(nn.Conv2d(16, 13, 3, bias=False), x)
267267

268268
def test_conv_variable_length(self):
269269
x = torch.ones(5, 3, 6, 6, requires_grad=True)

torch/onnx/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def export(model, args, f, export_params=True, verbose=False, training=False,
3131
input_names=None, output_names=None, aten=False, export_raw_ir=False,
3232
operator_export_type=None, opset_version=None, _retain_param_name=True,
3333
do_constant_folding=False, example_outputs=None, strip_doc_string=True,
34-
dynamic_axes=None, keep_initializers_as_inputs=True):
34+
dynamic_axes=None, keep_initializers_as_inputs=None):
3535
r"""
3636
Export a model into ONNX format. This exporter runs your model
3737
once in order to get a trace of its execution to be exported;
@@ -123,12 +123,16 @@ def export(model, args, f, export_params=True, verbose=False, training=False,
123123
124124
(c). MIXED MODE OF (a) and (b)
125125
dynamic_axes = {'input_1':[0, 2, 3], 'input_2':{0:'batch'}, 'output':[0,1]}
126-
keep_initializers_as_inputs (bool, default True): If True, all the initializers
126+
keep_initializers_as_inputs (bool, default None): If True, all the initializers
127127
(typically corresponding to parameters) in the exported graph will also be
128128
added as inputs to the graph. If False, then initializers are not added as
129129
inputs to the graph, and only the non-parameter inputs are added as inputs.
130130
This may allow for better optimizations (such as constant folding etc.) by
131-
backends/runtimes that execute these graphs.
131+
backends/runtimes that execute these graphs. If unspecified (default None),
132+
then the behavior is chosen automatically as follows. If operator_export_type
133+
is OperatorExportTypes.ONNX, the behavior is equivalent to setting this
134+
argument to False. For other values of operator_export_type, the behavior is
135+
equivalent to setting this argument to True.
132136
"""
133137

134138
from torch.onnx import utils

torch/onnx/utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def export(model, args, f, export_params=True, verbose=False, training=False,
4949
input_names=None, output_names=None, aten=False, export_raw_ir=False,
5050
operator_export_type=None, opset_version=None, _retain_param_name=True,
5151
do_constant_folding=False, example_outputs=None, strip_doc_string=True,
52-
dynamic_axes=None, keep_initializers_as_inputs=True):
52+
dynamic_axes=None, keep_initializers_as_inputs=None):
5353
if aten or export_raw_ir:
5454
assert operator_export_type is None
5555
assert aten ^ export_raw_ir
@@ -276,7 +276,7 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t
276276
operator_export_type=None, export_type=ExportTypes.PROTOBUF_FILE,
277277
example_outputs=None, propagate=False, google_printer=False,
278278
opset_version=None, _retain_param_name=True,
279-
keep_initializers_as_inputs=True):
279+
keep_initializers_as_inputs=None):
280280
if aten or export_raw_ir:
281281
assert operator_export_type is None
282282
assert aten ^ export_raw_ir
@@ -294,13 +294,16 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False,
294294
input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
295295
export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False,
296296
google_printer=False, opset_version=None, _retain_param_name=False,
297-
do_constant_folding=False, keep_initializers_as_inputs=True, fixed_batch_size=False):
297+
do_constant_folding=False, keep_initializers_as_inputs=None, fixed_batch_size=False):
298298
from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version
299299
from torch.onnx.symbolic_helper import _set_operator_export_type
300300
if opset_version is None:
301301
opset_version = _default_onnx_opset_version
302302
_set_opset_version(opset_version)
303303
_set_operator_export_type(operator_export_type)
304+
val_keep_init_as_ip = True if keep_initializers_as_inputs is None else keep_initializers_as_inputs
305+
if keep_initializers_as_inputs is None and operator_export_type is OperatorExportTypes.ONNX:
306+
val_keep_init_as_ip = False
304307
graph, params_dict, torch_out = _model_to_graph(model, args, verbose,
305308
training, input_names,
306309
output_names, operator_export_type,
@@ -309,7 +312,7 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False,
309312

310313
return graph._pretty_print_onnx(params_dict, opset_version, False,
311314
operator_export_type, google_printer,
312-
keep_initializers_as_inputs)
315+
val_keep_init_as_ip)
313316

314317

315318
# NOTE: the output `torch_out` will contain the output tensors resulting from
@@ -320,7 +323,7 @@ def _export(model, args, f, export_params=True, verbose=False, training=False,
320323
input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
321324
export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False,
322325
opset_version=None, _retain_param_name=False, do_constant_folding=False,
323-
strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=True, fixed_batch_size=False):
326+
strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=None, fixed_batch_size=False):
324327
if isinstance(model, torch.nn.DataParallel):
325328
raise ValueError('torch.nn.DataParallel is not supported by ONNX '
326329
'exporter, please use \'attribute\' module to '
@@ -336,6 +339,9 @@ def _export(model, args, f, export_params=True, verbose=False, training=False,
336339
opset_version = _default_onnx_opset_version
337340
_set_opset_version(opset_version)
338341
_set_operator_export_type(operator_export_type)
342+
val_keep_init_as_ip = True if keep_initializers_as_inputs is None else keep_initializers_as_inputs
343+
if keep_initializers_as_inputs is None and operator_export_type is OperatorExportTypes.ONNX:
344+
val_keep_init_as_ip = False
339345
graph, params_dict, torch_out = _model_to_graph(model, args, verbose,
340346
training, input_names,
341347
output_names, operator_export_type,
@@ -353,11 +359,11 @@ def _export(model, args, f, export_params=True, verbose=False, training=False,
353359
if export_params:
354360
proto, export_map = graph._export_onnx(
355361
params_dict, opset_version, dynamic_axes, defer_weight_export,
356-
operator_export_type, strip_doc_string, keep_initializers_as_inputs)
362+
operator_export_type, strip_doc_string, val_keep_init_as_ip)
357363
else:
358364
proto, export_map = graph._export_onnx(
359365
{}, opset_version, dynamic_axes, False, operator_export_type,
360-
strip_doc_string, keep_initializers_as_inputs)
366+
strip_doc_string, val_keep_init_as_ip)
361367

362368
if export_type == ExportTypes.PROTOBUF_FILE:
363369
assert(len(export_map) == 0)

0 commit comments

Comments
 (0)