|
14 | 14 | # limitations under the License.
|
15 | 15 |
|
16 | 16 |
|
| 17 | +import json |
17 | 18 | import logging
|
| 19 | +import os |
18 | 20 | import sys
|
19 |
| -import unittest |
20 | 21 | from time import time
|
21 | 22 | from unittest.mock import patch
|
22 | 23 |
|
23 |
| -from transformers.testing_utils import require_torch_tpu |
| 24 | +from transformers.testing_utils import TestCasePlus, require_torch_tpu |
24 | 25 |
|
25 | 26 |
|
26 | 27 | logging.basicConfig(level=logging.DEBUG)
|
27 | 28 |
|
28 | 29 | logger = logging.getLogger()
|
29 | 30 |
|
30 | 31 |
|
| 32 | +def get_results(output_dir): |
| 33 | + results = {} |
| 34 | + path = os.path.join(output_dir, "all_results.json") |
| 35 | + if os.path.exists(path): |
| 36 | + with open(path, "r") as f: |
| 37 | + results = json.load(f) |
| 38 | + else: |
| 39 | + raise ValueError(f"can't find {path}") |
| 40 | + return results |
| 41 | + |
| 42 | + |
31 | 43 | @require_torch_tpu
|
32 |
| -class TorchXLAExamplesTests(unittest.TestCase): |
| 44 | +class TorchXLAExamplesTests(TestCasePlus): |
33 | 45 | def test_run_glue(self):
|
34 | 46 | import xla_spawn
|
35 | 47 |
|
36 | 48 | stream_handler = logging.StreamHandler(sys.stdout)
|
37 | 49 | logger.addHandler(stream_handler)
|
38 | 50 |
|
39 |
| - output_directory = "run_glue_output" |
40 |
| - |
| 51 | + tmp_dir = self.get_auto_remove_tmp_dir() |
41 | 52 | testargs = f"""
|
42 |
| - transformers/examples/text-classification/run_glue.py |
| 53 | + ./examples/pytorch/text-classification/run_glue.py |
43 | 54 | --num_cores=8
|
44 |
| - transformers/examples/text-classification/run_glue.py |
| 55 | + ./examples/pytorch/text-classification/run_glue.py |
| 56 | + --model_name_or_path distilbert-base-uncased |
| 57 | + --output_dir {tmp_dir} |
| 58 | + --overwrite_output_dir |
| 59 | + --train_file ./tests/fixtures/tests_samples/MRPC/train.csv |
| 60 | + --validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv |
45 | 61 | --do_train
|
46 | 62 | --do_eval
|
47 |
| - --task_name=mrpc |
48 |
| - --cache_dir=./cache_dir |
49 |
| - --num_train_epochs=1 |
| 63 | + --debug tpu_metrics_debug |
| 64 | + --per_device_train_batch_size=2 |
| 65 | + --per_device_eval_batch_size=1 |
| 66 | + --learning_rate=1e-4 |
| 67 | + --max_steps=10 |
| 68 | + --warmup_steps=2 |
| 69 | + --seed=42 |
50 | 70 | --max_seq_length=128
|
51 |
| - --learning_rate=3e-5 |
52 |
| - --output_dir={output_directory} |
53 |
| - --overwrite_output_dir |
54 |
| - --logging_steps=5 |
55 |
| - --save_steps=5 |
56 |
| - --overwrite_cache |
57 |
| - --tpu_metrics_debug |
58 |
| - --model_name_or_path=bert-base-cased |
59 |
| - --per_device_train_batch_size=64 |
60 |
| - --per_device_eval_batch_size=64 |
61 |
| - --evaluation_strategy steps |
62 |
| - --overwrite_cache |
63 | 71 | """.split()
|
| 72 | + |
64 | 73 | with patch.object(sys, "argv", testargs):
|
65 | 74 | start = time()
|
66 | 75 | xla_spawn.main()
|
67 | 76 | end = time()
|
68 | 77 |
|
69 |
| - result = {} |
70 |
| - with open(f"{output_directory}/eval_results_mrpc.txt") as f: |
71 |
| - lines = f.readlines() |
72 |
| - for line in lines: |
73 |
| - key, value = line.split(" = ") |
74 |
| - result[key] = float(value) |
75 |
| - |
76 |
| - del result["eval_loss"] |
77 |
| - for value in result.values(): |
78 |
| - # Assert that the model trains |
79 |
| - self.assertGreaterEqual(value, 0.70) |
| 78 | + result = get_results(tmp_dir) |
| 79 | + self.assertGreaterEqual(result["eval_accuracy"], 0.75) |
80 | 80 |
|
81 |
| - # Assert that the script takes less than 300 seconds to make sure it doesn't hang. |
| 81 | + # Assert that the script takes less than 500 seconds to make sure it doesn't hang. |
82 | 82 | self.assertLess(end - start, 500)
|
83 | 83 |
|
84 | 84 | def test_trainer_tpu(self):
|
85 | 85 | import xla_spawn
|
86 | 86 |
|
87 | 87 | testargs = """
|
88 |
| - transformers/tests/test_trainer_tpu.py |
| 88 | + ./tests/test_trainer_tpu.py |
89 | 89 | --num_cores=8
|
90 |
| - transformers/tests/test_trainer_tpu.py |
| 90 | + ./tests/test_trainer_tpu.py |
91 | 91 | """.split()
|
92 | 92 | with patch.object(sys, "argv", testargs):
|
93 | 93 | xla_spawn.main()
|
0 commit comments