Skip to content

Commit 01b1466

Browse files
authoredOct 27, 2021
[TPU tests] Enable first TPU examples pytorch (#14121)
* up * up * fix * up * Update examples/pytorch/test_xla_examples.py * correct labels * up * up * up * up * up * up
1 parent 232822f commit 01b1466

File tree

3 files changed

+77
-38
lines changed

3 files changed

+77
-38
lines changed
 

‎.github/workflows/self-scheduled.yml

+39
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,45 @@ jobs:
181181
name: run_all_tests_tf_gpu_test_reports
182182
path: reports
183183

184+
run_all_examples_torch_xla_tpu:
185+
runs-on: [self-hosted, docker-tpu-test, tpu-v3-8]
186+
container:
187+
image: gcr.io/tpu-pytorch/xla:nightly_3.8_tpuvm
188+
options: --privileged -v "/lib/libtpu.so:/lib/libtpu.so" -v /mnt/cache/.cache/huggingface:/mnt/cache/ --shm-size 16G
189+
steps:
190+
- name: Launcher docker
191+
uses: actions/checkout@v2
192+
193+
- name: Install dependencies
194+
run: |
195+
pip install --upgrade pip
196+
pip install .[testing]
197+
198+
- name: Are TPUs recognized by our DL frameworks
199+
env:
200+
XRT_TPU_CONFIG: localservice;0;localhost:51011
201+
run: |
202+
python -c "import torch_xla.core.xla_model as xm; print(xm.xla_device())"
203+
204+
- name: Run example tests on TPU
205+
env:
206+
XRT_TPU_CONFIG: "localservice;0;localhost:51011"
207+
MKL_SERVICE_FORCE_INTEL: "1" # See: https://github.com/pytorch/pytorch/issues/37377
208+
209+
run: |
210+
python -m pytest -n 1 -v --dist=loadfile --make-reports=tests_torch_xla_tpu examples/pytorch/test_xla_examples.py
211+
212+
- name: Failure short reports
213+
if: ${{ always() }}
214+
run: cat reports/tests_torch_xla_tpu_failures_short.txt
215+
216+
- name: Test suite reports artifacts
217+
if: ${{ always() }}
218+
uses: actions/upload-artifact@v2
219+
with:
220+
name: run_all_examples_torch_xla_tpu
221+
path: reports
222+
184223
run_all_tests_torch_multi_gpu:
185224
runs-on: [self-hosted, docker-gpu, multi-gpu]
186225
container:

‎examples/pytorch/test_xla_examples.py

+36-36
Original file line numberDiff line numberDiff line change
@@ -14,80 +14,80 @@
1414
# limitations under the License.
1515

1616

17+
import json
1718
import logging
19+
import os
1820
import sys
19-
import unittest
2021
from time import time
2122
from unittest.mock import patch
2223

23-
from transformers.testing_utils import require_torch_tpu
24+
from transformers.testing_utils import TestCasePlus, require_torch_tpu
2425

2526

2627
logging.basicConfig(level=logging.DEBUG)
2728

2829
logger = logging.getLogger()
2930

3031

32+
def get_results(output_dir):
33+
results = {}
34+
path = os.path.join(output_dir, "all_results.json")
35+
if os.path.exists(path):
36+
with open(path, "r") as f:
37+
results = json.load(f)
38+
else:
39+
raise ValueError(f"can't find {path}")
40+
return results
41+
42+
3143
@require_torch_tpu
32-
class TorchXLAExamplesTests(unittest.TestCase):
44+
class TorchXLAExamplesTests(TestCasePlus):
3345
def test_run_glue(self):
3446
import xla_spawn
3547

3648
stream_handler = logging.StreamHandler(sys.stdout)
3749
logger.addHandler(stream_handler)
3850

39-
output_directory = "run_glue_output"
40-
51+
tmp_dir = self.get_auto_remove_tmp_dir()
4152
testargs = f"""
42-
transformers/examples/text-classification/run_glue.py
53+
./examples/pytorch/text-classification/run_glue.py
4354
--num_cores=8
44-
transformers/examples/text-classification/run_glue.py
55+
./examples/pytorch/text-classification/run_glue.py
56+
--model_name_or_path distilbert-base-uncased
57+
--output_dir {tmp_dir}
58+
--overwrite_output_dir
59+
--train_file ./tests/fixtures/tests_samples/MRPC/train.csv
60+
--validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
4561
--do_train
4662
--do_eval
47-
--task_name=mrpc
48-
--cache_dir=./cache_dir
49-
--num_train_epochs=1
63+
--debug tpu_metrics_debug
64+
--per_device_train_batch_size=2
65+
--per_device_eval_batch_size=1
66+
--learning_rate=1e-4
67+
--max_steps=10
68+
--warmup_steps=2
69+
--seed=42
5070
--max_seq_length=128
51-
--learning_rate=3e-5
52-
--output_dir={output_directory}
53-
--overwrite_output_dir
54-
--logging_steps=5
55-
--save_steps=5
56-
--overwrite_cache
57-
--tpu_metrics_debug
58-
--model_name_or_path=bert-base-cased
59-
--per_device_train_batch_size=64
60-
--per_device_eval_batch_size=64
61-
--evaluation_strategy steps
62-
--overwrite_cache
6371
""".split()
72+
6473
with patch.object(sys, "argv", testargs):
6574
start = time()
6675
xla_spawn.main()
6776
end = time()
6877

69-
result = {}
70-
with open(f"{output_directory}/eval_results_mrpc.txt") as f:
71-
lines = f.readlines()
72-
for line in lines:
73-
key, value = line.split(" = ")
74-
result[key] = float(value)
75-
76-
del result["eval_loss"]
77-
for value in result.values():
78-
# Assert that the model trains
79-
self.assertGreaterEqual(value, 0.70)
78+
result = get_results(tmp_dir)
79+
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
8080

81-
# Assert that the script takes less than 300 seconds to make sure it doesn't hang.
81+
# Assert that the script takes less than 500 seconds to make sure it doesn't hang.
8282
self.assertLess(end - start, 500)
8383

8484
def test_trainer_tpu(self):
8585
import xla_spawn
8686

8787
testargs = """
88-
transformers/tests/test_trainer_tpu.py
88+
./tests/test_trainer_tpu.py
8989
--num_cores=8
90-
transformers/tests/test_trainer_tpu.py
90+
./tests/test_trainer_tpu.py
9191
""".split()
9292
with patch.object(sys, "argv", testargs):
9393
xla_spawn.main()

‎tests/test_trainer_tpu.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def compute_metrics(p: EvalPrediction) -> Dict:
9999

100100
p = trainer.predict(dataset)
101101
logger.info(p.metrics)
102-
if p.metrics["eval_success"] is not True:
102+
if p.metrics["test_success"] is not True:
103103
logger.error(p.metrics)
104104
exit(1)
105105

@@ -113,7 +113,7 @@ def compute_metrics(p: EvalPrediction) -> Dict:
113113

114114
p = trainer.predict(dataset)
115115
logger.info(p.metrics)
116-
if p.metrics["eval_success"] is not True:
116+
if p.metrics["test_success"] is not True:
117117
logger.error(p.metrics)
118118
exit(1)
119119

0 commit comments

Comments
 (0)
Please sign in to comment.