33
33
lmap ,
34
34
pickle_save ,
35
35
save_git_info ,
36
+ save_json ,
36
37
use_task_specific_params ,
37
38
)
38
39
@@ -105,13 +106,25 @@ def __init__(self, hparams, **kwargs):
105
106
self .dataset_class = (
106
107
Seq2SeqDataset if hasattr (self .tokenizer , "prepare_seq2seq_batch" ) else LegacySeq2SeqDataset
107
108
)
109
+ self .already_saved_batch = False
108
110
self .eval_beams = self .model .config .num_beams if self .hparams .eval_beams is None else self .hparams .eval_beams
109
111
if self .hparams .eval_max_gen_length is not None :
110
112
self .eval_max_length = self .hparams .eval_max_gen_length
111
113
else :
112
114
self .eval_max_length = self .model .config .max_length
113
115
self .val_metric = self .default_val_metric if self .hparams .val_metric is None else self .hparams .val_metric
114
116
117
+ def save_readable_batch (self , batch : Dict [str , torch .Tensor ]) -> Dict [str , List [str ]]:
118
+ """A debugging utility"""
119
+ readable_batch = {
120
+ k : self .tokenizer .batch_decode (v .tolist ()) if "mask" not in k else v .shape for k , v in batch .items ()
121
+ }
122
+ save_json (readable_batch , Path (self .output_dir ) / "text_batch.json" )
123
+ save_json ({k : v .tolist () for k , v in batch .items ()}, Path (self .output_dir ) / "tok_batch.json" )
124
+
125
+ self .already_saved_batch = True
126
+ return readable_batch
127
+
115
128
def forward (self , input_ids , ** kwargs ):
116
129
return self .model (input_ids , ** kwargs )
117
130
@@ -129,6 +142,9 @@ def _step(self, batch: dict) -> Tuple:
129
142
decoder_input_ids = self .model ._shift_right (tgt_ids )
130
143
else :
131
144
decoder_input_ids = shift_tokens_right (tgt_ids , pad_token_id )
145
+ if not self .already_saved_batch : # This would be slightly better if it only happened on rank zero
146
+ batch ["decoder_input_ids" ] = decoder_input_ids
147
+ self .save_readable_batch (batch )
132
148
133
149
outputs = self (src_ids , attention_mask = src_mask , decoder_input_ids = decoder_input_ids , use_cache = False )
134
150
lm_logits = outputs [0 ]
0 commit comments