@@ -52,7 +52,7 @@ def get_setup_file():
52
52
return args .f
53
53
54
54
55
- def clean_test_dir (path = "./tests/fixtures/tests_samples/temp_dir" ):
55
+ def clean_test_dir (path ):
56
56
shutil .rmtree (path , ignore_errors = True )
57
57
58
58
@@ -68,7 +68,6 @@ def test_run_glue(self):
68
68
--task_name mrpc
69
69
--do_train
70
70
--do_eval
71
- --output_dir ./tests/fixtures/tests_samples/temp_dir
72
71
--per_device_train_batch_size=2
73
72
--per_device_eval_batch_size=1
74
73
--learning_rate=1e-4
@@ -77,13 +76,16 @@ def test_run_glue(self):
77
76
--overwrite_output_dir
78
77
--seed=42
79
78
--max_seq_length=128
80
- """ .split ()
79
+ """
80
+ output_dir = "./tests/fixtures/tests_samples/temp_dir_{}" .format (hash (testargs ))
81
+ testargs += "--output_dir " + output_dir
82
+ testargs = testargs .split ()
81
83
with patch .object (sys , "argv" , testargs ):
82
84
result = run_glue .main ()
83
85
del result ["eval_loss" ]
84
86
for value in result .values ():
85
87
self .assertGreaterEqual (value , 0.75 )
86
- clean_test_dir ()
88
+ clean_test_dir (output_dir )
87
89
88
90
def test_run_pl_glue (self ):
89
91
stream_handler = logging .StreamHandler (sys .stdout )
@@ -96,13 +98,15 @@ def test_run_pl_glue(self):
96
98
--task mrpc
97
99
--do_train
98
100
--do_predict
99
- --output_dir ./tests/fixtures/tests_samples/temp_dir
100
101
--train_batch_size=32
101
102
--learning_rate=1e-4
102
103
--num_train_epochs=1
103
104
--seed=42
104
105
--max_seq_length=128
105
- """ .split ()
106
+ """
107
+ output_dir = "./tests/fixtures/tests_samples/temp_dir_{}" .format (hash (testargs ))
108
+ testargs += "--output_dir " + output_dir
109
+ testargs = testargs .split ()
106
110
107
111
if torch .cuda .is_available ():
108
112
testargs += ["--fp16" , "--gpus=1" ]
@@ -119,7 +123,7 @@ def test_run_pl_glue(self):
119
123
# for k, v in result.items():
120
124
# self.assertGreaterEqual(v, 0.75, f"({k})")
121
125
#
122
- clean_test_dir ()
126
+ clean_test_dir (output_dir )
123
127
124
128
def test_run_language_modeling (self ):
125
129
stream_handler = logging .StreamHandler (sys .stdout )
@@ -133,17 +137,19 @@ def test_run_language_modeling(self):
133
137
--line_by_line
134
138
--train_data_file ./tests/fixtures/sample_text.txt
135
139
--eval_data_file ./tests/fixtures/sample_text.txt
136
- --output_dir ./tests/fixtures/tests_samples/temp_dir
137
140
--overwrite_output_dir
138
141
--do_train
139
142
--do_eval
140
143
--num_train_epochs=1
141
144
--no_cuda
142
- """ .split ()
145
+ """
146
+ output_dir = "./tests/fixtures/tests_samples/temp_dir_{}" .format (hash (testargs ))
147
+ testargs += "--output_dir " + output_dir
148
+ testargs = testargs .split ()
143
149
with patch .object (sys , "argv" , testargs ):
144
150
result = run_language_modeling .main ()
145
151
self .assertLess (result ["perplexity" ], 35 )
146
- clean_test_dir ()
152
+ clean_test_dir (output_dir )
147
153
148
154
def test_run_squad (self ):
149
155
stream_handler = logging .StreamHandler (sys .stdout )
@@ -154,7 +160,6 @@ def test_run_squad(self):
154
160
--model_type=distilbert
155
161
--model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
156
162
--data_dir=./tests/fixtures/tests_samples/SQUAD
157
- --output_dir=./tests/fixtures/tests_samples/temp_dir
158
163
--max_steps=10
159
164
--warmup_steps=2
160
165
--do_train
@@ -165,12 +170,15 @@ def test_run_squad(self):
165
170
--per_gpu_eval_batch_size=1
166
171
--overwrite_output_dir
167
172
--seed=42
168
- """ .split ()
173
+ """
174
+ output_dir = "./tests/fixtures/tests_samples/temp_dir_{}" .format (hash (testargs ))
175
+ testargs += "--output_dir " + output_dir
176
+ testargs = testargs .split ()
169
177
with patch .object (sys , "argv" , testargs ):
170
178
result = run_squad .main ()
171
179
self .assertGreaterEqual (result ["f1" ], 25 )
172
180
self .assertGreaterEqual (result ["exact" ], 21 )
173
- clean_test_dir ()
181
+ clean_test_dir (output_dir )
174
182
175
183
def test_generation (self ):
176
184
stream_handler = logging .StreamHandler (sys .stdout )
0 commit comments