Skip to content

Commit eb613b5

Browse files
authored
Use hash to clean the test dirs (#6475)
* Use hash to clean the test dirs * Use hash to clean the test dirs * Use hash to clean the test dirs * fix
1 parent 680f133 commit eb613b5

File tree

2 files changed

+27
-17
lines changed

2 files changed

+27
-17
lines changed

examples/bert-loses-patience/test_run_glue_with_pabee.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def get_setup_file():
2020
return args.f
2121

2222

23-
def clean_test_dir(path="./tests/fixtures/tests_samples/temp_dir"):
23+
def clean_test_dir(path):
2424
shutil.rmtree(path, ignore_errors=True)
2525

2626

@@ -37,7 +37,6 @@ def test_run_glue(self):
3737
--task_name mrpc
3838
--do_train
3939
--do_eval
40-
--output_dir ./tests/fixtures/tests_samples/temp_dir
4140
--per_gpu_train_batch_size=2
4241
--per_gpu_eval_batch_size=1
4342
--learning_rate=2e-5
@@ -46,10 +45,13 @@ def test_run_glue(self):
4645
--overwrite_output_dir
4746
--seed=42
4847
--max_seq_length=128
49-
""".split()
48+
"""
49+
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
50+
testargs += "--output_dir " + output_dir
51+
testargs = testargs.split()
5052
with patch.object(sys, "argv", testargs):
5153
result = run_glue_with_pabee.main()
5254
for value in result.values():
5355
self.assertGreaterEqual(value, 0.75)
5456

55-
clean_test_dir()
57+
clean_test_dir(output_dir)

examples/test_examples.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_setup_file():
5252
return args.f
5353

5454

55-
def clean_test_dir(path="./tests/fixtures/tests_samples/temp_dir"):
55+
def clean_test_dir(path):
5656
shutil.rmtree(path, ignore_errors=True)
5757

5858

@@ -68,7 +68,6 @@ def test_run_glue(self):
6868
--task_name mrpc
6969
--do_train
7070
--do_eval
71-
--output_dir ./tests/fixtures/tests_samples/temp_dir
7271
--per_device_train_batch_size=2
7372
--per_device_eval_batch_size=1
7473
--learning_rate=1e-4
@@ -77,13 +76,16 @@ def test_run_glue(self):
7776
--overwrite_output_dir
7877
--seed=42
7978
--max_seq_length=128
80-
""".split()
79+
"""
80+
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
81+
testargs += "--output_dir " + output_dir
82+
testargs = testargs.split()
8183
with patch.object(sys, "argv", testargs):
8284
result = run_glue.main()
8385
del result["eval_loss"]
8486
for value in result.values():
8587
self.assertGreaterEqual(value, 0.75)
86-
clean_test_dir()
88+
clean_test_dir(output_dir)
8789

8890
def test_run_pl_glue(self):
8991
stream_handler = logging.StreamHandler(sys.stdout)
@@ -96,13 +98,15 @@ def test_run_pl_glue(self):
9698
--task mrpc
9799
--do_train
98100
--do_predict
99-
--output_dir ./tests/fixtures/tests_samples/temp_dir
100101
--train_batch_size=32
101102
--learning_rate=1e-4
102103
--num_train_epochs=1
103104
--seed=42
104105
--max_seq_length=128
105-
""".split()
106+
"""
107+
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
108+
testargs += "--output_dir " + output_dir
109+
testargs = testargs.split()
106110

107111
if torch.cuda.is_available():
108112
testargs += ["--fp16", "--gpus=1"]
@@ -119,7 +123,7 @@ def test_run_pl_glue(self):
119123
# for k, v in result.items():
120124
# self.assertGreaterEqual(v, 0.75, f"({k})")
121125
#
122-
clean_test_dir()
126+
clean_test_dir(output_dir)
123127

124128
def test_run_language_modeling(self):
125129
stream_handler = logging.StreamHandler(sys.stdout)
@@ -133,17 +137,19 @@ def test_run_language_modeling(self):
133137
--line_by_line
134138
--train_data_file ./tests/fixtures/sample_text.txt
135139
--eval_data_file ./tests/fixtures/sample_text.txt
136-
--output_dir ./tests/fixtures/tests_samples/temp_dir
137140
--overwrite_output_dir
138141
--do_train
139142
--do_eval
140143
--num_train_epochs=1
141144
--no_cuda
142-
""".split()
145+
"""
146+
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
147+
testargs += "--output_dir " + output_dir
148+
testargs = testargs.split()
143149
with patch.object(sys, "argv", testargs):
144150
result = run_language_modeling.main()
145151
self.assertLess(result["perplexity"], 35)
146-
clean_test_dir()
152+
clean_test_dir(output_dir)
147153

148154
def test_run_squad(self):
149155
stream_handler = logging.StreamHandler(sys.stdout)
@@ -154,7 +160,6 @@ def test_run_squad(self):
154160
--model_type=distilbert
155161
--model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
156162
--data_dir=./tests/fixtures/tests_samples/SQUAD
157-
--output_dir=./tests/fixtures/tests_samples/temp_dir
158163
--max_steps=10
159164
--warmup_steps=2
160165
--do_train
@@ -165,12 +170,15 @@ def test_run_squad(self):
165170
--per_gpu_eval_batch_size=1
166171
--overwrite_output_dir
167172
--seed=42
168-
""".split()
173+
"""
174+
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
175+
testargs += "--output_dir " + output_dir
176+
testargs = testargs.split()
169177
with patch.object(sys, "argv", testargs):
170178
result = run_squad.main()
171179
self.assertGreaterEqual(result["f1"], 25)
172180
self.assertGreaterEqual(result["exact"], 21)
173-
clean_test_dir()
181+
clean_test_dir(output_dir)
174182

175183
def test_generation(self):
176184
stream_handler = logging.StreamHandler(sys.stdout)

0 commit comments

Comments
 (0)