Skip to content

Commit 9edafae

Browse files
authored
[s2s] test_bash_script.py - actually learn something (#8318)
* use decorator * remove hardcoded paths * make the test use more data and do real quality tests * shave off 10 secs * add --eval_beams 2, reformat * reduce train size, use smaller custom dataset
1 parent 1745039 commit 9edafae

File tree

1 file changed

+63
-46
lines changed

1 file changed

+63
-46
lines changed

examples/seq2seq/test_bash_script.py

+63-46
Original file line numberDiff line numberDiff line change
@@ -3,92 +3,107 @@
33
import argparse
44
import os
55
import sys
6-
from pathlib import Path
76
from unittest.mock import patch
87

9-
import pytest
108
import pytorch_lightning as pl
119
import timeout_decorator
1210
import torch
1311

1412
from distillation import BartSummarizationDistiller, distill_main
1513
from finetune import SummarizationModule, main
16-
from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
17-
from transformers import BartForConditionalGeneration, MarianMTModel
18-
from transformers.testing_utils import TestCasePlus, slow
14+
from transformers import MarianMTModel
15+
from transformers.file_utils import cached_path
16+
from transformers.testing_utils import TestCasePlus, require_torch_gpu, slow
1917
from utils import load_json
2018

2119

22-
MODEL_NAME = MBART_TINY
23-
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
20+
MARIAN_MODEL = "sshleifer/mar_enro_6_3_student"
2421

2522

26-
class TestAll(TestCasePlus):
23+
class TestMbartCc25Enro(TestCasePlus):
24+
def setUp(self):
25+
super().setUp()
26+
27+
data_cached = cached_path(
28+
"https://cdn-datasets.huggingface.co/translation/wmt_en_ro-tr40k-va0.5k-te0.5k.tar.gz",
29+
extract_compressed_file=True,
30+
)
31+
self.data_dir = f"{data_cached}/wmt_en_ro-tr40k-va0.5k-te0.5k"
32+
2733
@slow
28-
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
34+
@require_torch_gpu
2935
def test_model_download(self):
3036
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
31-
BartForConditionalGeneration.from_pretrained(MODEL_NAME)
3237
MarianMTModel.from_pretrained(MARIAN_MODEL)
3338

34-
@timeout_decorator.timeout(120)
39+
# @timeout_decorator.timeout(1200)
3540
@slow
36-
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
41+
@require_torch_gpu
3742
def test_train_mbart_cc25_enro_script(self):
38-
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
3943
env_vars_to_replace = {
40-
"--fp16_opt_level=O1": "",
41-
"$MAX_LEN": 128,
42-
"$BS": 4,
44+
"$MAX_LEN": 64,
45+
"$BS": 64,
4346
"$GAS": 1,
44-
"$ENRO_DIR": data_dir,
45-
"facebook/mbart-large-cc25": MODEL_NAME,
46-
# Download is 120MB in previous test.
47-
"val_check_interval=0.25": "val_check_interval=1.0",
47+
"$ENRO_DIR": self.data_dir,
48+
"facebook/mbart-large-cc25": MARIAN_MODEL,
49+
# "val_check_interval=0.25": "val_check_interval=1.0",
50+
"--learning_rate=3e-5": "--learning_rate 3e-4",
51+
"--num_train_epochs 6": "--num_train_epochs 1",
4852
}
4953

5054
# Clean up bash script
51-
bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip()
55+
bash_script = (self.test_file_dir / "train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip()
5256
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
5357
for k, v in env_vars_to_replace.items():
5458
bash_script = bash_script.replace(k, str(v))
5559
output_dir = self.get_auto_remove_tmp_dir()
5660

57-
bash_script = bash_script.replace("--fp16 ", "")
58-
testargs = (
59-
["finetune.py"]
60-
+ bash_script.split()
61-
+ [
62-
f"--output_dir={output_dir}",
63-
"--gpus=1",
64-
"--learning_rate=3e-1",
65-
"--warmup_steps=0",
66-
"--val_check_interval=1.0",
67-
"--tokenizer_name=facebook/mbart-large-en-ro",
68-
]
69-
)
61+
# bash_script = bash_script.replace("--fp16 ", "")
62+
args = f"""
63+
--output_dir {output_dir}
64+
--tokenizer_name Helsinki-NLP/opus-mt-en-ro
65+
--sortish_sampler
66+
--do_predict
67+
--gpus 1
68+
--freeze_encoder
69+
--n_train 40000
70+
--n_val 500
71+
--n_test 500
72+
--fp16_opt_level O1
73+
--num_sanity_val_steps 0
74+
--eval_beams 2
75+
""".split()
76+
# XXX: args.gpus > 1 : handle multigpu in the future
77+
78+
testargs = ["finetune.py"] + bash_script.split() + args
7079
with patch.object(sys, "argv", testargs):
7180
parser = argparse.ArgumentParser()
7281
parser = pl.Trainer.add_argparse_args(parser)
7382
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
7483
args = parser.parse_args()
75-
args.do_predict = False
76-
# assert args.gpus == gpus THIS BREAKS for multigpu
7784
model = main(args)
7885

7986
# Check metrics
8087
metrics = load_json(model.metrics_save_path)
8188
first_step_stats = metrics["val"][0]
8289
last_step_stats = metrics["val"][-1]
83-
assert (
84-
len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1
85-
) # +1 accounts for val_sanity_check
90+
self.assertEqual(len(metrics["val"]), (args.max_epochs / args.val_check_interval))
91+
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
8692

87-
assert last_step_stats["val_avg_gen_time"] >= 0.01
93+
self.assertGreater(last_step_stats["val_avg_gen_time"], 0.01)
94+
# model hanging on generate. Maybe bad config was saved. (XXX: old comment/assert?)
95+
self.assertLessEqual(last_step_stats["val_avg_gen_time"], 1.0)
8896

89-
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
90-
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
91-
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
97+
# test learning requirements:
98+
99+
# 1. BLEU improves over the course of training by more than 2 pts
100+
self.assertGreater(last_step_stats["val_avg_bleu"] - first_step_stats["val_avg_bleu"], 2)
101+
102+
# 2. BLEU finishes above 17
103+
self.assertGreater(last_step_stats["val_avg_bleu"], 17)
104+
105+
# 3. test BLEU and val BLEU within ~1.1 pt.
106+
self.assertLess(abs(metrics["val"][-1]["val_avg_bleu"] - metrics["test"][-1]["test_avg_bleu"]), 1.1)
92107

93108
# check lightning ckpt can be loaded and has a reasonable statedict
94109
contents = os.listdir(output_dir)
@@ -107,11 +122,13 @@ def test_train_mbart_cc25_enro_script(self):
107122
# assert len(metrics["val"]) == desired_n_evals
108123
assert len(metrics["test"]) == 1
109124

125+
126+
class TestDistilMarianNoTeacher(TestCasePlus):
110127
@timeout_decorator.timeout(600)
111128
@slow
112-
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
129+
@require_torch_gpu
113130
def test_opus_mt_distill_script(self):
114-
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
131+
data_dir = f"{self.test_file_dir_str}/test_data/wmt_en_ro"
115132
env_vars_to_replace = {
116133
"--fp16_opt_level=O1": "",
117134
"$MAX_LEN": 128,
@@ -124,7 +141,7 @@ def test_opus_mt_distill_script(self):
124141

125142
# Clean up bash script
126143
bash_script = (
127-
Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip()
144+
(self.test_file_dir / "distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip()
128145
)
129146
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
130147
bash_script = bash_script.replace("--fp16 ", " ")

0 commit comments

Comments
 (0)