Skip to content

Commit f410d5b

Browse files
FCChenfacebook-github-bot
authored andcommitted
RandomSubsetTrainingSampler to randomly sample training data subset for accuracy-data curve
Summary: Add a sampler class `RandomSubsetTrainingSampler`, which is similar to TrainingSampler but only sample a random subset (e.g., 50%) of indices. `RandomSubsetTrainingSampler` is useful when you want to estimate the accuracy vs data-volume curves by training the model with different `subset_ratio`. Reviewed By: ppwwyyxx Differential Revision: D29892290 fbshipit-source-id: a342a6f1aa7852feb6566c648bd673028a3e0668
1 parent 91f1d95 commit f410d5b

File tree

3 files changed

+76
-2
lines changed

3 files changed

+76
-2
lines changed

detectron2/data/build.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@
1919
from .common import AspectRatioGroupedDataset, DatasetFromList, MapDataset
2020
from .dataset_mapper import DatasetMapper
2121
from .detection_utils import check_metadata_consistency
22-
from .samplers import InferenceSampler, RepeatFactorTrainingSampler, TrainingSampler
22+
from .samplers import (
23+
InferenceSampler,
24+
RandomSubsetTrainingSampler,
25+
RepeatFactorTrainingSampler,
26+
TrainingSampler,
27+
)
2328

2429
"""
2530
This file contains the default logic to build a dataloader for training or testing.
@@ -331,6 +336,8 @@ def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
331336
dataset, cfg.DATALOADER.REPEAT_THRESHOLD
332337
)
333338
sampler = RepeatFactorTrainingSampler(repeat_factors)
339+
elif sampler_name == "RandomSubsetTrainingSampler":
340+
sampler = RandomSubsetTrainingSampler(len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO)
334341
else:
335342
raise ValueError("Unknown training sampler: {}".format(sampler_name))
336343

detectron2/data/samplers/__init__.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
2-
from .distributed_sampler import InferenceSampler, RepeatFactorTrainingSampler, TrainingSampler
2+
from .distributed_sampler import (
3+
InferenceSampler,
4+
RandomSubsetTrainingSampler,
5+
RepeatFactorTrainingSampler,
6+
TrainingSampler,
7+
)
8+
39
from .grouped_batch_sampler import GroupedBatchSampler
410

511
__all__ = [
612
"GroupedBatchSampler",
713
"TrainingSampler",
14+
"RandomSubsetTrainingSampler",
815
"InferenceSampler",
916
"RepeatFactorTrainingSampler",
1017
]

detectron2/data/samplers/distributed_sampler.py

+60
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
22
import itertools
3+
import logging
34
import math
45
from collections import defaultdict
56
from typing import Optional
@@ -8,6 +9,8 @@
89

910
from detectron2.utils import comm
1011

12+
logger = logging.getLogger(__name__)
13+
1114

1215
class TrainingSampler(Sampler):
1316
"""
@@ -66,6 +69,63 @@ def _infinite_indices(self):
6669
yield from torch.arange(self._size).tolist()
6770

6871

72+
class RandomSubsetTrainingSampler(TrainingSampler):
73+
"""
74+
Similar to TrainingSampler, but only sample a random subset of indices.
75+
This is useful when you want to estimate the accuracy vs data-number curves by
76+
training the model with different subset_ratio.
77+
"""
78+
79+
def __init__(
80+
self,
81+
size: int,
82+
subset_ratio: float,
83+
shuffle: bool = True,
84+
seed_shuffle: Optional[int] = None,
85+
seed_subset: Optional[int] = None,
86+
):
87+
"""
88+
Args:
89+
size (int): the total number of data of the underlying dataset to sample from
90+
subset_ratio (float): the ratio of subset data to sample from the underlying dataset
91+
shuffle (bool): whether to shuffle the indices or not
92+
seed_shuffle (int): the initial seed of the shuffle. Must be the same
93+
across all workers. If None, will use a random seed shared
94+
among workers (require synchronization among all workers).
95+
seed_subset (int): the seed to randomize the subset to be sampled.
96+
Must be the same across all workers. If None, will use a random seed shared
97+
among workers (require synchronization among all workers).
98+
"""
99+
super().__init__(size=size, shuffle=shuffle, seed=seed_shuffle)
100+
101+
assert 0.0 < subset_ratio <= 1.0
102+
self._size_subset = int(size * subset_ratio)
103+
assert self._size_subset > 0
104+
if seed_subset is None:
105+
seed_subset = comm.shared_random_seed()
106+
self._seed_subset = int(seed_subset)
107+
108+
# randomly generate the subset indexes to be sampled from
109+
g = torch.Generator()
110+
g.manual_seed(self._seed_subset)
111+
indexes_randperm = torch.randperm(self._size, generator=g)
112+
self._indexes_subset = indexes_randperm[: self._size_subset]
113+
114+
logger.info("Using RandomSubsetTrainingSampler......")
115+
logger.info(f"Randomly sample {self._size_subset} data from the original {self._size} data")
116+
117+
def _infinite_indices(self):
118+
g = torch.Generator()
119+
g.manual_seed(self._seed) # self._seed equals seed_shuffle from __init__()
120+
while True:
121+
if self._shuffle:
122+
# generate a random permutation to shuffle self._indexes_subset
123+
randperm = torch.randperm(self._size_subset, generator=g)
124+
yield from self._indexes_subset[randperm].tolist()
125+
else:
126+
yield from self._indexes_subset.tolist()
127+
128+
69129
class RepeatFactorTrainingSampler(Sampler):
70130
"""
71131
Similar to TrainingSampler, but a sample may appear more times than others based

0 commit comments

Comments
 (0)