Skip to content

Commit 6a3a197

Browse files
kding1mfuntowicz
andauthored
Add SigOpt HPO to transformers trainer api (#13572)
* add sigopt hpo to transformers. Signed-off-by: Ding, Ke <ke.ding@intel.com> * extend sigopt changes to test code and others.. Signed-off-by: Ding, Ke <ke.ding@intel.com> * Style. * fix style for sigopt integration. Signed-off-by: Ding, Ke <ke.ding@intel.com> * Add necessary information to run unittests on SigOpt. Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
1 parent 62832c9 commit 6a3a197

File tree

10 files changed

+168
-15
lines changed

10 files changed

+168
-15
lines changed

.github/workflows/self-nightly-scheduled.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ env:
1515
OMP_NUM_THREADS: 16
1616
MKL_NUM_THREADS: 16
1717
PYTEST_TIMEOUT: 600
18+
SIGOPT_API_TOKEN: ${{ secrets.SIGOPT_API_TOKEN }}
1819

1920
jobs:
2021
run_all_tests_torch_gpu:

.github/workflows/self-scheduled.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ env:
1515
OMP_NUM_THREADS: 16
1616
MKL_NUM_THREADS: 16
1717
PYTEST_TIMEOUT: 600
18+
SIGOPT_API_TOKEN: ${{ secrets.SIGOPT_API_TOKEN }}
1819

1920
jobs:
2021
run_all_tests_torch_gpu:

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
"sagemaker>=2.31.0",
136136
"scikit-learn",
137137
"sentencepiece>=0.1.91,!=0.1.92",
138+
"sigopt",
138139
"soundfile",
139140
"sphinx-copybutton",
140141
"sphinx-markdown-tables",
@@ -248,8 +249,9 @@ def run(self):
248249
extras["fairscale"] = deps_list("fairscale")
249250
extras["optuna"] = deps_list("optuna")
250251
extras["ray"] = deps_list("ray[tune]")
252+
extras["sigopt"] = deps_list("sigopt")
251253

252-
extras["integrations"] = extras["optuna"] + extras["ray"]
254+
extras["integrations"] = extras["optuna"] + extras["ray"]+ extras["sigopt"]
253255

254256
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
255257
extras["audio"] = deps_list("soundfile")

src/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@
130130
"is_optuna_available",
131131
"is_ray_available",
132132
"is_ray_tune_available",
133+
"is_sigopt_available",
133134
"is_tensorboard_available",
134135
"is_wandb_available",
135136
],
@@ -1951,6 +1952,7 @@
19511952
is_optuna_available,
19521953
is_ray_available,
19531954
is_ray_tune_available,
1955+
is_sigopt_available,
19541956
is_tensorboard_available,
19551957
is_wandb_available,
19561958
)

src/transformers/dependency_versions_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
"sagemaker": "sagemaker>=2.31.0",
5454
"scikit-learn": "scikit-learn",
5555
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
56+
"sigopt": "sigopt",
5657
"soundfile": "soundfile",
5758
"sphinx-copybutton": "sphinx-copybutton",
5859
"sphinx-markdown-tables": "sphinx-markdown-tables",

src/transformers/integrations.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ def is_ray_tune_available():
8383
return importlib.util.find_spec("ray.tune") is not None
8484

8585

86+
def is_sigopt_available():
87+
return importlib.util.find_spec("sigopt") is not None
88+
89+
8690
def is_azureml_available():
8791
if importlib.util.find_spec("azureml") is None:
8892
return False
@@ -117,6 +121,10 @@ def hp_params(trial):
117121
if isinstance(trial, dict):
118122
return trial
119123

124+
if is_sigopt_available():
125+
if isinstance(trial, dict):
126+
return trial
127+
120128
raise RuntimeError(f"Unknown type for trial {trial.__class__}")
121129

122130

@@ -125,6 +133,8 @@ def default_hp_search_backend():
125133
return "optuna"
126134
elif is_ray_tune_available():
127135
return "ray"
136+
elif is_sigopt_available():
137+
return "sigopt"
128138

129139

130140
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
@@ -288,6 +298,45 @@ def dynamic_modules_import_trainable(*args, **kwargs):
288298
return best_run
289299

290300

