@@ -63,7 +63,7 @@ class SummarizationModule(BaseTransformer):
63
63
mode = "summarization"
64
64
loss_names = ["loss" ]
65
65
metric_names = ROUGE_KEYS
66
- val_metric = "rouge2"
66
+ default_val_metric = "rouge2"
67
67
68
68
def __init__ (self , hparams , ** kwargs ):
69
69
super ().__init__ (hparams , num_labels = None , mode = self .mode , ** kwargs )
@@ -110,6 +110,9 @@ def __init__(self, hparams, **kwargs):
110
110
self .dataset_class = (
111
111
Seq2SeqDataset if hasattr (self .tokenizer , "prepare_seq2seq_batch" ) else LegacySeq2SeqDataset
112
112
)
113
+ self .eval_beams = self .model .config .num_beams if self .hparams .eval_beams is None else self .hparams .eval_beams
114
+ assert self .eval_beams >= 1 , f"got self.eval_beams={ self .eval_beams } . Need an integer > 1"
115
+ self .val_metric = self .default_val_metric if self .hparams .val_metric is None else self .hparams .val_metric
113
116
114
117
def freeze_embeds (self ):
115
118
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
@@ -301,6 +304,8 @@ def add_model_specific_args(parser, root_dir):
301
304
parser .add_argument ("--label_smoothing" , type = float , default = 0.0 , required = False )
302
305
parser .add_argument ("--src_lang" , type = str , default = "" , required = False )
303
306
parser .add_argument ("--tgt_lang" , type = str , default = "" , required = False )
307
+ parser .add_argument ("--eval_beams" , type = int , default = None , required = False )
308
+ parser .add_argument ("--val_metric" , type = str , default = None , required = False )
304
309
parser .add_argument (
305
310
"--early_stopping_patience" ,
306
311
type = int ,
@@ -315,7 +320,7 @@ class TranslationModule(SummarizationModule):
315
320
mode = "translation"
316
321
loss_names = ["loss" ]
317
322
metric_names = ["bleu" ]
318
- val_metric = "bleu"
323
+ default_val_metric = "bleu"
319
324
320
325
def __init__ (self , hparams , ** kwargs ):
321
326
super ().__init__ (hparams , ** kwargs )
0 commit comments