Skip to content

Commit e07d0dc

Browse files
committed
Upstream (and rename) sortish_sampler
1 parent c8b1656 commit e07d0dc

File tree

5 files changed

+177
-12
lines changed

5 files changed

+177
-12
lines changed

examples/seq2seq/test_finetune_trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def run_trainer(
169169
--logging_steps 0
170170
--save_steps {str(eval_steps)}
171171
--eval_steps {str(eval_steps)}
172-
--sortish_sampler
172+
--group_by_length
173173
--label_smoothing_factor 0.1
174174
--adafactor
175175
--task translation

src/transformers/trainer.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,13 @@
7070
TrainerState,
7171
)
7272
from .trainer_pt_utils import (
73+
DistributedLengthGroupedSampler,
7374
DistributedTensorGatherer,
7475
LabelSmoother,
7576
SequentialDistributedSampler,
7677
distributed_broadcast_scalars,
7778
distributed_concat,
78-
get_tpu_sampler,
79+
get_length_grouped_indices,
7980
nested_concat,
8081
nested_detach,
8182
nested_numpify,
@@ -94,7 +95,7 @@
9495
set_seed,
9596
speed_metrics,
9697
)
97-
from .training_args import TrainingArguments
98+
from .training_args import ParallelMode, TrainingArguments
9899
from .utils import logging
99100

100101

@@ -448,14 +449,33 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
448449
self.train_dataset, collections.abc.Sized
449450
):
450451
return None
451-
elif is_torch_tpu_available():
452-
return get_tpu_sampler(self.train_dataset)
452+
453+
# Gather the number of processes and this process index.
454+
if self.args.parallel_mode == ParallelMode.TPU:
455+
num_processes = xm.xrt_world_size()
456+
process_index = xm.get_ordinal()
457+
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
458+
num_processes = torch.distributed.get_world_size()
459+
process_index = torch.distributed.get_rank()
453460
else:
454-
return (
455-
RandomSampler(self.train_dataset)
456-
if self.args.local_rank == -1
457-
else DistributedSampler(self.train_dataset)
458-
)
461+
num_processes = 1
462+
process_index = 0
463+
464+
# Build the sampler.
465+
if self.args.group_by_length:
466+
if num_processes <= 1:
467+
lengths = [len(feature["input_ids"]) for feature in self.train_dataset]
468+
return get_length_grouped_indices(lengths, self.args.train_batch_size)
469+
else:
470+
return DistributedLengthGroupedSampler(
471+
self.train_dataset, self.args.train_batch_size, num_replicas=num_processes, rank=process_index
472+
)
473+
474+
else:
475+
if num_processes <= 1:
476+
return RandomSampler(self.train_dataset)
477+
else:
478+
return DistributedSampler(self.train_dataset, num_replicas=num_processes, rank=process_index)
459479