301+
def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
302+
303+
from sigopt import Connection
304+
305+
conn = Connection()
306+
proxies = kwargs.pop("proxies", None)
307+
if proxies is not None:
308+
conn.set_proxies(proxies)
309+
310+
experiment = conn.experiments().create(
311+
name="huggingface-tune",
312+
parameters=trainer.hp_space(None),
313+
metrics=[dict(name="objective", objective=direction, strategy="optimize")],
314+
parallel_bandwidth=1,
315+
observation_budget=n_trials,
316+
project="huggingface",
317+
)
318+
logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")
319+
320+
while experiment.progress.observation_count < experiment.observation_budget:
321+
suggestion = conn.experiments(experiment.id).suggestions().create()
322+
trainer.objective = None
323+
trainer.train(resume_from_checkpoint=None, trial=suggestion)
324+
# If there hasn't been any evaluation during the training loop.
325+
if getattr(trainer, "objective", None) is None:
326+
metrics = trainer.evaluate()
327+
trainer.objective = trainer.compute_objective(metrics)
328+
329+
values = [dict(name="objective", value=trainer.objective)]
330+
obs = conn.experiments(experiment.id).observations().create(suggestion=suggestion.id, values=values)
331+
logger.info(f"[suggestion_id, observation_id]: [{suggestion.id}, {obs.id}]")
332+
experiment = conn.experiments(experiment.id).fetch()
333+
334+
best = list(conn.experiments(experiment.id).best_assignments().fetch().iterate_pages())[0]
335+
best_run = BestRun(best.id, best.value, best.assignments)
336+
337+
return best_run
338+
339+
291340
def get_available_reporting_integrations():
292341
integrations = []
293342
if is_azureml_available():

src/transformers/testing_utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
is_torchaudio_available,
5252
is_vision_available,
5353
)
54-
from .integrations import is_optuna_available, is_ray_available
54+
from .integrations import is_optuna_available, is_ray_available, is_sigopt_available
5555

5656

5757
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
@@ -511,6 +511,19 @@ def require_ray(test_case):
511511
return test_case
512512

513513

