3
3
import argparse
4
4
import os
5
5
import sys
6
- from pathlib import Path
7
6
from unittest .mock import patch
8
7
9
- import pytest
10
8
import pytorch_lightning as pl
11
9
import timeout_decorator
12
10
import torch
13
11
14
12
from distillation import BartSummarizationDistiller , distill_main
15
13
from finetune import SummarizationModule , main
16
- from test_seq2seq_examples import CUDA_AVAILABLE , MBART_TINY
17
- from transformers import BartForConditionalGeneration , MarianMTModel
18
- from transformers .testing_utils import TestCasePlus , slow
14
+ from transformers import MarianMTModel
15
+ from transformers . file_utils import cached_path
16
+ from transformers .testing_utils import TestCasePlus , require_torch_gpu , slow
19
17
from utils import load_json
20
18
21
19
22
- MODEL_NAME = MBART_TINY
23
- MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
20
+ MARIAN_MODEL = "sshleifer/mar_enro_6_3_student"
24
21
25
22
26
- class TestAll (TestCasePlus ):
23
+ class TestMbartCc25Enro (TestCasePlus ):
24
+ def setUp (self ):
25
+ super ().setUp ()
26
+
27
+ data_cached = cached_path (
28
+ "https://cdn-datasets.huggingface.co/translation/wmt_en_ro-tr40k-va0.5k-te0.5k.tar.gz" ,
29
+ extract_compressed_file = True ,
30
+ )
31
+ self .data_dir = f"{ data_cached } /wmt_en_ro-tr40k-va0.5k-te0.5k"
32
+
27
33
@slow
28
- @pytest . mark . skipif ( not CUDA_AVAILABLE , reason = "too slow to run on CPU" )
34
+ @require_torch_gpu
29
35
def test_model_download (self ):
30
36
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
31
- BartForConditionalGeneration .from_pretrained (MODEL_NAME )
32
37
MarianMTModel .from_pretrained (MARIAN_MODEL )
33
38
34
- @timeout_decorator .timeout (120 )
39
+ # @timeout_decorator.timeout(1200 )
35
40
@slow
36
- @pytest . mark . skipif ( not CUDA_AVAILABLE , reason = "too slow to run on CPU" )
41
+ @require_torch_gpu
37
42
def test_train_mbart_cc25_enro_script (self ):
38
- data_dir = "examples/seq2seq/test_data/wmt_en_ro"
39
43
env_vars_to_replace = {
40
- "--fp16_opt_level=O1" : "" ,
41
- "$MAX_LEN" : 128 ,
42
- "$BS" : 4 ,
44
+ "$MAX_LEN" : 64 ,
45
+ "$BS" : 64 ,
43
46
"$GAS" : 1 ,
44
- "$ENRO_DIR" : data_dir ,
45
- "facebook/mbart-large-cc25" : MODEL_NAME ,
46
- # Download is 120MB in previous test.
47
- "val_check_interval=0.25" : "val_check_interval=1.0" ,
47
+ "$ENRO_DIR" : self .data_dir ,
48
+ "facebook/mbart-large-cc25" : MARIAN_MODEL ,
49
+ # "val_check_interval=0.25": "val_check_interval=1.0",
50
+ "--learning_rate=3e-5" : "--learning_rate 3e-4" ,
51
+ "--num_train_epochs 6" : "--num_train_epochs 1" ,
48
52
}
49
53
50
54
# Clean up bash script
51
- bash_script = Path ( "examples/seq2seq/ train_mbart_cc25_enro.sh" ).open ().read ().split ("finetune.py" )[1 ].strip ()
55
+ bash_script = ( self . test_file_dir / " train_mbart_cc25_enro.sh" ).open ().read ().split ("finetune.py" )[1 ].strip ()
52
56
bash_script = bash_script .replace ("\\ \n " , "" ).strip ().replace ('"$@"' , "" )
53
57
for k , v in env_vars_to_replace .items ():
54
58
bash_script = bash_script .replace (k , str (v ))
55
59
output_dir = self .get_auto_remove_tmp_dir ()
56
60
57
- bash_script = bash_script .replace ("--fp16 " , "" )
58
- testargs = (
59
- ["finetune.py" ]
60
- + bash_script .split ()
61
- + [
62
- f"--output_dir={ output_dir } " ,
63
- "--gpus=1" ,
64
- "--learning_rate=3e-1" ,
65
- "--warmup_steps=0" ,
66
- "--val_check_interval=1.0" ,
67
- "--tokenizer_name=facebook/mbart-large-en-ro" ,
68
- ]
69
- )
61
+ # bash_script = bash_script.replace("--fp16 ", "")
62
+ args = f"""
63
+ --output_dir { output_dir }
64
+ --tokenizer_name Helsinki-NLP/opus-mt-en-ro
65
+ --sortish_sampler
66
+ --do_predict
67
+ --gpus 1
68
+ --freeze_encoder
69
+ --n_train 40000
70
+ --n_val 500
71
+ --n_test 500
72
+ --fp16_opt_level O1
73
+ --num_sanity_val_steps 0
74
+ --eval_beams 2
75
+ """ .split ()
76
+ # XXX: args.gpus > 1 : handle multigpu in the future
77
+
78
+ testargs = ["finetune.py" ] + bash_script .split () + args
70
79
with patch .object (sys , "argv" , testargs ):
71
80
parser = argparse .ArgumentParser ()
72
81
parser = pl .Trainer .add_argparse_args (parser )
73
82
parser = SummarizationModule .add_model_specific_args (parser , os .getcwd ())
74
83
args = parser .parse_args ()
75
- args .do_predict = False
76
- # assert args.gpus == gpus THIS BREAKS for multigpu
77
84
model = main (args )
78
85
79
86
# Check metrics
80
87
metrics = load_json (model .metrics_save_path )
81
88
first_step_stats = metrics ["val" ][0 ]
82
89
last_step_stats = metrics ["val" ][- 1 ]
83
- assert (
84
- len (metrics ["val" ]) == (args .max_epochs / args .val_check_interval ) + 1
85
- ) # +1 accounts for val_sanity_check
90
+ self .assertEqual (len (metrics ["val" ]), (args .max_epochs / args .val_check_interval ))
91
+ assert isinstance (last_step_stats [f"val_avg_{ model .val_metric } " ], float )
86
92
87
- assert last_step_stats ["val_avg_gen_time" ] >= 0.01
93
+ self .assertGreater (last_step_stats ["val_avg_gen_time" ], 0.01 )
94
+ # model hanging on generate. Maybe bad config was saved. (XXX: old comment/assert?)
95
+ self .assertLessEqual (last_step_stats ["val_avg_gen_time" ], 1.0 )
88
96
89
- assert first_step_stats ["val_avg_bleu" ] < last_step_stats ["val_avg_bleu" ] # model learned nothing
90
- assert 1.0 >= last_step_stats ["val_avg_gen_time" ] # model hanging on generate. Maybe bad config was saved.
91
- assert isinstance (last_step_stats [f"val_avg_{ model .val_metric } " ], float )
97
+ # test learning requirements:
98
+
99
+ # 1. BLEU improves over the course of training by more than 2 pts
100
+ self .assertGreater (last_step_stats ["val_avg_bleu" ] - first_step_stats ["val_avg_bleu" ], 2 )
101
+
102
+ # 2. BLEU finishes above 17
103
+ self .assertGreater (last_step_stats ["val_avg_bleu" ], 17 )
104
+
105
+ # 3. test BLEU and val BLEU within ~1.1 pt.
106
+ self .assertLess (abs (metrics ["val" ][- 1 ]["val_avg_bleu" ] - metrics ["test" ][- 1 ]["test_avg_bleu" ]), 1.1 )
92
107
93
108
# check lightning ckpt can be loaded and has a reasonable statedict
94
109
contents = os .listdir (output_dir )
@@ -107,11 +122,13 @@ def test_train_mbart_cc25_enro_script(self):
107
122
# assert len(metrics["val"]) == desired_n_evals
108
123
assert len (metrics ["test" ]) == 1
109
124
125
+
126
+ class TestDistilMarianNoTeacher (TestCasePlus ):
110
127
@timeout_decorator .timeout (600 )
111
128
@slow
112
- @pytest . mark . skipif ( not CUDA_AVAILABLE , reason = "too slow to run on CPU" )
129
+ @require_torch_gpu
113
130
def test_opus_mt_distill_script (self ):
114
- data_dir = "examples/seq2seq /test_data/wmt_en_ro"
131
+ data_dir = f" { self . test_file_dir_str } /test_data/wmt_en_ro"
115
132
env_vars_to_replace = {
116
133
"--fp16_opt_level=O1" : "" ,
117
134
"$MAX_LEN" : 128 ,
@@ -124,7 +141,7 @@ def test_opus_mt_distill_script(self):
124
141
125
142
# Clean up bash script
126
143
bash_script = (
127
- Path ( "examples/seq2seq/ distil_marian_no_teacher.sh" ).open ().read ().split ("distillation.py" )[1 ].strip ()
144
+ ( self . test_file_dir / " distil_marian_no_teacher.sh" ).open ().read ().split ("distillation.py" )[1 ].strip ()
128
145
)
129
146
bash_script = bash_script .replace ("\\ \n " , "" ).strip ().replace ('"$@"' , "" )
130
147
bash_script = bash_script .replace ("--fp16 " , " " )
0 commit comments