Skip to content

Commit 7b75aa9

Browse files
authored
[TPU] Doc, fix xla_spawn.py, only preprocess dataset once (huggingface#4223)
* [TPU] Doc, fix xla_spawn.py, only preprocess dataset once * Update examples/README.md * [xla_spawn] Add `_mp_fn` to other Trainer scripts * [TPU] Fix: eval dataloader was None
1 parent 274d850 commit 7b75aa9

File tree

10 files changed

+88
-47
lines changed

10 files changed

+88
-47
lines changed

examples/README.md

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,28 @@ pip install -r ./examples/requirements.txt
5353

5454
## Running on TPUs
5555

56-
Documentation to come.
56+
When using Tensorflow, TPUs are supported out of the box as a `tf.distribute.Strategy`.
57+
58+
When using PyTorch, we support TPUs thanks to `pytorch/xla`. For more context and information on how to setup your TPU environment refer to Google's documentation and to the
59+
very detailed [pytorch/xla README](https://github.com/pytorch/xla/blob/master/README.md).
60+
61+
In this repo, we provide a very simple launcher script named [xla_spawn.py](./xla_spawn.py) that lets you run our example scripts on multiple TPU cores without any boilerplate.
62+
Just pass a `--num_cores` flag to this script, then your regular training script with its arguments (this is similar to the `torch.distributed.launch` helper for torch.distributed).
63+
64+
For example for `run_glue`:
65+
66+
```bash
67+
python examples/xla_spawn.py --num_cores 8 \
68+
examples/text-classification/run_glue.py
69+
--model_name_or_path bert-base-cased \
70+
--task_name mnli \
71+
--data_dir ./data/glue_data/MNLI \
72+
--output_dir ./models/tpu \
73+
--overwrite_output_dir \
74+
--do_train \
75+
--do_eval \
76+
--num_train_epochs 1 \
77+
--save_steps 20000
78+
```
79+
80+
Feedback and more use cases and benchmarks involving TPUs are welcome, please share with the community.

examples/bertology/run_bertology.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def main():
404404
logger.info("Training/evaluation parameters %s", args)
405405

406406
# Prepare dataset for the GLUE task
407-
eval_dataset = GlueDataset(args, tokenizer=tokenizer, evaluate=True, local_rank=args.local_rank)
407+
eval_dataset = GlueDataset(args, tokenizer=tokenizer, evaluate=True)
408408
if args.data_subset > 0:
409409
eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
410410
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)

examples/language-modeling/run_language_modeling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,5 +280,10 @@ def main():
280280
return results
281281

282282

283+
def _mp_fn(index):
284+
# For xla_spawn (TPUs)
285+
main()
286+
287+
283288
if __name__ == "__main__":
284289
main()

examples/multiple-choice/run_multiple_choice.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,5 +221,10 @@ def compute_metrics(p: EvalPrediction) -> Dict:
221221
return results
222222

223223

224+
def _mp_fn(index):
225+
# For xla_spawn (TPUs)
226+
main()
227+
228+
224229
if __name__ == "__main__":
225230
main()

examples/text-classification/README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,12 @@ CoLA, SST-2. The following section provides details on how to run half-precision
8585
said, there shouldn’t be any issues in running half-precision training with the remaining GLUE tasks as well,
8686
since the data processor for each task inherits from the base class DataProcessor.
8787

88-
## Running on TPUs
88+
## Running on TPUs in PyTorch
8989

90-
You can accelerate your workloads on Google's TPUs. For information on how to setup your TPU environment refer to this
91-
[README](https://github.com/pytorch/xla/blob/master/README.md).
90+
**Update**: read the more up-to-date [Running on TPUs](../README.md#running-on-tpus) in the main README.md instead.
91+
92+
Even when running PyTorch, you can accelerate your workloads on Google's TPUs, using `pytorch/xla`. For information on how to setup your TPU environment refer to the
93+
[pytorch/xla README](https://github.com/pytorch/xla/blob/master/README.md).
9294

9395
The following are some examples of running the `*_tpu.py` finetuning scripts on TPUs. All steps for data preparation are
9496
identical to your normal GPU + Huggingface setup.
@@ -101,7 +103,6 @@ export GLUE_DIR=/path/to/glue
101103
export TASK_NAME=MNLI
102104
103105
python run_glue_tpu.py \
104-
--model_type bert \
105106
--model_name_or_path bert-base-cased \
106107
--task_name $TASK_NAME \
107108
--do_train \
@@ -115,8 +116,7 @@ python run_glue_tpu.py \
115116
--overwrite_output_dir \
116117
--logging_steps 50 \
117118
--save_steps 200 \
118-
--num_cores=8 \
119-
--only_log_master
119+
--num_cores=8
120120
```
121121

122122
### MRPC

examples/text-classification/run_glue.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -134,16 +134,8 @@ def main():
134134
)
135135

136136
# Get datasets
137-
train_dataset = (
138-
GlueDataset(data_args, tokenizer=tokenizer, local_rank=training_args.local_rank)
139-
if training_args.do_train
140-
else None
141-
)
142-
eval_dataset = (
143-
GlueDataset(data_args, tokenizer=tokenizer, local_rank=training_args.local_rank, evaluate=True)
144-
if training_args.do_eval
145-
else None
146-
)
137+
train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
138+
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
147139

148140
def compute_metrics(p: EvalPrediction) -> Dict:
149141
if output_mode == "classification":
@@ -181,9 +173,7 @@ def compute_metrics(p: EvalPrediction) -> Dict:
181173
eval_datasets = [eval_dataset]
182174
if data_args.task_name == "mnli":
183175
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
184-
eval_datasets.append(
185-
GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, local_rank=training_args.local_rank, evaluate=True)
186-
)
176+
eval_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, evaluate=True))
187177

188178
for eval_dataset in eval_datasets:
189179
result = trainer.evaluate(eval_dataset=eval_dataset)

examples/token-classification/run_ner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,5 +292,10 @@ def compute_metrics(p: EvalPrediction) -> Dict:
292292
return results
293293

294294

295+
def _mp_fn(index):
296+
# For xla_spawn (TPUs)
297+
main()
298+
299+
295300
if __name__ == "__main__":
296301
main()

examples/xla_spawn.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,13 @@
1212

1313

1414
import importlib
15-
import os
1615
import sys
1716
from argparse import REMAINDER, ArgumentParser
17+
from pathlib import Path
1818

1919
import torch_xla.distributed.xla_multiprocessing as xmp
2020

2121

22-
def trim_suffix(s: str, suffix: str):
23-
return s if not s.endswith(suffix) or len(suffix) == 0 else s[: -len(suffix)]
24-
25-
2622
def parse_args():
2723
"""
2824
Helper function parsing the command line options
@@ -44,7 +40,7 @@ def parse_args():
4440
"training_script",
4541
type=str,
4642
help=(
47-
"The full module name to the single TPU training "
43+
"The full path to the single TPU training "
4844
"program/script to be launched in parallel, "
4945
"followed by all the arguments for the "
5046
"training script"
@@ -61,7 +57,9 @@ def main():
6157
args = parse_args()
6258

6359
# Import training_script as a module.
64-
mod_name = trim_suffix(os.path.basename(args.training_script), ".py")
60+
script_fpath = Path(args.training_script)
61+
sys.path.append(str(script_fpath.parent.resolve()))
62+
mod_name = script_fpath.stem
6563
mod = importlib.import_module(mod_name)
6664

6765
# Patch sys.argv

src/transformers/data/datasets/glue.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
from typing import List, Optional
66

77
import torch
8+
from filelock import FileLock
89
from torch.utils.data.dataset import Dataset
910

1011
from ...tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
1112
from ...tokenization_utils import PreTrainedTokenizer
1213
from ...tokenization_xlm_roberta import XLMRobertaTokenizer
13-
from ...trainer import torch_distributed_zero_first
1414
from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors
1515
from ..processors.utils import InputFeatures
1616

@@ -63,7 +63,6 @@ def __init__(
6363
tokenizer: PreTrainedTokenizer,
6464
limit_length: Optional[int] = None,
6565
evaluate=False,
66-
local_rank=-1,
6766
):
6867
self.args = args
6968
processor = glue_processors[args.task_name]()
@@ -75,9 +74,11 @@ def __init__(
7574
"dev" if evaluate else "train", tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
7675
),
7776
)
78-
with torch_distributed_zero_first(local_rank):
79-
# Make sure only the first process in distributed training processes the dataset,
80-
# and the others will use the cache.
77+
78+
# Make sure only the first process in distributed training processes the dataset,
79+
# and the others will use the cache.
80+
lock_path = cached_features_file + ".lock"
81+
with FileLock(lock_path):
8182

8283
if os.path.exists(cached_features_file) and not args.overwrite_cache:
8384
start = time.time()
@@ -109,13 +110,12 @@ def __init__(
109110
label_list=label_list,
110111
output_mode=self.output_mode,
111112
)
112-
if local_rank in [-1, 0]:
113-
start = time.time()
114-
torch.save(self.features, cached_features_file)
115-
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
116-
logger.info(
117-
f"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
118-
)
113+
start = time.time()
114+
torch.save(self.features, cached_features_file)
115+
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
116+
logger.info(
117+
f"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
118+
)
119119

120120
def __len__(self):
121121
return len(self.features)

src/transformers/trainer.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import shutil
77
from contextlib import contextmanager
88
from pathlib import Path
9-
from typing import Callable, Dict, List, Optional, Tuple
9+
from typing import Callable, Dict, List, Optional, Tuple, Union
1010

1111
import numpy as np
1212
import torch
@@ -195,10 +195,12 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
195195
if eval_dataset is None and self.eval_dataset is None:
196196
raise ValueError("Trainer: evaluation requires an eval_dataset.")
197197

198+
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
199+
198200
sampler = get_tpu_sampler(eval_dataset) if is_tpu_available() else None
199201

200202
data_loader = DataLoader(
201-
eval_dataset if eval_dataset is not None else self.eval_dataset,
203+
eval_dataset,
202204
sampler=sampler,
203205
batch_size=self.args.eval_batch_size,
204206
shuffle=False,
@@ -267,6 +269,16 @@ def _setup_wandb(self):
267269
# keep track of model topology and gradients
268270
wandb.watch(self.model)
269271

272+
def num_examples(self, dataloader: Union[DataLoader, "pl.PerDeviceLoader"]) -> int:
273+
"""
274+
Helper to get num of examples from a DataLoader, by accessing its Dataset.
275+
"""
276+
if is_tpu_available():
277+
assert isinstance(dataloader, pl.PerDeviceLoader)
278+
return len(dataloader._loader._loader.dataset)
279+
else:
280+
return len(dataloader.dataset)
281+
270282
def train(self, model_path: Optional[str] = None):
271283
"""
272284
Main training entry point.
@@ -326,17 +338,15 @@ def train(self, model_path: Optional[str] = None):
326338

327339
# Train!
328340
if is_tpu_available():
329-
num_examples = len(train_dataloader._loader._loader.dataset)
330341
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
331342
else:
332-
num_examples = len(train_dataloader.dataset)
333343
total_train_batch_size = (
334344
self.args.train_batch_size
335345
* self.args.gradient_accumulation_steps
336346
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
337347
)
338348
logger.info("***** Running training *****")
339-
logger.info(" Num examples = %d", num_examples)
349+
logger.info(" Num examples = %d", self.num_examples(train_dataloader))
340350
logger.info(" Num Epochs = %d", num_train_epochs)
341351
logger.info(" Instantaneous batch size per device = %d", self.args.per_gpu_train_batch_size)
342352
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
@@ -606,9 +616,13 @@ def _prediction_loop(
606616
model = self.model
607617
model.to(self.args.device)
608618

619+
if is_tpu_available():
620+
batch_size = dataloader._loader._loader.batch_size
621+
else:
622+
batch_size = dataloader.batch_size
609623
logger.info("***** Running %s *****", description)
610-
logger.info(" Num examples = %d", len(dataloader.dataset))
611-
logger.info(" Batch size = %d", dataloader.batch_size)
624+
logger.info(" Num examples = %d", self.num_examples(dataloader))
625+
logger.info(" Batch size = %d", batch_size)
612626
eval_losses: List[float] = []
613627
preds: np.ndarray = None
614628
label_ids: np.ndarray = None

0 commit comments

Comments
 (0)