Skip to content

Commit 500be01

Browse files
authored
[s2s] save first batch to json for debugging purposes (#6810)
1 parent 2b574e7 commit 500be01

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

Diff for: examples/seq2seq/finetune.py

+16
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
lmap,
3434
pickle_save,
3535
save_git_info,
36+
save_json,
3637
use_task_specific_params,
3738
)
3839

@@ -105,13 +106,25 @@ def __init__(self, hparams, **kwargs):
105106
self.dataset_class = (
106107
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
107108
)
109+
self.already_saved_batch = False
108110
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
109111
if self.hparams.eval_max_gen_length is not None:
110112
self.eval_max_length = self.hparams.eval_max_gen_length
111113
else:
112114
self.eval_max_length = self.model.config.max_length
113115
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
114116

117+
def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]:
118+
"""A debugging utility"""
119+
readable_batch = {
120+
k: self.tokenizer.batch_decode(v.tolist()) if "mask" not in k else v.shape for k, v in batch.items()
121+
}
122+
save_json(readable_batch, Path(self.output_dir) / "text_batch.json")
123+
save_json({k: v.tolist() for k, v in batch.items()}, Path(self.output_dir) / "tok_batch.json")
124+
125+
self.already_saved_batch = True
126+
return readable_batch
127+
115128
def forward(self, input_ids, **kwargs):
116129
return self.model(input_ids, **kwargs)
117130

@@ -129,6 +142,9 @@ def _step(self, batch: dict) -> Tuple:
129142
decoder_input_ids = self.model._shift_right(tgt_ids)
130143
else:
131144
decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
145+
if not self.already_saved_batch: # This would be slightly better if it only happened on rank zero
146+
batch["decoder_input_ids"] = decoder_input_ids
147+
self.save_readable_batch(batch)
132148

133149
outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
134150
lm_logits = outputs[0]

Diff for: examples/seq2seq/test_seq2seq_examples.py

+4
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,10 @@ def test_finetune(model):
422422
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
423423
assert bart.decoder.embed_tokens == bart.shared
424424

425+
example_batch = load_json(module.output_dir / "text_batch.json")
426+
assert isinstance(example_batch, dict)
427+
assert len(example_batch) >= 4
428+
425429

426430
def test_finetune_extra_model_args():
427431
args_d: dict = CHEAP_ARGS.copy()

0 commit comments

Comments
 (0)