Skip to content

Commit 9dbe409

Browse files
authored
[testing] a new TestCasePlus subclass + get_auto_remove_tmp_dir() (#6494)
* [testing] switch to a new TestCasePlus + get_auto_remove_tmp_dir() for auto-removal of tmp dirs * respect after=True for tempfile, simplify code * comments * comment fix * put `before` last in args, so can make debug even faster
1 parent 36010cb commit 9dbe409

File tree

3 files changed

+124
-49
lines changed

3 files changed

+124
-49
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import argparse
22
import logging
3-
import shutil
43
import sys
5-
import unittest
64
from unittest.mock import patch
75

86
import run_glue_with_pabee
7+
from transformers.testing_utils import TestCasePlus
98

109

1110
logging.basicConfig(level=logging.DEBUG)
@@ -20,20 +19,19 @@ def get_setup_file():
2019
return args.f
2120

2221

23-
def clean_test_dir(path):
24-
shutil.rmtree(path, ignore_errors=True)
25-
26-
27-
class PabeeTests(unittest.TestCase):
22+
class PabeeTests(TestCasePlus):
2823
def test_run_glue(self):
2924
stream_handler = logging.StreamHandler(sys.stdout)
3025
logger.addHandler(stream_handler)
3126

32-
testargs = """
27+
tmp_dir = self.get_auto_remove_tmp_dir()
28+
testargs = f"""
3329
run_glue_with_pabee.py
3430
--model_type albert
3531
--model_name_or_path albert-base-v2
3632
--data_dir ./tests/fixtures/tests_samples/MRPC/
33+
--output_dir {tmp_dir}
34+
--overwrite_output_dir
3735
--task_name mrpc
3836
--do_train
3937
--do_eval
@@ -42,16 +40,11 @@ def test_run_glue(self):
4240
--learning_rate=2e-5
4341
--max_steps=50
4442
--warmup_steps=2
45-
--overwrite_output_dir
4643
--seed=42
4744
--max_seq_length=128
48-
"""
49-
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
50-
testargs += "--output_dir " + output_dir
51-
testargs = testargs.split()
45+
""".split()
46+
5247
with patch.object(sys, "argv", testargs):
5348
result = run_glue_with_pabee.main()
5449
for value in result.values():
5550
self.assertGreaterEqual(value, 0.75)
56-
57-
clean_test_dir(output_dir)

examples/test_examples.py

+24-34
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
import argparse
1818
import logging
1919
import os
20-
import shutil
2120
import sys
22-
import unittest
2321
from unittest.mock import patch
2422

2523
import torch
2624

25+
from transformers.testing_utils import TestCasePlus
26+
2727

2828
SRC_DIRS = [
2929
os.path.join(os.path.dirname(__file__), dirname)
@@ -52,19 +52,18 @@ def get_setup_file():
5252
return args.f
5353

5454

55-
def clean_test_dir(path):
56-
shutil.rmtree(path, ignore_errors=True)
57-
58-
59-
class ExamplesTests(unittest.TestCase):
55+
class ExamplesTests(TestCasePlus):
6056
def test_run_glue(self):
6157
stream_handler = logging.StreamHandler(sys.stdout)
6258
logger.addHandler(stream_handler)
6359

64-
testargs = """
60+
tmp_dir = self.get_auto_remove_tmp_dir()
61+
testargs = f"""
6562
run_glue.py
6663
--model_name_or_path distilbert-base-uncased
6764
--data_dir ./tests/fixtures/tests_samples/MRPC/
65+
--output_dir {tmp_dir}
66+
--overwrite_output_dir
6867
--task_name mrpc
6968
--do_train
7069
--do_eval
@@ -73,28 +72,26 @@ def test_run_glue(self):
7372
--learning_rate=1e-4
7473
--max_steps=10
7574
--warmup_steps=2
76-
--overwrite_output_dir
7775
--seed=42
7876
--max_seq_length=128
79-
"""
80-
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
81-
testargs += "--output_dir " + output_dir
82-
testargs = testargs.split()
77+
""".split()
78+
8379
with patch.object(sys, "argv", testargs):
8480
result = run_glue.main()
8581
del result["eval_loss"]
8682
for value in result.values():
8783
self.assertGreaterEqual(value, 0.75)
88-
clean_test_dir(output_dir)
8984

9085
def test_run_pl_glue(self):
9186
stream_handler = logging.StreamHandler(sys.stdout)
9287
logger.addHandler(stream_handler)
9388

94-
testargs = """
89+
tmp_dir = self.get_auto_remove_tmp_dir()
90+
testargs = f"""
9591
run_pl_glue.py
9692
--model_name_or_path bert-base-cased
9793
--data_dir ./tests/fixtures/tests_samples/MRPC/
94+
--output_dir {tmp_dir}
9895
--task mrpc
9996
--do_train
10097
--do_predict
@@ -103,11 +100,7 @@ def test_run_pl_glue(self):
103100
--num_train_epochs=1
104101
--seed=42
105102
--max_seq_length=128
106-
"""
107-
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
108-
testargs += "--output_dir " + output_dir
109-
testargs = testargs.split()
110-
103+
""".split()
111104
if torch.cuda.is_available():
112105
testargs += ["--fp16", "--gpus=1"]
113106

@@ -123,43 +116,44 @@ def test_run_pl_glue(self):
123116
# for k, v in result.items():
124117
# self.assertGreaterEqual(v, 0.75, f"({k})")
125118
#
126-
clean_test_dir(output_dir)
127119

128120
def test_run_language_modeling(self):
129121
stream_handler = logging.StreamHandler(sys.stdout)
130122
logger.addHandler(stream_handler)
131123

132-
testargs = """
124+
tmp_dir = self.get_auto_remove_tmp_dir()
125+
testargs = f"""
133126
run_language_modeling.py
134127
--model_name_or_path distilroberta-base
135128
--model_type roberta
136129
--mlm
137130
--line_by_line
138131
--train_data_file ./tests/fixtures/sample_text.txt
139132
--eval_data_file ./tests/fixtures/sample_text.txt
133+
--output_dir {tmp_dir}
140134
--overwrite_output_dir
141135
--do_train
142136
--do_eval
143137
--num_train_epochs=1
144138
--no_cuda
145-
"""
146-
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
147-
testargs += "--output_dir " + output_dir
148-
testargs = testargs.split()
139+
""".split()
140+
149141
with patch.object(sys, "argv", testargs):
150142
result = run_language_modeling.main()
151143
self.assertLess(result["perplexity"], 35)
152-
clean_test_dir(output_dir)
153144

154145
def test_run_squad(self):
155146
stream_handler = logging.StreamHandler(sys.stdout)
156147
logger.addHandler(stream_handler)
157148

158-
testargs = """
149+
tmp_dir = self.get_auto_remove_tmp_dir()
150+
testargs = f"""
159151
run_squad.py
160152
--model_type=distilbert
161153
--model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
162154
--data_dir=./tests/fixtures/tests_samples/SQUAD
155+
--output_dir {tmp_dir}
156+
--overwrite_output_dir
163157
--max_steps=10
164158
--warmup_steps=2
165159
--do_train
@@ -168,17 +162,13 @@ def test_run_squad(self):
168162
--learning_rate=2e-4
169163
--per_gpu_train_batch_size=2
170164
--per_gpu_eval_batch_size=1
171-
--overwrite_output_dir
172165
--seed=42
173-
"""
174-
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
175-
testargs += "--output_dir " + output_dir
176-
testargs = testargs.split()
166+
""".split()
167+
177168
with patch.object(sys, "argv", testargs):
178169
result = run_squad.main()
179170
self.assertGreaterEqual(result["f1"], 25)
180171
self.assertGreaterEqual(result["exact"], 21)
181-
clean_test_dir(output_dir)
182172

183173
def test_generation(self):
184174
stream_handler = logging.StreamHandler(sys.stdout)

src/transformers/testing_utils.py

+92
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import os
22
import re
3+
import shutil
34
import sys
5+
import tempfile
46
import unittest
57
from distutils.util import strtobool
68
from io import StringIO
9+
from pathlib import Path
710

811
from .file_utils import _tf_available, _torch_available, _torch_tpu_available
912

@@ -255,3 +258,92 @@ class CaptureStderr(CaptureStd):
255258

256259
def __init__(self):
257260
super().__init__(out=False)
261+
262+
263+
class TestCasePlus(unittest.TestCase):
264+
"""This class extends `unittest.TestCase` with additional features.
265+
266+
Feature 1: Flexible auto-removable temp dirs which are guaranteed to get
267+
removed at the end of test.
268+
269+
In all the following scenarios the temp dir will be auto-removed at the end
270+
of test, unless `after=False`.
271+
272+
# 1. create a unique temp dir, `tmp_dir` will contain the path to the created temp dir
273+
def test_whatever(self):
274+
tmp_dir = self.get_auto_remove_tmp_dir()
275+
276+
# 2. create a temp dir of my choice and delete it at the end - useful for debug when you want to
277+
# monitor a specific directory
278+
def test_whatever(self):
279+
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test")
280+
281+
# 3. create a temp dir of my choice and do not delete it at the end - useful for when you want
282+
# to look at the temp results
283+
def test_whatever(self):
284+
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test", after=False)
285+
286+
# 4. create a temp dir of my choice and ensure to delete it right away - useful for when you
287+
# disabled deletion in the previous test run and want to make sure the that tmp dir is empty
288+
# before the new test is run
289+
def test_whatever(self):
290+
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test", before=True)
291+
292+
Note 1: In order to run the equivalent of `rm -r` safely, only subdirs of the
293+
project repository checkout are allowed if an explicit `tmp_dir` is used, so
294+
that by mistake no `/tmp` or similar important part of the filesystem will
295+
get nuked. i.e. please always pass paths that start with `./`
296+
297+
Note 2: Each test can register multiple temp dirs and they all will get
298+
auto-removed, unless requested otherwise.
299+
300+
"""
301+
302+
def setUp(self):
303+
self.teardown_tmp_dirs = []
304+
305+
def get_auto_remove_tmp_dir(self, tmp_dir=None, after=True, before=False):
306+
"""
307+
Args:
308+
tmp_dir (:obj:`string`, `optional`, defaults to :obj:`None`):
309+
use this path, if None a unique path will be assigned
310+
before (:obj:`bool`, `optional`, defaults to :obj:`False`):
311+
if `True` and tmp dir already exists make sure to empty it right away
312+
after (:obj:`bool`, `optional`, defaults to :obj:`True`):
313+
delete the tmp dir at the end of the test
314+
315+
Returns:
316+
tmp_dir(:obj:`string`):
317+
either the same value as passed via `tmp_dir` or the path to the auto-created tmp dir
318+
"""
319+
if tmp_dir is not None:
320+
# using provided path
321+
path = Path(tmp_dir).resolve()
322+
323+
# to avoid nuking parts of the filesystem, only relative paths are allowed
324+
if not tmp_dir.startswith("./"):
325+
raise ValueError(
326+
f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`"
327+
)
328+
329+
# ensure the dir is empty to start with
330+
if before is True and path.exists():
331+
shutil.rmtree(tmp_dir, ignore_errors=True)
332+
333+
path.mkdir(parents=True, exist_ok=True)
334+
335+
else:
336+
# using unique tmp dir (always empty, regardless of `before`)
337+
tmp_dir = tempfile.mkdtemp()
338+
339+
if after is True:
340+
# register for deletion
341+
self.teardown_tmp_dirs.append(tmp_dir)
342+
343+
return tmp_dir
344+
345+
def tearDown(self):
346+
# remove registered temp dirs
347+
for path in self.teardown_tmp_dirs:
348+
shutil.rmtree(path, ignore_errors=True)
349+
self.teardown_tmp_dirs = []

0 commit comments

Comments
 (0)