-
Notifications
You must be signed in to change notification settings - Fork 28.4k
/
Copy pathrun_eval_search.py
executable file
·138 lines (113 loc) · 5.2 KB
/
run_eval_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python
import argparse
import itertools
import operator
import sys
from collections import OrderedDict
from run_eval import datetime_now, run_generate
from utils import ROUGE_KEYS
# A table of supported tasks and the list of scores in the order of importance to be sorted by.
# To add a new task, simply list the score names that `run_eval.run_generate()` returns
task_score_names = {
"translation": ["bleu"],
"summarization": ROUGE_KEYS,
}
def parse_search_arg(search):
groups = search.split()
entries = {k: vs for k, vs in (g.split("=") for g in groups)}
entry_names = list(entries.keys())
sets = [list((f"--{k} {v}") for v in vs.split(":")) for k, vs in entries.items()]
matrix = [list(x) for x in itertools.product(*sets)]
return matrix, entry_names
def run_search():
"""
Run parametric search over the desired hparam space with help of ``run_eval.py``.
All the arguments except ``--search`` are passed to ``run_eval.py`` as is. The values inside of "--search" are parsed, reformatted and fed to ``run_eval.py`` as additional args.
The format for the ``--search`` value is a simple string with hparams and colon separated values to try, e.g.:
```
--search "num_beams=5:10 length_penalty=0.8:1.0:1.2 early_stopping=true:false"
```
which will generate ``12`` ``(2*3*2)`` searches for a product of each hparam. For example the example that was just used will invoke ``run_eval.py`` repeatedly with:
```
--num_beams 5 --length_penalty 0.8 --early_stopping true
--num_beams 5 --length_penalty 0.8 --early_stopping false
[...]
--num_beams 10 --length_penalty 1.2 --early_stopping false
```
On completion, this function prints a markdown table of the results sorted by the best BLEU score and the winning arguments.
"""
prog = sys.argv[0]
parser = argparse.ArgumentParser(
usage="\n\nImportant: this script accepts all arguments `run_eval.py` accepts and then a few extra, therefore refer to `run_eval.py -h` for the complete list."
)
parser.add_argument(
"--search",
type=str,
required=False,
help='param space to search, e.g. "num_beams=5:10 length_penalty=0.8:1.0:1.2"',
)
parser.add_argument(
"--bs", type=int, default=8, required=False, help="initial batch size (may get reduced if it's too big)"
)
parser.add_argument("--task", type=str, help="used for task_specific_params + metrics")
parser.add_argument(
"--info",
nargs="?",
type=str,
const=datetime_now(),
help="add custom notes to be printed before the results table. If no value is passed, the current datetime string will be used.",
)
args, args_main = parser.parse_known_args()
# we share some of the args
args_main.extend(["--task", args.task])
args_normal = [prog] + args_main
# to support variations like translation_en_to_de"
task = "translation" if "translation" in args.task else "summarization"
matrix, col_names = parse_search_arg(args.search)
col_names[0:0] = task_score_names[task] # score cols first
col_widths = {col: len(str(col)) for col in col_names}
results = []
for r in matrix:
hparams = {k: v for k, v in (x.replace("--", "").split() for x in r)}
args_exp = " ".join(r).split()
args_exp.extend(["--bs", str(args.bs)]) # in case we need to reduce its size due to CUDA OOM
sys.argv = args_normal + args_exp
# XXX: need to trap CUDA OOM and lower args.bs if that happens and retry
scores = run_generate(verbose=False)
# make sure scores are first in the table
result = OrderedDict()
for score in task_score_names[task]:
result[score] = scores[score]
result.update(hparams)
results.append(result)
# find widest entries
for k, v in result.items():
l = len(str(v))
if l > col_widths[k]:
col_widths[k] = l
results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[task]), reverse=True)
print(" | ".join([f"{col:{col_widths[col]}}" for col in col_names]))
print(" | ".join([f"{'-'*col_widths[col]}" for col in col_names]))
for row in results_sorted:
print(" | ".join([f"{row[col]:{col_widths[col]}}" for col in col_names]))
best = results_sorted[0]
for score in task_score_names[task]:
del best[score]
best_args = [f"--{k} {v}" for k, v in best.items()]
dyn_args = ["--bs", str(args.bs)]
if args.info:
print(f"\nInfo: {args.info}")
print("\nBest score args:")
print(" ".join(args_main + best_args + dyn_args))
return results_sorted
if __name__ == "__main__":
# Usage:
# [normal-run_eval_search.py cmd plus] \
# --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false"
#
# Example:
# PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval_search.py $MODEL_NAME \
# $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target \
# --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation \
# --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false"
run_search()