|
20 | 20 | import warnings
|
21 | 21 | from contextlib import contextmanager
|
22 | 22 | from dataclasses import dataclass
|
23 |
| -from typing import List, Optional, Union |
| 23 | +from typing import Iterator, List, Optional, Union |
24 | 24 |
|
25 | 25 | import numpy as np
|
26 | 26 | import torch
|
| 27 | +from torch.utils.data.dataset import Dataset |
27 | 28 | from torch.utils.data.distributed import DistributedSampler
|
28 | 29 | from torch.utils.data.sampler import RandomSampler, Sampler
|
29 | 30 |
|
@@ -390,3 +391,110 @@ def __call__(self, model_output, labels):
|
390 | 391 | # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
|
391 | 392 | smoothed_loss = log_probs.mean(dim=-1).sum() / (padding_mask.numel() - padding_mask.long().sum())
|
392 | 393 | return (1 - self.epsilon) * model_loss + self.epsilon * smoothed_loss
|
| 394 | + |
| 395 | + |
| 396 | +def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None): |
| 397 | + """ |
| 398 | + Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of |
| 399 | + similar lengths. To do this, the indices are: |
| 400 | +
|
| 401 | + - randomly permuted |
| 402 | + - grouped in mega-batches of size :obj:`mega_batch_mult * batch_size` |
| 403 | + - sorted by length in each mega-batch |
| 404 | +
|
| 405 | + The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of |
| 406 | + maximum length placed first, so that an OOM happens sooner rather than later. |
| 407 | + """ |
| 408 | + # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller. |
| 409 | + if mega_batch_mult is None: |
| 410 | + mega_batch_mult = min(len(lengths) // (batch_size * 4), 50) |
| 411 | + # Just in case, for tiny datasets |
| 412 | + if mega_batch_mult == 0: |
| 413 | + mega_batch_mult = 1 |
| 414 | + |
| 415 | + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. |
| 416 | + indices = torch.randperm(len(lengths), generator=generator) |
| 417 | + megabatch_size = mega_batch_mult * batch_size |
| 418 | + megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] |
| 419 | + megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches] |
| 420 | + |
| 421 | + # The rest is to get the biggest batch first. |
| 422 | + # Since each meagbatch is sorted by descending length, the longest element is the first |
| 423 | + megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches] |
| 424 | + max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item() |
| 425 | + # Switch to put the longest element in first position |
| 426 | + megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0] |
| 427 | + |
| 428 | + return sum(megabatches, []) |
| 429 | + |
| 430 | + |
| 431 | +class DistributedLengthGroupedSampler(DistributedSampler): |
| 432 | + r""" |
| 433 | + Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same |
| 434 | + length while keeping a bit of randomness. |
| 435 | + """ |
| 436 | + # Copied and adapted from PyTorch DistributedSampler. |
| 437 | + def __init__( |
| 438 | + self, |
| 439 | + dataset: Dataset, |
| 440 | + batch_size: int, |
| 441 | + num_replicas: Optional[int] = None, |
| 442 | + rank: Optional[int] = None, |
| 443 | + seed: int = 0, |
| 444 | + drop_last: bool = False, |
| 445 | + lengths: Optional[List[int]] = None, |
| 446 | + ): |
| 447 | + if num_replicas is None: |
| 448 | + if not torch.distributed.is_available(): |
| 449 | + raise RuntimeError("Requires distributed package to be available") |
| 450 | + num_replicas = torch.distributed.get_world_size() |
| 451 | + if rank is None: |
| 452 | + if not torch.distributed.is_available(): |
| 453 | + raise RuntimeError("Requires distributed package to be available") |
| 454 | + rank = torch.distributed.get_rank() |
| 455 | + self.dataset = dataset |
| 456 | + self.batch_size = batch_size |
| 457 | + self.num_replicas = num_replicas |
| 458 | + self.rank = rank |
| 459 | + self.epoch = 0 |
| 460 | + self.drop_last = drop_last |
| 461 | + # If the dataset length is evenly divisible by # of replicas, then there |
| 462 | + # is no need to drop any data, since the dataset will be split equally. |
| 463 | + if self.drop_last and len(self.dataset) % self.num_replicas != 0: |
| 464 | + # Split to nearest available length that is evenly divisible. |
| 465 | + # This is to ensure each rank receives the same amount of data when |
| 466 | + # using this Sampler. |
| 467 | + self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas) |
| 468 | + else: |
| 469 | + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) |
| 470 | + self.total_size = self.num_samples * self.num_replicas |
| 471 | + self.seed = seed |
| 472 | + |
| 473 | + if lengths is None: |
| 474 | + if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]: |
| 475 | + raise ValueError( |
| 476 | + "Can only automatically infer lengths for datasets whose items are dictionaries with an " |
| 477 | + "'input_ids' key." |
| 478 | + ) |
| 479 | + lengths = [len(feature["input_ids"]) for feature in dataset] |
| 480 | + self.lengths = lengths |
| 481 | + |
| 482 | + def __iter__(self) -> Iterator: |
| 483 | + # Deterministically shuffle based on epoch and seed |
| 484 | + g = torch.Generator() |
| 485 | + g.manual_seed(self.seed + self.epoch) |
| 486 | + indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g) |
| 487 | + |
| 488 | + if not self.drop_last: |
| 489 | + # add extra samples to make it evenly divisible |
| 490 | + indices += indices[: (self.total_size - len(indices))] |
| 491 | + else: |
| 492 | + # remove tail of data to make it evenly divisible. |
| 493 | + indices = indices[: self.total_size] |
| 494 | + assert len(indices) == self.total_size |
| 495 | + |
| 496 | + # subsample |
| 497 | + indices = indices[self.rank : self.total_size : self.num_replicas] |
| 498 | + assert len(indices) == self.num_samples |
| 499 | + |
| 500 | + return iter(indices) |
0 commit comments