Skip to content

Commit b977289

Browse files
authored
[s2s] command line args for faster val steps (#6833)
1 parent 8af1970 commit b977289

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

Diff for: examples/seq2seq/distillation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ class BartTranslationDistiller(BartSummarizationDistiller):
262262

263263
mode = "translation"
264264
metric_names = ["bleu"]
265-
val_metric = "bleu"
265+
default_val_metric = "bleu"
266266

267267
def __init__(self, hparams, **kwargs):
268268
super().__init__(hparams, **kwargs)

Diff for: examples/seq2seq/finetune.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class SummarizationModule(BaseTransformer):
6363
mode = "summarization"
6464
loss_names = ["loss"]
6565
metric_names = ROUGE_KEYS
66-
val_metric = "rouge2"
66+
default_val_metric = "rouge2"
6767

6868
def __init__(self, hparams, **kwargs):
6969
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
@@ -110,6 +110,9 @@ def __init__(self, hparams, **kwargs):
110110
self.dataset_class = (
111111
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
112112
)
113+
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
114+
assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
115+
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
113116

114117
def freeze_embeds(self):
115118
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
@@ -301,6 +304,8 @@ def add_model_specific_args(parser, root_dir):
301304
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
302305
parser.add_argument("--src_lang", type=str, default="", required=False)
303306
parser.add_argument("--tgt_lang", type=str, default="", required=False)
307+
parser.add_argument("--eval_beams", type=int, default=None, required=False)
308+
parser.add_argument("--val_metric", type=str, default=None, required=False)
304309
parser.add_argument(
305310
"--early_stopping_patience",
306311
type=int,
@@ -315,7 +320,7 @@ class TranslationModule(SummarizationModule):
315320
mode = "translation"
316321
loss_names = ["loss"]
317322
metric_names = ["bleu"]
318-
val_metric = "bleu"
323+
default_val_metric = "bleu"
319324

320325
def __init__(self, hparams, **kwargs):
321326
super().__init__(hparams, **kwargs)

Diff for: examples/seq2seq/test_seq2seq_examples.py

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
CUDA_AVAILABLE = torch.cuda.is_available()
3232
CHEAP_ARGS = {
3333
"label_smoothing": 0.2,
34+
"eval_beams": 1,
35+
"val_metric": None,
3436
"adafactor": True,
3537
"early_stopping_patience": 2,
3638
"logger_name": "default",

0 commit comments

Comments
 (0)