Skip to content

Commit 09a2f40

Browse files
sshleiferPradhy729
andauthored
Seq2SeqDataset uses linecache to save memory by @Pradhy729 (#5792)
Co-authored-by: Pradhy729 <49659913+Pradhy729@users.noreply.github.com>
1 parent 4b506a3 commit 09a2f40

File tree

6 files changed

+181
-169
lines changed

6 files changed

+181
-169
lines changed

Diff for: examples/seq2seq/README.md

+14-32
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,24 @@ For `bertabs` instructions, see `bertabs/README.md`.
77

88

99
### Data
10-
11-
CNN/DailyMail data
10+
XSUM Data:
1211
```bash
1312
cd examples/seq2seq
14-
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
15-
tar -xzvf cnn_dm.tgz
16-
17-
export CNN_DIR=${PWD}/cnn_dm
13+
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
14+
tar -xzvf xsum.tar.gz
15+
export XSUM_DIR=${PWD}/xsum
1816
```
19-
2017
this should make a directory called cnn_dm/ with files like `test.source`.
2118
To use your own data, copy that files format. Each article to be summarized is on its own line.
2219

23-
XSUM Data:
20+
CNN/DailyMail data
2421
```bash
2522
cd examples/seq2seq
26-
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
27-
tar -xzvf xsum.tar.gz
28-
export XSUM_DIR=${PWD}/xsum
29-
```
23+
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
24+
tar -xzvf cnn_dm.tgz
3025

26+
export CNN_DIR=${PWD}/cnn_dm
27+
```
3128

3229
WMT16 English-Romanian Translation Data:
3330
```bash
@@ -40,7 +37,7 @@ export ENRO_DIR=${PWD}/wmt_en_ro
4037
If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target.
4138
The `.source` files are the input, the `.target` files are the desired output.
4239

43-
40+
4441
### Tips and Tricks
4542

4643
General Tips:
@@ -64,6 +61,10 @@ Summarization Tips:
6461
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
6562
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
6663

64+
**Update 2018-07-18**
65+
Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.**
66+
A new dataset is needed to support multilingual tasks.
67+
6768
### Summarization Finetuning
6869
Run/modify `finetune.sh`
6970

@@ -78,8 +79,6 @@ The following command should work on a 16GB GPU:
7879
--model_name_or_path facebook/bart-large
7980
```
8081

81-
82-
8382
### Translation Finetuning
8483

8584
First, follow the wmt_en_ro download instructions.
@@ -124,23 +123,6 @@ from transformers import AutoModelForSeq2SeqLM
124123
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
125124
```
126125

127-
#### XSUM Shared Task
128-
Compare XSUM results with others by using `--logger_name wandb_shared`. This requires `wandb` registration.
129-
130-
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
131-
```bash
132-
WANDB_PROJECT='hf_xsum' ./finetune.sh \
133-
--data_dir $XSUM_DIR \
134-
--output_dir xsum_frozen_embs \
135-
--model_name_or_path facebook/bart-large \
136-
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
137-
--num_train_epochs 6 \
138-
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
139-
--logger_name wandb
140-
```
141-
142-
You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
143-
144126
### Evaluation Commands
145127

146128
To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models.

Diff for: examples/seq2seq/distillation.py

+4-22
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,15 @@
1515

1616
try:
1717
from .finetune import SummarizationModule
18-
from .initialization_utils import init_student, copy_layers
19-
from .utils import (
20-
use_task_specific_params,
21-
SummarizationDataset,
22-
pickle_load,
23-
freeze_params,
24-
assert_all_frozen,
25-
any_requires_grad,
26-
)
2718
from .finetune import main as ft_main
19+
from .initialization_utils import init_student, copy_layers
20+
from .utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
21+
2822
except ImportError:
2923
from finetune import SummarizationModule
3024
from finetune import main as ft_main
3125
from initialization_utils import init_student, copy_layers
32-
from utils import (
33-
use_task_specific_params,
34-
SummarizationDataset,
35-
pickle_load,
36-
freeze_params,
37-
assert_all_frozen,
38-
any_requires_grad,
39-
)
26+
from utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
4027

4128

4229
class BartSummarizationDistiller(SummarizationModule):
@@ -115,11 +102,6 @@ def copy_t5_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, studen
115102
if self.different_encoder:
116103
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
117104

118-
def get_dataset(self, type_path) -> SummarizationDataset:
119-
n_obs = self.n_obs[type_path]
120-
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, n_obs=n_obs, **self.dataset_kwargs)
121-
return dataset
122-
123105
def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
124106
if mask is not None:
125107
# mask has False at padding_idx

