36
36
)
37
37
from torch import Tensor
38
38
from torch .nn .utils import rnn as rnn_utils
39
- from torch .onnx import verification
39
+ from torch .onnx import _constants , verification
40
40
from torch .testing ._internal import common_utils
41
41
from torch .testing ._internal .common_utils import skipIfNoLapack
42
42
43
+ # The min onnx opset version to test for
44
+ MIN_ONNX_OPSET_VERSION = 9
45
+ # The max onnx opset version to test for
46
+ MAX_ONNX_OPSET_VERSION = _constants .onnx_main_opset
47
+
43
48
44
49
def _init_test_generalized_rcnn_transform ():
45
50
min_size = 100
@@ -112,13 +117,23 @@ def _construct_tensor_for_quantization_test(
112
117
return tensor
113
118
114
119
115
- def _parameterized_class_attrs_and_values ():
120
+ def _parameterized_class_attrs_and_values (
121
+ min_opset_version : int , max_opset_version : int
122
+ ):
116
123
attrs = ("opset_version" , "is_script" , "keep_initializers_as_inputs" )
117
124
input_values = []
118
125
input_values .extend (itertools .product ((7 , 8 ), (True , False ), (True ,)))
119
126
# Valid opset versions are defined in torch/onnx/_constants.py.
120
- # Versions are intentionally set statically, to not be affected by elsewhere changes.
121
- input_values .extend (itertools .product (range (9 , 17 ), (True , False ), (True , False )))
127
+ # Versions are intentionally set statically, to not be affected by changes elsewhere.
128
+ if min_opset_version < 9 :
129
+ raise ValueError ("min_opset_version must be >= 9" )
130
+ input_values .extend (
131
+ itertools .product (
132
+ range (min_opset_version , max_opset_version + 1 ),
133
+ (True , False ),
134
+ (True , False ),
135
+ )
136
+ )
122
137
return {"attrs" : attrs , "input_values" : input_values }
123
138
124
139
@@ -143,7 +158,9 @@ def _parametrize_rnn_args(arg_name):
143
158
144
159
145
160
@parameterized .parameterized_class (
146
- ** _parameterized_class_attrs_and_values (),
161
+ ** _parameterized_class_attrs_and_values (
162
+ MIN_ONNX_OPSET_VERSION , MAX_ONNX_OPSET_VERSION
163
+ ),
147
164
class_name_func = onnx_test_common .parameterize_class_name ,
148
165
)
149
166
@common_utils .instantiate_parametrized_tests
0 commit comments