Skip to content

Commit c8d3fa0

Browse files
jpluLysandreJik
andauthored
Check TF ops for ONNX compliance (#10025)
* Add check-ops script * Finish to implement check_tf_ops and start the test * Make the test mandatory only for BERT * Update tf_ops folder * Remove useless classes * Add the ONNX test for GPT2 and BART * Add a onnxruntime slow test + better opset flexibility * Fix test + apply style * fix tests * Switch min opset from 12 to 10 * Update src/transformers/file_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Fix GPT2 * Remove extra shape_list usage * Fix GPT2 * Address Morgan's comments Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
1 parent 93bd2f7 commit c8d3fa0

33 files changed

+468
-17
lines changed

src/transformers/file_utils.py

+14
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,16 @@
151151
_faiss_available = False
152152

153153

154+
_onnx_available = (
155+
importlib.util.find_spec("keras2onnx") is not None and importlib.util.find_spec("onnxruntime") is not None
156+
)
157+
try:
158+
_onxx_version = importlib_metadata.version("onnx")
159+
logger.debug(f"Successfully imported onnx version {_onxx_version}")
160+
except importlib_metadata.PackageNotFoundError:
161+
_onnx_available = False
162+
163+
154164
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
155165
try:
156166
_scatter_version = importlib_metadata.version("torch_scatter")
@@ -230,6 +240,10 @@ def is_tf_available():
230240
return _tf_available
231241

232242

243+
def is_onnx_available():
244+
return _onnx_available
245+
246+
233247
def is_flax_available():
234248
return _flax_available
235249

src/transformers/models/gpt2/modeling_tf_gpt2.py

+3-16
Original file line numberDiff line numberDiff line change
@@ -1030,16 +1030,7 @@ def call(
10301030
)
10311031
- 1
10321032
)
1033-
1034-
def get_seq_element(sequence_position, input_batch):
1035-
return tf.strided_slice(
1036-
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
1037-
)
1038-
1039-
result = tf.map_fn(
1040-
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
1041-
)
1042-
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
1033+
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
10431034
else:
10441035
sequence_lengths = -1
10451036
logger.warning(
@@ -1049,16 +1040,12 @@ def get_seq_element(sequence_position, input_batch):
10491040
loss = None
10501041

10511042
if inputs["labels"] is not None:
1052-
if input_ids is not None:
1053-
batch_size, sequence_length = shape_list(inputs["input_ids"])[:2]
1054-
else:
1055-
batch_size, sequence_length = shape_list(inputs["inputs_embeds"])[:2]
10561043
assert (
1057-
self.config.pad_token_id is not None or batch_size == 1
1044+
self.config.pad_token_id is not None or logits_shape[0] == 1
10581045
), "Cannot handle batch sizes > 1 if no padding token is defined."
10591046

10601047
if not tf.is_tensor(sequence_lengths):
1061-
in_logits = logits[0:batch_size, sequence_lengths]
1048+
in_logits = logits[0 : logits_shape[0], sequence_lengths]
10621049

10631050
loss = self.compute_loss(tf.reshape(inputs["labels"], [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
10641051
pooled_logits = in_logits if in_logits is not None else logits

src/transformers/testing_utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
is_datasets_available,
2929
is_faiss_available,
3030
is_flax_available,
31+
is_onnx_available,
3132
is_pandas_available,
3233
is_scatter_available,
3334
is_sentencepiece_available,
@@ -160,6 +161,13 @@ def require_git_lfs(test_case):
160161
return test_case
161162

162163

164+
def require_onnx(test_case):
165+
if not is_onnx_available():
166+
return unittest.skip("test requires ONNX")(test_case)
167+
else:
168+
return test_case
169+
170+
163171
def require_torch(test_case):
164172
"""
165173
Decorator marking a test that requires PyTorch.

tests/test_modeling_tf_albert.py

+1
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
241241
else ()
242242
)
243243
test_head_masking = False
244+
test_onnx = False
244245

245246
def setUp(self):
246247
self.model_tester = TFAlbertModelTester(self)

tests/test_modeling_tf_bart.py

+2
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
178178
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
179179
is_encoder_decoder = True
180180
test_pruning = False
181+
test_onnx = True
182+
onnx_min_opset = 10
181183

182184
def setUp(self):
183185
self.model_tester = TFBartModelTester(self)

tests/test_modeling_tf_bert.py

+2
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,8 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
274274
else ()
275275
)
276276
test_head_masking = False
277+
test_onnx = True
278+
onnx_min_opset = 10
277279

278280
# special case for ForPreTraining model
279281
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):

tests/test_modeling_tf_blenderbot.py

+1
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
177177
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
178178
is_encoder_decoder = True
179179
test_pruning = False
180+
test_onnx = False
180181

181182
def setUp(self):
182183
self.model_tester = TFBlenderbotModelTester(self)

tests/test_modeling_tf_blenderbot_small.py

+1
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
179179
all_generative_model_classes = (TFBlenderbotSmallForConditionalGeneration,) if is_tf_available() else ()
180180
is_encoder_decoder = True
181181
test_pruning = False
182+
test_onnx = False
182183

183184
def setUp(self):
184185
self.model_tester = TFBlenderbotSmallModelTester(self)

tests/test_modeling_tf_common.py

+63-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import copy
1818
import inspect
19+
import json
1920
import os
2021
import random
2122
import tempfile
@@ -24,7 +25,7 @@
2425
from typing import List, Tuple
2526

2627
from transformers import is_tf_available
27-
from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_tf, slow
28+
from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_onnx, require_tf, slow
2829

2930

3031
if is_tf_available():
@@ -201,6 +202,67 @@ def test_saved_model_creation(self):
201202
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
202203
self.assertTrue(os.path.exists(saved_model_dir))
203204

205+
def test_onnx_compliancy(self):
206+
if not self.test_onnx:
207+
return
208+
209+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
210+
INTERNAL_OPS = [
211+
"Assert",
212+
"AssignVariableOp",
213+
"EmptyTensorList",
214+
"ReadVariableOp",
215+
"ResourceGather",
216+
"TruncatedNormal",
217+
"VarHandleOp",
218+
"VarIsInitializedOp",
219+
]
220+
onnx_ops = []
221+
222+
with open(os.path.join(".", "utils", "tf_ops", "onnx.json")) as f:
223+
onnx_opsets = json.load(f)["opsets"]
224+
225+
for i in range(1, self.onnx_min_opset + 1):
226+
onnx_ops.extend(onnx_opsets[str(i)])
227+
228+
for model_class in self.all_model_classes:
229+
model_op_names = set()
230+
231+
with tf.Graph().as_default() as g:
232+
model = model_class(config)
233+
model(model.dummy_inputs)
234+
235+
for op in g.get_operations():
236+
model_op_names.add(op.node_def.op)
237+
238+
model_op_names = sorted(model_op_names)
239+
incompatible_ops = []
240+
241+
for op in model_op_names:
242+
if op not in onnx_ops and op not in INTERNAL_OPS:
243+
incompatible_ops.append(op)
244+
245+
self.assertEqual(len(incompatible_ops), 0, incompatible_ops)
246+
247+
@require_onnx
248+
@slow
249+
def test_onnx_runtime_optimize(self):
250+
if not self.test_onnx:
251+
return
252+
253+
import keras2onnx
254+
import onnxruntime
255+
256+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
257+
258+
for model_class in self.all_model_classes:
259+
model = model_class(config)
260+
model(model.dummy_inputs)
261+
262+
onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset)
263+
264+
onnxruntime.InferenceSession(onnx_model.SerializeToString())
265+
204266
@slow
205267
def test_saved_model_creation_extended(self):
206268
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

tests/test_modeling_tf_convbert.py

+1
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase):
239239
)
240240
test_pruning = False
241241
test_head_masking = False
242+
test_onnx = False
242243

243244
def setUp(self):
244245
self.model_tester = TFConvBertModelTester(self)

tests/test_modeling_tf_ctrl.py

+1
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
174174
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel, TFCTRLForSequenceClassification) if is_tf_available() else ()
175175
all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else ()
176176
test_head_masking = False
177+
test_onnx = False
177178

178179
def setUp(self):
179180
self.model_tester = TFCTRLModelTester(self)

tests/test_modeling_tf_distilbert.py

+1
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
184184
else None
185185
)
186186
test_head_masking = False
187+
test_onnx = False
187188

188189
def setUp(self):
189190
self.model_tester = TFDistilBertModelTester(self)

tests/test_modeling_tf_dpr.py

+1
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ class TFDPRModelTest(TFModelTesterMixin, unittest.TestCase):
188188
test_missing_keys = False
189189
test_pruning = False
190190
test_head_masking = False
191+
test_onnx = False
191192

192193
def setUp(self):
193194
self.model_tester = TFDPRModelTester(self)

tests/test_modeling_tf_electra.py

+1
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
206206
else ()
207207
)
208208
test_head_masking = False
209+
test_onnx = False
209210

210211
def setUp(self):
211212
self.model_tester = TFElectraModelTester(self)

tests/test_modeling_tf_flaubert.py

+1
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
292292
(TFFlaubertWithLMHeadModel,) if is_tf_available() else ()
293293
) # TODO (PVP): Check other models whether language generation is also applicable
294294
test_head_masking = False
295+
test_onnx = False
295296

296297
def setUp(self):
297298
self.model_tester = TFFlaubertModelTester(self)

tests/test_modeling_tf_funnel.py

+2
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
339339
else ()
340340
)
341341
test_head_masking = False
342+
test_onnx = False
342343

343344
def setUp(self):
344345
self.model_tester = TFFunnelModelTester(self)
@@ -382,6 +383,7 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
382383
(TFFunnelBaseModel, TFFunnelForMultipleChoice, TFFunnelForSequenceClassification) if is_tf_available() else ()
383384
)
384385
test_head_masking = False
386+
test_onnx = False
385387

386388
def setUp(self):
387389
self.model_tester = TFFunnelModelTester(self, base=True)

tests/test_modeling_tf_gpt2.py

+2
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,8 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
333333
)
334334
all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
335335
test_head_masking = False
336+
test_onnx = True
337+
onnx_min_opset = 10
336338

337339
def setUp(self):
338340
self.model_tester = TFGPT2ModelTester(self)

tests/test_modeling_tf_led.py

+2
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
195195
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
196196
is_encoder_decoder = True
197197
test_pruning = False
198+
test_head_masking = False
199+
test_onnx = False
198200

199201
def setUp(self):
200202
self.model_tester = TFLEDModelTester(self)

tests/test_modeling_tf_longformer.py

+2
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,8 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
297297
if is_tf_available()
298298
else ()
299299
)
300+
test_head_masking = False
301+
test_onnx = False
300302

301303
def setUp(self):
302304
self.model_tester = TFLongformerModelTester(self)

tests/test_modeling_tf_lxmert.py

+1
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
362362

363363
all_model_classes = (TFLxmertModel, TFLxmertForPreTraining) if is_tf_available() else ()
364364
test_head_masking = False
365+
test_onnx = False
365366

366367
def setUp(self):
367368
self.model_tester = TFLxmertModelTester(self)

tests/test_modeling_tf_marian.py

+1
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
179179
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
180180
is_encoder_decoder = True
181181
test_pruning = False
182+
test_onnx = False
182183

183184
def setUp(self):
184185
self.model_tester = TFMarianModelTester(self)

tests/test_modeling_tf_mbart.py

+1
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
181181
all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
182182
is_encoder_decoder = True
183183
test_pruning = False
184+
test_onnx = False
184185

185186
def setUp(self):
186187
self.model_tester = TFMBartModelTester(self)

tests/test_modeling_tf_mobilebert.py

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
5656
else ()
5757
)
5858
test_head_masking = False
59+
test_onnx = False
5960

6061
class TFMobileBertModelTester(object):
6162
def __init__(

tests/test_modeling_tf_mpnet.py

+1
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
199199
else ()
200200
)
201201
test_head_masking = False
202+
test_onnx = False
202203

203204
def setUp(self):
204205
self.model_tester = TFMPNetModelTester(self)

tests/test_modeling_tf_openai.py

+1
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
203203
(TFOpenAIGPTLMHeadModel,) if is_tf_available() else ()
204204
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
205205
test_head_masking = False
206+
test_onnx = False
206207

207208
def setUp(self):
208209
self.model_tester = TFOpenAIGPTModelTester(self)

tests/test_modeling_tf_pegasus.py

+1
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
177177
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
178178
is_encoder_decoder = True
179179
test_pruning = False
180+
test_onnx = False
180181

181182
def setUp(self):
182183
self.model_tester = TFPegasusModelTester(self)

tests/test_modeling_tf_roberta.py

+1
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
186186
else ()
187187
)
188188
test_head_masking = False
189+
test_onnx = False
189190

190191
def setUp(self):
191192
self.model_tester = TFRobertaModelTester(self)

tests/test_modeling_tf_t5.py

+2
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
249249
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
250250
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
251251
test_head_masking = False
252+
test_onnx = False
252253

253254
def setUp(self):
254255
self.model_tester = TFT5ModelTester(self)
@@ -427,6 +428,7 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
427428
is_encoder_decoder = False
428429
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
429430
test_head_masking = False
431+
test_onnx = False
430432

431433
def setUp(self):
432434
self.model_tester = TFT5EncoderOnlyModelTester(self)

tests/test_modeling_tf_transfo_xl.py

+1
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
164164
# TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented
165165
test_resize_embeddings = False
166166
test_head_masking = False
167+
test_onnx = False
167168

168169
def setUp(self):
169170
self.model_tester = TFTransfoXLModelTester(self)

0 commit comments

Comments
 (0)