Skip to content

Commit 05849ea

Browse files
justinchubypytorchmergebot
authored andcommitted
[ONNX] Create empty opset 17 symbolic file (#83287)
The PR - Creates an empty symbolic file to house the new ops defined in ONNX 17 - Increments the max version to 17 and fixes the doc for version 16 - Enables tests for opset 17 - Updates the IR version in `export.cpp` Pull Request resolved: #83287 Approved by: https://github.com/thiagocrepaldi, https://github.com/AllenTiTaiWang, https://github.com/BowenBao
1 parent 1f2efdc commit 05849ea

File tree

5 files changed

+64
-25
lines changed

5 files changed

+64
-25
lines changed

test/onnx/test_pytorch_onnx_onnxruntime.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,15 @@
3636
)
3737
from torch import Tensor
3838
from torch.nn.utils import rnn as rnn_utils
39-
from torch.onnx import verification
39+
from torch.onnx import _constants, verification
4040
from torch.testing._internal import common_utils
4141
from torch.testing._internal.common_utils import skipIfNoLapack
4242

43+
# The min onnx opset version to test for
44+
MIN_ONNX_OPSET_VERSION = 9
45+
# The max onnx opset version to test for
46+
MAX_ONNX_OPSET_VERSION = _constants.onnx_main_opset
47+
4348

4449
def _init_test_generalized_rcnn_transform():
4550
min_size = 100
@@ -112,13 +117,23 @@ def _construct_tensor_for_quantization_test(
112117
return tensor
113118

114119

115-
def _parameterized_class_attrs_and_values():
120+
def _parameterized_class_attrs_and_values(
121+
min_opset_version: int, max_opset_version: int
122+
):
116123
attrs = ("opset_version", "is_script", "keep_initializers_as_inputs")
117124
input_values = []
118125
input_values.extend(itertools.product((7, 8), (True, False), (True,)))
119126
# Valid opset versions are defined in torch/onnx/_constants.py.
120-
# Versions are intentionally set statically, to not be affected by elsewhere changes.
121-
input_values.extend(itertools.product(range(9, 17), (True, False), (True, False)))
127+
# Versions are intentionally set statically, to not be affected by changes elsewhere.
128+
if min_opset_version < 9:
129+
raise ValueError("min_opset_version must be >= 9")
130+
input_values.extend(
131+
itertools.product(
132+
range(min_opset_version, max_opset_version + 1),
133+
(True, False),
134+
(True, False),
135+
)
136+
)
122137
return {"attrs": attrs, "input_values": input_values}
123138

124139

@@ -143,7 +158,9 @@ def _parametrize_rnn_args(arg_name):
143158

144159

145160
@parameterized.parameterized_class(
146-
**_parameterized_class_attrs_and_values(),
161+
**_parameterized_class_attrs_and_values(
162+
MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION
163+
),
147164
class_name_func=onnx_test_common.parameterize_class_name,
148165
)
149166
@common_utils.instantiate_parametrized_tests

torch/csrc/jit/serialization/export.cpp

+22-18
Original file line numberDiff line numberDiff line change
@@ -57,26 +57,30 @@ namespace onnx_torch = ::torch::onnx;
5757
namespace onnx = ::ONNX_NAMESPACE;
5858

5959
const static int kInvalidOpsetVersion = -1;
60+
const static int kMainOpsetVersion = 17;
6061
// Based on OP_SET_ID_VERSION_MAP in
6162
// https://github.com/onnx/onnx/blob/master/onnx/helper.py.
62-
constexpr static std::array<int64_t, 17> kOpsetVersionToIRVersion = {
63-
kInvalidOpsetVersion,
64-
3,
65-
kInvalidOpsetVersion,
66-
kInvalidOpsetVersion,
67-
kInvalidOpsetVersion,
68-
3,
69-
3,
70-
3,
71-
3,
72-
4,
73-
5,
74-
6,
75-
7,
76-
7,
77-
7,
78-
8,
79-
8};
63+
constexpr static std::array<int64_t, kMainOpsetVersion + 1>
64+
kOpsetVersionToIRVersion = {
65+
kInvalidOpsetVersion,
66+
3, // opset 1
67+
kInvalidOpsetVersion,
68+
kInvalidOpsetVersion,
69+
kInvalidOpsetVersion,
70+
3, // opset 5
71+
3, // opset 6
72+
3, // opset 7
73+
3, // opset 8
74+
4, // opset 9
75+
5, // opset 10
76+
6, // opset 11
77+
7, // opset 12
78+
7, // opset 13
79+
7, // opset 14
80+
8, // opset 15
81+
8, // opset 16
82+
8, // opset 17
83+
};
8084

8185
std::string getNodeStackTraceString(const Node* n) {
8286
return n->sourceRange().str();

torch/onnx/_constants.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
44
onnx_default_opset = 14
5-
onnx_main_opset = 16
5+
onnx_main_opset = 17
66
onnx_stable_opsets = tuple(range(7, onnx_main_opset))
77
onnx_constant_folding_opsets = tuple(range(9, onnx_main_opset + 1))
88

torch/onnx/symbolic_opset16.py

-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
Where
2121
GreaterOrEqual
2222
LessOrEqual
23-
SequenceMap
2423
"""
2524

2625
# EDITING THIS FILE? READ THIS FIRST!

torch/onnx/symbolic_opset17.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""This file exports ONNX ops for opset 17.
2+
3+
Note [ONNX Operators that are added/updated in opset 17]
4+
5+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6+
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set
7+
New operators:
8+
BlackmanWindow
9+
DFT
10+
HammingWindow
11+
HannWindow
12+
LayerNormalization
13+
MelWeightMatrix
14+
STFT
15+
SequenceMap
16+
"""
17+
18+
# EDITING THIS FILE? READ THIS FIRST!
19+
# see Note [Edit Symbolic Files] in symbolic_helper.py

0 commit comments

Comments
 (0)