460480
def get_train_dataloader(self) -> DataLoader:
461481
"""

src/transformers/trainer_pt_utils.py

+109-1
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
import warnings
2121
from contextlib import contextmanager
2222
from dataclasses import dataclass
23-
from typing import List, Optional, Union
23+
from typing import Iterator, List, Optional, Union
2424

2525
import numpy as np
2626
import torch
27+
from torch.utils.data.dataset import Dataset
2728
from torch.utils.data.distributed import DistributedSampler
2829
from torch.utils.data.sampler import RandomSampler, Sampler
2930

@@ -390,3 +391,110 @@ def __call__(self, model_output, labels):
390391
# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
391392
smoothed_loss = log_probs.mean(dim=-1).sum() / (padding_mask.numel() - padding_mask.long().sum())
392393
return (1 - self.epsilon) * model_loss + self.epsilon * smoothed_loss
394+
395+
396+
def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None):
397+
"""
398+
Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of
399+
similar lengths. To do this, the indices are:
400+
401+
- randomly permuted
402+
- grouped in mega-batches of size :obj:`mega_batch_mult * batch_size`
403+
- sorted by length in each mega-batch
404+
405+
The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of
406+
maximum length placed first, so that an OOM happens sooner rather than later.
407+
"""
408+
# Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
409+
if mega_batch_mult is None:
410+
mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
411+
# Just in case, for tiny datasets
412+
if mega_batch_mult == 0:
413+
mega_batch_mult = 1
414+
415+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
416+
indices = torch.randperm(len(lengths), generator=generator)
417+
megabatch_size = mega_batch_mult * batch_size
418+
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
419+
megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
420+
421+
# The rest is to get the biggest batch first.
422+
# Since each meagbatch is sorted by descending length, the longest element is the first
423+
megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
424+
max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
425+
# Switch to put the longest element in first position
426+
megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0]
427+
428+
return sum(megabatches, [])
429+
430+
431+
class DistributedLengthGroupedSampler(DistributedSampler):
432+
r"""
433+
Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
434+
length while keeping a bit of randomness.
435+
"""
436+
# Copied and adapted from PyTorch DistributedSampler.
437+
def __init__(
438+
self,
439+
dataset: Dataset,
440+
batch_size: int,
441+
num_replicas: Optional[int] = None,
442+
rank: Optional[int] = None,
443+
seed: int = 0,
444+
drop_last: bool = False,
445+
lengths: Optional[List[int]] = None,
446+
):
447+
if num_replicas is None:
448+
if not torch.distributed.is_available():
449+
raise RuntimeError("Requires distributed package to be available")
450+
num_replicas = torch.distributed.get_world_size()
451+
if rank is None:
452+
if not torch.distributed.is_available():
453+
raise RuntimeError("Requires distributed package to be available")
454+
rank = torch.distributed.get_rank()
455+
self.dataset = dataset
456+
self.batch_size = batch_size
457+
self.num_replicas = num_replicas
458+
self.rank = rank
459+
self.epoch = 0
460+
self.drop_last = drop_last
461+
# If the dataset length is evenly divisible by # of replicas, then there
462+
# is no need to drop any data, since the dataset will be split equally.
463+
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
464+
# Split to nearest available length that is evenly divisible.
465+
# This is to ensure each rank receives the same amount of data when
466+
# using this Sampler.
467+
self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas)
468+
else:
469+
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
470+
self.total_size = self.num_samples * self.num_replicas
471+
self.seed = seed
472+
473+
if lengths is None:
474+
if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]:
475+
raise ValueError(
476+
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
477+
"'input_ids' key."
478+
)
479+
lengths = [len(feature["input_ids"]) for feature in dataset]
480+
self.lengths = lengths
481+
482+
def __iter__(self) -> Iterator:
483+
# Deterministically shuffle based on epoch and seed
484+
g = torch.Generator()
485+
g.manual_seed(self.seed + self.epoch)
486+
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
487+
488+
if not self.drop_last:
489+
# add extra samples to make it evenly divisible
490+
indices += indices[: (self.total_size - len(indices))]
491+
else:
492+
# remove tail of data to make it evenly divisible.
493+
indices = indices[: self.total_size]
494+
assert len(indices) == self.total_size
495+
496+
# subsample
497+
indices = indices[self.rank : self.total_size : self.num_replicas]
498+
assert len(indices) == self.num_samples
499+
500+
return iter(indices)

src/transformers/training_args.py

+7
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ class TrainingArguments:
227227
adafactor (:obj:`bool`, `optional`, defaults to :obj:`False`):
228228
Whether or not to use the :class:`~transformers.Adafactor` optimizer instead of
229229
:class:`~transformers.AdamW`.
230+
group_by_length (:obj:`bool`, `optional`, defaults to :obj:`False`):
231+
Whether or not to group together samples of roughly the same legnth in the training dataset (to minimize
232+
padding applied and be more efficient). Only useful if applying dynamic padding.
230233
"""
231234

232235
output_dir: str = field(
@@ -405,6 +408,10 @@ class TrainingArguments:
405408
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
406409
)
407410
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace Adam by Adafactor."})
411+
group_by_length: bool = field(
412+
default=False,
413+
metadata={"help": "Whether or not to group samples of roughly the same length together when batching."},
414+
)
408415
_n_gpu: int = field(init=False, repr=False, default=-1)
409416

410417
def __post_init__(self):

tests/test_trainer_utils.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@
2525
import torch
2626

2727
from transformers.modeling_outputs import SequenceClassifierOutput
28-
from transformers.trainer_pt_utils import DistributedTensorGatherer, LabelSmoother
28+
from transformers.trainer_pt_utils import (
29+
DistributedLengthGroupedSampler,
30+
DistributedTensorGatherer,
31+
LabelSmoother,
32+
get_length_grouped_indices,
33+
)
2934

3035

3136
@require_torch
@@ -87,3 +92,28 @@ def test_label_smoothing(self):
8792
log_probs[2, 3] = 0.0
8893
expected_loss = (1 - epsilon) * loss + epsilon * log_probs.sum() / (num_labels * 17)
8994
self.assertTrue(torch.allclose(label_smoothed_loss, expected_loss))
95+
96+
def test_group_by_length(self):
97+
# Get some inputs of random lengths
98+
lengths = torch.randint(0, 25, (100,)).tolist()
99+
# Put one bigger than the others to check it ends up in first position
100+
lengths[32] = 50
101+
102+
indices = get_length_grouped_indices(lengths, 4)
103+
# The biggest element should be first
104+
self.assertEqual(lengths[indices[0]], 50)
105+
# The indices should be a permutation of range(100)
106+
self.assertEqual(list(sorted(indices)), list(range(100)))
107+
108+
def test_distributed_length_grouped(self):
109+
# Get some inputs of random lengths
110+
lengths = torch.randint(0, 25, (100,)).tolist()
111+
# Put one bigger than the others to check it ends up in first position
112+
lengths[32] = 50
113+
114+
indices_process_0 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 0, lengths=lengths))
115+
indices_process_1 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 1, lengths=lengths))
116+
# The biggest element should be first
117+
self.assertEqual(lengths[indices_process_0[0]], 50)
118+
# The indices should be a permutation of range(100)
119+
self.assertEqual(list(sorted(indices_process_0 + indices_process_1)), list(range(100)))

0 commit comments

Comments
 (0)