|
16 | 16 |
|
17 | 17 | import copy
|
18 | 18 | import inspect
|
| 19 | +import json |
19 | 20 | import os
|
20 | 21 | import random
|
21 | 22 | import tempfile
|
|
24 | 25 | from typing import List, Tuple
|
25 | 26 |
|
26 | 27 | from transformers import is_tf_available
|
27 |
| -from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_tf, slow |
| 28 | +from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_onnx, require_tf, slow |
28 | 29 |
|
29 | 30 |
|
30 | 31 | if is_tf_available():
|
@@ -201,6 +202,67 @@ def test_saved_model_creation(self):
|
201 | 202 | saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
|
202 | 203 | self.assertTrue(os.path.exists(saved_model_dir))
|
203 | 204 |
|
| 205 | + def test_onnx_compliancy(self): |
| 206 | + if not self.test_onnx: |
| 207 | + return |
| 208 | + |
| 209 | + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| 210 | + INTERNAL_OPS = [ |
| 211 | + "Assert", |
| 212 | + "AssignVariableOp", |
| 213 | + "EmptyTensorList", |
| 214 | + "ReadVariableOp", |
| 215 | + "ResourceGather", |
| 216 | + "TruncatedNormal", |
| 217 | + "VarHandleOp", |
| 218 | + "VarIsInitializedOp", |
| 219 | + ] |
| 220 | + onnx_ops = [] |
| 221 | + |
| 222 | + with open(os.path.join(".", "utils", "tf_ops", "onnx.json")) as f: |
| 223 | + onnx_opsets = json.load(f)["opsets"] |
| 224 | + |
| 225 | + for i in range(1, self.onnx_min_opset + 1): |
| 226 | + onnx_ops.extend(onnx_opsets[str(i)]) |
| 227 | + |
| 228 | + for model_class in self.all_model_classes: |
| 229 | + model_op_names = set() |
| 230 | + |
| 231 | + with tf.Graph().as_default() as g: |
| 232 | + model = model_class(config) |
| 233 | + model(model.dummy_inputs) |
| 234 | + |
| 235 | + for op in g.get_operations(): |
| 236 | + model_op_names.add(op.node_def.op) |
| 237 | + |
| 238 | + model_op_names = sorted(model_op_names) |
| 239 | + incompatible_ops = [] |
| 240 | + |
| 241 | + for op in model_op_names: |
| 242 | + if op not in onnx_ops and op not in INTERNAL_OPS: |
| 243 | + incompatible_ops.append(op) |
| 244 | + |
| 245 | + self.assertEqual(len(incompatible_ops), 0, incompatible_ops) |
| 246 | + |
| 247 | + @require_onnx |
| 248 | + @slow |
| 249 | + def test_onnx_runtime_optimize(self): |
| 250 | + if not self.test_onnx: |
| 251 | + return |
| 252 | + |
| 253 | + import keras2onnx |
| 254 | + import onnxruntime |
| 255 | + |
| 256 | + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| 257 | + |
| 258 | + for model_class in self.all_model_classes: |
| 259 | + model = model_class(config) |
| 260 | + model(model.dummy_inputs) |
| 261 | + |
| 262 | + onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset) |
| 263 | + |
| 264 | + onnxruntime.InferenceSession(onnx_model.SerializeToString()) |
| 265 | + |
204 | 266 | @slow
|
205 | 267 | def test_saved_model_creation_extended(self):
|
206 | 268 | config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
0 commit comments