Diff for: examples/seq2seq/finetune.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from .utils import (
2222
assert_all_frozen,
2323
use_task_specific_params,
24-
SummarizationDataset,
2524
lmap,
2625
flatten_list,
2726
pickle_save,
@@ -32,12 +31,17 @@
3231
get_git_info,
3332
ROUGE_KEYS,
3433
calculate_bleu_score,
34+
Seq2SeqDataset,
35+
MBartDataset,
3536
)
37+
3638
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
3739
except ImportError:
3840
from utils import (
41+
Seq2SeqDataset,
42+
MBartDataset,
43+
assert_all_frozen,
3944
use_task_specific_params,
40-
SummarizationDataset,
4145
lmap,
4246
flatten_list,
4347
pickle_save,
@@ -48,7 +52,6 @@
4852
get_git_info,
4953
ROUGE_KEYS,
5054
calculate_bleu_score,
51-
assert_all_frozen,
5255
)
5356
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
5457

@@ -100,6 +103,7 @@ def __init__(self, hparams, **kwargs):
100103
self.hparams.git_sha = get_git_info()["repo_sha"]
101104
self.num_workers = hparams.num_workers
102105
self.decoder_start_token_id = None
106+
self.dataset_class = Seq2SeqDataset
103107

104108
def freeze_embeds(self):
105109
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
@@ -163,7 +167,7 @@ def calc_generative_metrics(self, preds, target) -> Dict:
163167

