Skip to content

Commit 05bc843

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
support IterableDataset in build_detection_train_loader
Differential Revision: D24677397 fbshipit-source-id: 1e4a991c521da1e139ccc7fe40b715dc921d3294
1 parent 2a345c1 commit 05bc843

File tree

3 files changed

+63
-24
lines changed

3 files changed

+63
-24
lines changed

detectron2/data/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
print_instances_class_histogram,
1111
)
1212
from .catalog import DatasetCatalog, MetadataCatalog, Metadata
13-
from .common import DatasetFromList, MapDataset
13+
from .common import DatasetFromList, MapDataset, ToIterableDataset
1414
from .dataset_mapper import DatasetMapper
1515

1616
# ensure the builtin datasets are registered

detectron2/data/build.py

+30-22
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import numpy as np
55
import operator
66
import pickle
7-
import torch.utils.data
7+
import torch
8+
import torch.utils.data as torchdata
89
from tabulate import tabulate
910
from termcolor import colored
1011

@@ -16,7 +17,7 @@
1617
from detectron2.utils.logger import _log_api_usage, log_first_n
1718

1819
from .catalog import DatasetCatalog, MetadataCatalog
19-
from .common import AspectRatioGroupedDataset, DatasetFromList, MapDataset
20+
from .common import AspectRatioGroupedDataset, DatasetFromList, MapDataset, ToIterableDataset
2021
from .dataset_mapper import DatasetMapper
2122
from .detection_utils import check_metadata_consistency
2223
from .samplers import (
@@ -270,8 +271,9 @@ def build_batch_data_loader(
270271
2. use no "batch collation", because this is common for detection training
271272
272273
Args:
273-
dataset (torch.utils.data.Dataset): map-style PyTorch dataset. Can be indexed.
274-
sampler (torch.utils.data.sampler.Sampler): a sampler that produces indices
274+
dataset (torch.utils.data.Dataset): a pytorch map-style or iterable dataset.
275+
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices.
276+
Must be provided iff. ``dataset`` is a map-style dataset.
275277
total_batch_size, aspect_ratio_grouping, num_workers): see
276278
:func:`build_detection_train_loader`.
277279
@@ -285,26 +287,27 @@ def build_batch_data_loader(
285287
), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
286288
total_batch_size, world_size
287289
)
288-
289290
batch_size = total_batch_size // world_size
291+
292+
if isinstance(dataset, torchdata.IterableDataset):
293+
assert sampler is None, "sampler must be None if dataset is IterableDataset"
294+
else:
295+
dataset = ToIterableDataset(dataset, sampler)
296+
290297
if aspect_ratio_grouping:
291-
data_loader = torch.utils.data.DataLoader(
298+
data_loader = torchdata.DataLoader(
292299
dataset,
293-
sampler=sampler,
294300
num_workers=num_workers,
295-
batch_sampler=None,
296301
collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
297302
worker_init_fn=worker_init_reset_seed,
298303
) # yield individual mapped dict
299304
return AspectRatioGroupedDataset(data_loader, batch_size)
300305
else:
301-
batch_sampler = torch.utils.data.sampler.BatchSampler(
302-
sampler, batch_size, drop_last=True
303-
) # drop_last so the batch always have the same size
304-
return torch.utils.data.DataLoader(
306+
return torchdata.DataLoader(
305307
dataset,
308+
batch_size=batch_size,
309+
drop_last=True,
306310
num_workers=num_workers,
307-
batch_sampler=batch_sampler,
308311
collate_fn=trivial_batch_collator,
309312
worker_init_fn=worker_init_reset_seed,
310313
)
@@ -351,7 +354,6 @@ def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
351354
}
352355

353356

354-
# TODO can allow dataset as an iterable or IterableDataset to make this function more general
355357
@configurable(from_config=_train_loader_from_config)
356358
def build_detection_train_loader(
357359
dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0
@@ -362,14 +364,16 @@ def build_detection_train_loader(
362364
363365
Args:
364366
dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
365-
or a map-style pytorch dataset. They can be obtained by using
366-
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
367+
or a pytorch dataset (either map-style or iterable). It can be obtained
368+
by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
367369
mapper (callable): a callable which takes a sample (dict) from dataset and
368370
returns the format to be consumed by the model.
369371
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
370372
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
371-
indices to be applied on ``dataset``. Default to :class:`TrainingSampler`,
373+
indices to be applied on ``dataset``.
374+
If ``dataset`` is map-style, the default sampler is a :class:`TrainingSampler`,
372375
which coordinates an infinite random shuffle sequence across all workers.
376+
Sampler must be None if ``dataset`` is iterable.
373377
total_batch_size (int): total batch size across all workers. Batching
374378
simply puts data into a list.
375379
aspect_ratio_grouping (bool): whether to group images with similar
@@ -387,9 +391,13 @@ def build_detection_train_loader(
387391
dataset = DatasetFromList(dataset, copy=False)
388392
if mapper is not None:
389393
dataset = MapDataset(dataset, mapper)
390-
if sampler is None:
391-
sampler = TrainingSampler(len(dataset))
392-
assert isinstance(sampler, torch.utils.data.sampler.Sampler)
394+
395+
if isinstance(dataset, torchdata.IterableDataset):
396+
assert sampler is None, "sampler must be None if dataset is IterableDataset"
397+
else:
398+
if sampler is None:
399+
sampler = TrainingSampler(len(dataset))
400+
assert isinstance(sampler, torchdata.Sampler), f"Expect a Sampler but got {type(sampler)}"
393401
return build_batch_data_loader(
394402
dataset,
395403
sampler,
@@ -462,8 +470,8 @@ def build_detection_test_loader(dataset, *, mapper, sampler=None, num_workers=0)
462470
sampler = InferenceSampler(len(dataset))
463471
# Always use 1 image per worker during inference since this is the
464472
# standard when reporting inference time in papers.
465-
batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, 1, drop_last=False)
466-
data_loader = torch.utils.data.DataLoader(
473+
batch_sampler = torchdata.sampler.BatchSampler(sampler, 1, drop_last=False)
474+
data_loader = torchdata.DataLoader(
467475
dataset,
468476
num_workers=num_workers,
469477
batch_sampler=batch_sampler,

tests/data/test_dataset.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,15 @@
88
import torch
99
from iopath.common.file_io import LazyPath
1010

11-
from detectron2.data.build import DatasetFromList, MapDataset
11+
from detectron2 import model_zoo
12+
from detectron2.config import instantiate
13+
from detectron2.data import (
14+
DatasetFromList,
15+
MapDataset,
16+
ToIterableDataset,
17+
build_detection_train_loader,
18+
)
19+
from detectron2.data.samplers import TrainingSampler
1220

1321

1422
def _a_slow_func(x):
@@ -61,3 +69,26 @@ def test_pickleability(self):
6169
ds = MapDataset(ds, lambda x: x * 2)
6270
ds = pickle.loads(pickle.dumps(ds))
6371
self.assertEqual(ds[0], 2)
72+
73+
74+
@unittest.skipIf(os.environ.get("CI"), "Skipped OSS testing due to COCO data requirement.")
75+
class TestDataLoader(unittest.TestCase):
76+
def _get_kwargs(self):
77+
# get kwargs of build_detection_train_loader
78+
cfg = model_zoo.get_config("common/data/coco.py").dataloader.train
79+
cfg.dataset.names = "coco_2017_val_100"
80+
cfg.pop("_target_")
81+
kwargs = {k: instantiate(v) for k, v in cfg.items()}
82+
return kwargs
83+
84+
def test_build_dataloader(self):
85+
kwargs = self._get_kwargs()
86+
dl = build_detection_train_loader(**kwargs)
87+
next(iter(dl))
88+
89+
def test_build_iterable_dataloader(self):
90+
kwargs = self._get_kwargs()
91+
ds = DatasetFromList(kwargs.pop("dataset"))
92+
ds = ToIterableDataset(ds, TrainingSampler(len(ds)))
93+
dl = build_detection_train_loader(dataset=ds, **kwargs)
94+
next(iter(dl))

0 commit comments

Comments
 (0)