514+
def require_sigopt(test_case):
515+
"""
516+
Decorator marking a test that requires SigOpt.
517+
518+
These tests are skipped when SigOpt isn't installed.
519+
520+
"""
521+
if not is_sigopt_available():
522+
return unittest.skip("test requires SigOpt")(test_case)
523+
else:
524+
return test_case
525+
526+
514527
def require_soundfile(test_case):
515528
"""
516529
Decorator marking a test that requires soundfile

src/transformers/trainer.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,10 @@
4040
is_fairscale_available,
4141
is_optuna_available,
4242
is_ray_tune_available,
43+
is_sigopt_available,
4344
run_hp_search_optuna,
4445
run_hp_search_ray,
46+
run_hp_search_sigopt,
4547
)
4648

4749
import numpy as np
@@ -231,9 +233,9 @@ class Trainer:
231233
A function that instantiates the model to be used. If provided, each call to
232234
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
233235
234-
The function may have zero argument, or a single one containing the optuna/Ray Tune trial object, to be
235-
able to choose different architectures according to hyper parameters (such as layer count, sizes of inner
236-
layers, dropout probabilities etc).
236+
The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
237+
be able to choose different architectures according to hyper parameters (such as layer count, sizes of
238+
inner layers, dropout probabilities etc).
237239
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
238240
The function that will be used to compute metrics at evaluation. Must take a
239241
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
@@ -869,6 +871,8 @@ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
869871
elif self.hp_search_backend == HPSearchBackend.RAY:
870872
params = trial
871873
params.pop("wandb", None)
874+
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
875+
params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
872876

873877
for key, value in params.items():
874878
if not hasattr(self.args, key):
@@ -883,6 +887,8 @@ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
883887
setattr(self.args, key, value)
884888
if self.hp_search_backend == HPSearchBackend.OPTUNA:
885889
logger.info("Trial:", trial.params)
890+
if self.hp_search_backend == HPSearchBackend.SIGOPT:
891+
logger.info(f"SigOpt Assignments: {trial.assignments}")
886892
if self.args.deepspeed:
887893
# Rebuild the deepspeed config to reflect the updated training parameters
888894
from transformers.deepspeed import HfDeepSpeedConfig
@@ -1232,7 +1238,7 @@ def train(
12321238
self.callback_handler.lr_scheduler = self.lr_scheduler
12331239
self.callback_handler.train_dataloader = train_dataloader
12341240
self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
1235-
self.state.trial_params = hp_params(trial) if trial is not None else None
1241+
self.state.trial_params = hp_params(trial.assignments) if trial is not None else None
12361242
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
12371243
# to set this after the load.
12381244
self.state.max_steps = max_steps
@@ -1524,10 +1530,12 @@ def _save_checkpoint(self, model, trial, metrics=None):
15241530
if self.hp_search_backend is not None and trial is not None:
15251531
if self.hp_search_backend == HPSearchBackend.OPTUNA:
15261532
run_id = trial.number
1527-
else:
1533+
elif self.hp_search_backend == HPSearchBackend.RAY:
15281534
from ray import tune
15291535

15301536
run_id = tune.get_trial_id()
1537+
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
1538+
run_id = trial.id
15311539
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
15321540
run_dir = os.path.join(self.args.output_dir, run_name)
15331541
else:
@@ -1671,9 +1679,9 @@ def hyperparameter_search(
16711679
**kwargs,
16721680
) -> BestRun:
16731681
"""
1674-
Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by
1675-
:obj:`compute_objective`, which defaults to a function returning the evaluation loss when no metric is
1676-
provided, the sum of all metrics otherwise.
1682+
Launch an hyperparameter search using ``optuna`` or ``Ray Tune`` or ``SigOpt``. The optimized quantity is
1683+
determined by :obj:`compute_objective`, which defaults to a function returning the evaluation loss when no
1684+
metric is provided, the sum of all metrics otherwise.
16771685
16781686
.. warning::
16791687
@@ -1686,7 +1694,8 @@ def hyperparameter_search(
16861694
hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`):
16871695
A function that defines the hyperparameter search space. Will default to
16881696
:func:`~transformers.trainer_utils.default_hp_space_optuna` or
1689-
:func:`~transformers.trainer_utils.default_hp_space_ray` depending on your backend.
1697+
:func:`~transformers.trainer_utils.default_hp_space_ray` or
1698+
:func:`~transformers.trainer_utils.default_hp_space_sigopt` depending on your backend.
16901699
compute_objective (:obj:`Callable[[Dict[str, float]], float]`, `optional`):
16911700
A function computing the objective to minimize or maximize from the metrics returned by the
16921701
:obj:`evaluate` method. Will default to :func:`~transformers.trainer_utils.default_compute_objective`.
@@ -1697,8 +1706,8 @@ def hyperparameter_search(
16971706
pick :obj:`"minimize"` when optimizing the validation loss, :obj:`"maximize"` when optimizing one or
16981707
several metrics.
16991708
backend(:obj:`str` or :class:`~transformers.training_utils.HPSearchBackend`, `optional`):
1700-
The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending on which
1701-
one is installed. If both are installed, will default to optuna.
1709+
The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
1710+
on which one is installed. If all are installed, will default to optuna.
17021711
kwargs:
17031712
Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For
17041713
more information see:
@@ -1707,6 +1716,7 @@ def hyperparameter_search(
17071716
<https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html>`__
17081717
- the documentation of `tune.run
17091718
<https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__
1719+
- the documentation of `sigopt <https://app.sigopt.com/docs/endpoints/experiments/create>`__
17101720
17111721
Returns:
17121722
:class:`transformers.trainer_utils.BestRun`: All the information about the best run.
@@ -1718,6 +1728,7 @@ def hyperparameter_search(
17181728
"At least one of optuna or ray should be installed. "
17191729
"To install optuna run `pip install optuna`."
17201730
"To install ray run `pip install ray[tune]`."
1731+
"To install sigopt run `pip install sigopt`."
17211732
)
17221733
backend = HPSearchBackend(backend)
17231734
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
@@ -1726,6 +1737,8 @@ def hyperparameter_search(
17261737
raise RuntimeError(
17271738
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
17281739
)
1740+
if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
1741+
raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
17291742
self.hp_search_backend = backend
17301743
if self.model_init is None:
17311744
raise RuntimeError(
@@ -1736,8 +1749,12 @@ def hyperparameter_search(
17361749
self.hp_name = hp_name
17371750
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
17381751

1739-
run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
1740-
best_run = run_hp_search(self, n_trials, direction, **kwargs)
1752+
backend_dict = {
1753+
HPSearchBackend.OPTUNA: run_hp_search_optuna,
1754+
HPSearchBackend.RAY: run_hp_search_ray,
1755+
HPSearchBackend.SIGOPT: run_hp_search_sigopt,
1756+
}
1757+
best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
17411758

17421759
self.hp_search_backend = None
17431760
return best_run

src/transformers/trainer_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,14 +198,29 @@ def default_hp_space_ray(trial) -> Dict[str, float]:
198198
}
199199

200200

201+
def default_hp_space_sigopt(trial):
202+
return [
203+
{"bounds": {"min": 1e-6, "max": 1e-4}, "name": "learning_rate", "type": "double", "transformamtion": "log"},
204+
{"bounds": {"min": 1, "max": 6}, "name": "num_train_epochs", "type": "int"},
205+
{"bounds": {"min": 1, "max": 40}, "name": "seed", "type": "int"},
206+
{
207+
"categorical_values": ["4", "8", "16", "32", "64"],
208+
"name": "per_device_train_batch_size",
209+
"type": "categorical",
210+
},
211+
]
212+
213+
201214
class HPSearchBackend(ExplicitEnum):
202215
OPTUNA = "optuna"
203216
RAY = "ray"
217+
SIGOPT = "sigopt"
204218

205219

206220
default_hp_space = {
207221
HPSearchBackend.OPTUNA: default_hp_space_optuna,
208222
HPSearchBackend.RAY: default_hp_space_ray,
223+
HPSearchBackend.SIGOPT: default_hp_space_sigopt,
209224
}
210225

211226

tests/test_trainer.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
require_optuna,
5151
require_ray,
5252
require_sentencepiece,
53+
require_sigopt,
5354
require_tokenizers,
5455
require_torch,
5556
require_torch_gpu,
@@ -1522,3 +1523,54 @@ def test_hyperparameter_search_ray_client(self):
15221523
with ray_start_client_server():
15231524
assert ray.util.client.ray.is_connected()
15241525
self.ray_hyperparameter_search()
1526+
1527+
1528+
@require_torch
1529+
@require_sigopt
1530+
class TrainerHyperParameterSigOptIntegrationTest(unittest.TestCase):
1531+
def setUp(self):
1532+
args = TrainingArguments(".")
1533+
self.n_epochs = args.num_train_epochs
1534+
self.batch_size = args.train_batch_size
1535+
1536+
def test_hyperparameter_search(self):
1537+
class MyTrialShortNamer(TrialShortNamer):
1538+
DEFAULTS = {"a": 0, "b": 0}
1539+
1540+
def hp_space(trial):
1541+
return [
1542+
{"bounds": {"min": -4, "max": 4}, "name": "a", "type": "int"},
1543+
{"bounds": {"min": -4, "max": 4}, "name": "b", "type": "int"},
1544+
]
1545+
1546+
def model_init(trial):
1547+
if trial is not None:
1548+
a = trial.assignments["a"]
1549+
b = trial.assignments["b"]
1550+
else:
1551+
a = 0
1552+
b = 0
1553+
config = RegressionModelConfig(a=a, b=b, double_output=False)
1554+
1555+
return RegressionPreTrainedModel(config)
1556+
1557+
def hp_name(trial):
1558+
return MyTrialShortNamer.shortname(trial.assignments)
1559+
1560+
with tempfile.TemporaryDirectory() as tmp_dir:
1561+
trainer = get_regression_trainer(
1562+
output_dir=tmp_dir,
1563+
learning_rate=0.1,
1564+
logging_steps=1,
1565+
evaluation_strategy=IntervalStrategy.EPOCH,
1566+
save_strategy=IntervalStrategy.EPOCH,
1567+
num_train_epochs=4,
1568+
disable_tqdm=True,
1569+
load_best_model_at_end=True,
1570+
logging_dir="runs",
1571+
run_name="test",
1572+
model_init=model_init,
1573+
)
1574+
trainer.hyperparameter_search(
1575+
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="sigopt", n_trials=4
1576+
)

0 commit comments

Comments
 (0)