Skip to content

Commit 0533cf4

Browse files
authoredJul 9, 2020
Test XLA examples (#5583)
* Test XLA examples * Style * Using `require_torch_tpu` * Style * No need for pytest
1 parent 3bd5519 commit 0533cf4

File tree

2 files changed

+102
-1
lines changed

2 files changed

+102
-1
lines changed
 

‎examples/test_xla_examples.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# coding=utf-8
2+
# Copyright 2018 HuggingFace Inc..
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import argparse
18+
import logging
19+
import sys
20+
import unittest
21+
from time import time
22+
from unittest.mock import patch
23+
24+
from transformers.testing_utils import require_torch_tpu
25+
26+
27+
logging.basicConfig(level=logging.DEBUG)
28+
29+
logger = logging.getLogger()
30+
31+
32+
def get_setup_file():
33+
parser = argparse.ArgumentParser()
34+
parser.add_argument("-f")
35+
args = parser.parse_args()
36+
return args.f
37+
38+
39+
@require_torch_tpu
40+
class TorchXLAExamplesTests(unittest.TestCase):
41+
def test_run_glue(self):
42+
import xla_spawn
43+
44+
stream_handler = logging.StreamHandler(sys.stdout)
45+
logger.addHandler(stream_handler)
46+
47+
output_directory = "run_glue_output"
48+
49+
testargs = f"""
50+
text-classification/run_glue.py
51+
--num_cores=8
52+
text-classification/run_glue.py
53+
--do_train
54+
--do_eval
55+
--task_name=MRPC
56+
--data_dir=../glue_data/MRPC
57+
--cache_dir=./cache_dir
58+
--num_train_epochs=1
59+
--max_seq_length=128
60+
--learning_rate=3e-5
61+
--output_dir={output_directory}
62+
--overwrite_output_dir
63+
--logging_steps=5
64+
--save_steps=5
65+
--overwrite_cache
66+
--tpu_metrics_debug
67+
--model_name_or_path=bert-base-cased
68+
--per_device_train_batch_size=64
69+
--per_device_eval_batch_size=64
70+
--evaluate_during_training
71+
--overwrite_cache
72+
""".split()
73+
with patch.object(sys, "argv", testargs):
74+
start = time()
75+
xla_spawn.main()
76+
end = time()
77+
78+
result = {}
79+
with open(f"{output_directory}/eval_results_mrpc.txt") as f:
80+
lines = f.readlines()
81+
for line in lines:
82+
key, value = line.split(" = ")
83+
result[key] = float(value)
84+
85+
del result["eval_loss"]
86+
for value in result.values():
87+
# Assert that the model trains
88+
self.assertGreaterEqual(value, 0.70)
89+
90+
# Assert that the script takes less than 100 seconds to make sure it doesn't hang.
91+
self.assertLess(end - start, 100)

‎src/transformers/testing_utils.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import unittest
33
from distutils.util import strtobool
44

5-
from transformers.file_utils import _tf_available, _torch_available
5+
from transformers.file_utils import _tf_available, _torch_available, _torch_tpu_available
66

77

88
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
@@ -113,6 +113,16 @@ def require_multigpu(test_case):
113113
return test_case
114114

115115

116+
def require_torch_tpu(test_case):
117+
"""
118+
Decorator marking a test that requires a TPU (in PyTorch).
119+
"""
120+
if not _torch_tpu_available:
121+
return unittest.skip("test requires PyTorch TPU")
122+
123+
return test_case
124+
125+
116126
if _torch_available:
117127
# Set the USE_CUDA environment variable to select a GPU.
118128
torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu"

0 commit comments

Comments
 (0)
Please sign in to comment.