164168
def _generative_step(self, batch: dict) -> dict:
165169
pad_token_id = self.tokenizer.pad_token_id
166-
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
170+
source_ids, source_mask, y = Seq2SeqDataset.trim_seq2seq_batch(batch, pad_token_id)
167171
t0 = time.time()
168172
generated_ids = self.model.generate(
169173
input_ids=source_ids,
@@ -187,10 +191,10 @@ def test_step(self, batch, batch_idx):
187191
def test_epoch_end(self, outputs):
188192
return self.validation_epoch_end(outputs, prefix="test")
189193

190-
def get_dataset(self, type_path) -> SummarizationDataset:
194+
def get_dataset(self, type_path) -> Seq2SeqDataset:
191195
n_obs = self.n_obs[type_path]
192196
max_target_length = self.target_lens[type_path]
193-
dataset = SummarizationDataset(
197+
dataset = self.dataset_class(
194198
self.tokenizer,
195199
type_path=type_path,
196200
n_obs=n_obs,
@@ -303,6 +307,8 @@ def __init__(self, hparams, **kwargs):
303307
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
304308
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
305309
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
310+
if isinstance(self.tokenizer, MBartTokenizer):
311+
self.dataset_class = MBartDataset
306312

307313
def calc_generative_metrics(self, preds, target) -> dict:
308314
return calculate_bleu_score(preds, target)

Diff for: examples/seq2seq/test_seq2seq_examples.py

+47-16
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,25 @@
99

1010
import pytest
1111
import torch
12+
from pytest import param
1213
from torch.utils.data import DataLoader
1314

14-
from transformers import AutoTokenizer
15+
from transformers import AutoTokenizer, MBartTokenizer
1516
from transformers.testing_utils import require_multigpu
1617

1718
from .distillation import distill_main, evaluate_checkpoint
1819
from .finetune import main
1920
from .pack_dataset import pack_data_dir
2021
from .run_eval import generate_summaries_or_translations, run_generate
21-
from .utils import SummarizationDataset, lmap, load_json
22+
from .utils import MBartDataset, Seq2SeqDataset, lmap, load_json
2223

2324

2425
logging.basicConfig(level=logging.DEBUG)
2526

2627
logger = logging.getLogger()
2728
CUDA_AVAILABLE = torch.cuda.is_available()
2829
CHEAP_ARGS = {
30+
"label_smoothing_eps": 0.2,
2931
"logger_name": "default",
3032
"length_penalty": 0.5,
3133
"cache_dir": "",
@@ -80,11 +82,11 @@
8082

8183

8284
def _dump_articles(path: Path, articles: list):
83-
with path.open("w") as f:
84-
f.write("\n".join(articles))
85+
content = "\n".join(articles)
86+
Path(path).open("w").writelines(content)
8587

8688

87-
ARTICLES = [" Sam ate lunch today", "Sams lunch ingredients"]
89+
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
8890
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
8991
T5_TINY = "patrickvonplaten/t5-tiny-random"
9092
BART_TINY = "sshleifer/bart-tiny-random"
@@ -208,7 +210,7 @@ def test_run_eval_bart(model):
208210

209211

210212
@pytest.mark.parametrize(
211-
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
213+
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)],
212214
)
213215
def test_finetune(model):
214216
args_d: dict = CHEAP_ARGS.copy()
@@ -260,22 +262,50 @@ def test_pack_dataset():
260262
assert orig_paths == new_paths
261263

262264

263-
@pytest.mark.parametrize(
264-
["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
265-
)
266-
def test_dataset(tok):
267-
tokenizer = AutoTokenizer.from_pretrained(tok)
265+
def test_mbart_dataset_truncation():
266+
tokenizer = MBartTokenizer.from_pretrained(MBART_TINY)
268267
tmp_dir = make_test_data_dir()
269268
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
270269
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
271-
trunc_target = 4
272-
train_dataset = SummarizationDataset(
270+
trunc = 4
271+
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
272+
train_dataset = MBartDataset(
273273
tokenizer,
274274
data_dir=tmp_dir,
275275
type_path="train",
276-
max_source_length=20,
277-
max_target_length=trunc_target,
278-
tgt_lang="ro_RO",
276+
max_source_length=trunc,
277+
max_target_length=1000, # ignored
278+
src_lang=src_lang,
279+
tgt_lang=tgt_lang,
280+
)
281+
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
282+
for batch in dataloader:
283+
assert isinstance(batch, dict)
284+
assert batch["attention_mask"].shape == batch["input_ids"].shape
285+
# show that articles were trimmed.
286+
assert batch["input_ids"].shape[1] == trunc
287+
# show that targets are the same len
288+
assert batch["decoder_input_ids"].shape[1] == trunc
289+
# check language codes in correct place
290+
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
291+
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
292+
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
293+
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
294+
295+
assert max_len_target > trunc # Truncated
296+
assert max_len_source > trunc
297+
break # No need to test every batch
298+
299+
300+
@pytest.mark.parametrize(["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), param(MARIAN_TINY)])
301+
def test_summarization_dataset_truncation(tok):
302+
tokenizer = AutoTokenizer.from_pretrained(tok)
303+
tmp_dir = make_test_data_dir()
304+
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
305+
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
306+
trunc_target = 4
307+
train_dataset = Seq2SeqDataset(
308+
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
279309
)
280310
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
281311
for batch in dataloader:
@@ -286,3 +316,4 @@ def test_dataset(tok):
286316
# show that targets were truncated
287317
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated
288318
assert max_len_target > trunc_target # Truncated
319+
break # No need to test every batch

0 commit comments

Comments
 (0)