17
17
import argparse
18
18
import logging
19
19
import os
20
- import shutil
21
20
import sys
22
- import unittest
23
21
from unittest .mock import patch
24
22
25
23
import torch
26
24
25
+ from transformers .testing_utils import TestCasePlus
26
+
27
27
28
28
SRC_DIRS = [
29
29
os .path .join (os .path .dirname (__file__ ), dirname )
@@ -52,19 +52,18 @@ def get_setup_file():
52
52
return args .f
53
53
54
54
55
- def clean_test_dir (path ):
56
- shutil .rmtree (path , ignore_errors = True )
57
-
58
-
59
- class ExamplesTests (unittest .TestCase ):
55
+ class ExamplesTests (TestCasePlus ):
60
56
def test_run_glue (self ):
61
57
stream_handler = logging .StreamHandler (sys .stdout )
62
58
logger .addHandler (stream_handler )
63
59
64
- testargs = """
60
+ tmp_dir = self .get_auto_remove_tmp_dir ()
61
+ testargs = f"""
65
62
run_glue.py
66
63
--model_name_or_path distilbert-base-uncased
67
64
--data_dir ./tests/fixtures/tests_samples/MRPC/
65
+ --output_dir { tmp_dir }
66
+ --overwrite_output_dir
68
67
--task_name mrpc
69
68
--do_train
70
69
--do_eval
@@ -73,28 +72,26 @@ def test_run_glue(self):
73
72
--learning_rate=1e-4
74
73
--max_steps=10
75
74
--warmup_steps=2
76
- --overwrite_output_dir
77
75
--seed=42
78
76
--max_seq_length=128
79
- """
80
- output_dir = "./tests/fixtures/tests_samples/temp_dir_{}" .format (hash (testargs ))
81
- testargs += "--output_dir " + output_dir
82
- testargs = testargs .split ()
77
+ """ .split ()
78
+
83
79
with patch .object (sys , "argv" , testargs ):
84
80
result = run_glue .main ()
85
81
del result ["eval_loss" ]
86
82
for value in result .values ():
87
83
self .assertGreaterEqual (value , 0.75 )
88
- clean_test_dir (output_dir )
89
84
90
85
def test_run_pl_glue (self ):
91
86
stream_handler = logging .StreamHandler (sys .stdout )
92
87
logger .addHandler (stream_handler )
93
88
94
- testargs = """
89
+ tmp_dir = self .get_auto_remove_tmp_dir ()
90
+ testargs = f"""
95
91
run_pl_glue.py
96
92
--model_name_or_path bert-base-cased
97
93
--data_dir ./tests/fixtures/tests_samples/MRPC/
94
+ --output_dir { tmp_dir }
98
95
--task mrpc
99
96
--do_train
100
97
--do_predict
@@ -103,11 +100,7 @@ def test_run_pl_glue(self):
103
100
--num_train_epochs=1
104
101
--seed=42
105
102
--max_seq_length=128
106
- """
107
- output_dir = "./tests/fixtures/tests_samples/temp_dir_{}" .format (hash (testargs ))
108
- testargs += "--output_dir " + output_dir
109
- testargs = testargs .split ()
110
-
103
+ """ .split ()
111
104
if torch .cuda .is_available ():
112
105
testargs += ["--fp16" , "--gpus=1" ]
113
106
@@ -123,43 +116,44 @@ def test_run_pl_glue(self):
123
116
# for k, v in result.items():
124
117
# self.assertGreaterEqual(v, 0.75, f"({k})")
125
118
#
126
- clean_test_dir (output_dir )
127
119
128
120
def test_run_language_modeling (self ):
129
121
stream_handler = logging .StreamHandler (sys .stdout )
130
122
logger .addHandler (stream_handler )
131
123
132
- testargs = """
124
+ tmp_dir = self .get_auto_remove_tmp_dir ()
125
+ testargs = f"""
133
126
run_language_modeling.py
134
127
--model_name_or_path distilroberta-base
135
128
--model_type roberta
136
129
--mlm
137
130
--line_by_line
138
131
--train_data_file ./tests/fixtures/sample_text.txt
139
132
--eval_data_file ./tests/fixtures/sample_text.txt
133
+ --output_dir { tmp_dir }
140
134
--overwrite_output_dir
141
135
--do_train
142
136
--do_eval
143
137
--num_train_epochs=1
144
138
--no_cuda
145
- """
146
- output_dir = "./tests/fixtures/tests_samples/temp_dir_{}" .format (hash (testargs ))
147
- testargs += "--output_dir " + output_dir
148
- testargs = testargs .split ()
139
+ """ .split ()
140
+
149
141
with patch .object (sys , "argv" , testargs ):
150
142
result = run_language_modeling .main ()
151
143
self .assertLess (result ["perplexity" ], 35 )
152
- clean_test_dir (output_dir )
153
144
154
145
def test_run_squad (self ):
155
146
stream_handler = logging .StreamHandler (sys .stdout )
156
147
logger .addHandler (stream_handler )
157
148
158
- testargs = """
149
+ tmp_dir = self .get_auto_remove_tmp_dir ()
150
+ testargs = f"""
159
151
run_squad.py
160
152
--model_type=distilbert
161
153
--model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
162
154
--data_dir=./tests/fixtures/tests_samples/SQUAD
155
+ --output_dir { tmp_dir }
156
+ --overwrite_output_dir
163
157
--max_steps=10
164
158
--warmup_steps=2
165
159
--do_train
@@ -168,17 +162,13 @@ def test_run_squad(self):
168
162
--learning_rate=2e-4
169
163
--per_gpu_train_batch_size=2
170
164
--per_gpu_eval_batch_size=1
171
- --overwrite_output_dir
172
165
--seed=42
173
- """
174
- output_dir = "./tests/fixtures/tests_samples/temp_dir_{}" .format (hash (testargs ))
175
- testargs += "--output_dir " + output_dir
176
- testargs = testargs .split ()
166
+ """ .split ()
167
+
177
168
with patch .object (sys , "argv" , testargs ):
178
169
result = run_squad .main ()
179
170
self .assertGreaterEqual (result ["f1" ], 25 )
180
171
self .assertGreaterEqual (result ["exact" ], 21 )
181
- clean_test_dir (output_dir )
182
172
183
173
def test_generation (self ):
184
174
stream_handler = logging .StreamHandler (sys .stdout )
0 commit comments