Skip to content

Commit 0a2a4a3

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
support iterable dataset in build_detection_test_loader
Summary: Like D24677397 (facebookresearch@05bc843), but for test loader Differential Revision: D31161853 fbshipit-source-id: 6c101843a2be681fc23b2ff241070876a77be80f
1 parent 2a1cec4 commit 0a2a4a3

File tree

2 files changed

+41
-16
lines changed

2 files changed

+41
-16
lines changed

detectron2/data/build.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -439,14 +439,15 @@ def build_detection_test_loader(dataset, *, mapper, sampler=None, num_workers=0)
439439
440440
Args:
441441
dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
442-
or a map-style pytorch dataset. They can be obtained by using
443-
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
442+
or a pytorch dataset (either map-style or iterable). They can be obtained
443+
by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
444444
mapper (callable): a callable which takes a sample (dict) from dataset
445445
and returns the format to be consumed by the model.
446446
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
447447
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
448448
indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
449-
which splits the dataset across all workers.
449+
which splits the dataset across all workers. Sampler must be None
450+
if `dataset` is iterable.
450451
num_workers (int): number of parallel data loading workers
451452
452453
Returns:
@@ -466,18 +467,20 @@ def build_detection_test_loader(dataset, *, mapper, sampler=None, num_workers=0)
466467
dataset = DatasetFromList(dataset, copy=False)
467468
if mapper is not None:
468469
dataset = MapDataset(dataset, mapper)
469-
if sampler is None:
470-
sampler = InferenceSampler(len(dataset))
470+
if isinstance(dataset, torchdata.IterableDataset):
471+
assert sampler is None, "sampler must be None if dataset is IterableDataset"
472+
else:
473+
if sampler is None:
474+
sampler = InferenceSampler(len(dataset))
471475
# Always use 1 image per worker during inference since this is the
472476
# standard when reporting inference time in papers.
473-
batch_sampler = torchdata.sampler.BatchSampler(sampler, 1, drop_last=False)
474-
data_loader = torchdata.DataLoader(
477+
return torchdata.DataLoader(
475478
dataset,
479+
batch_size=1,
480+
sampler=sampler,
476481
num_workers=num_workers,
477-
batch_sampler=batch_sampler,
478482
collate_fn=trivial_batch_collator,
479483
)
480-
return data_loader
481484

482485

483486
def trivial_batch_collator(batch):

tests/data/test_dataset.py

+29-7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
MapDataset,
1616
ToIterableDataset,
1717
build_batch_data_loader,
18+
build_detection_test_loader,
1819
build_detection_train_loader,
1920
)
2021
from detectron2.data.samplers import InferenceSampler, TrainingSampler
@@ -82,25 +83,46 @@ def _get_kwargs(self):
8283
kwargs = {k: instantiate(v) for k, v in cfg.items()}
8384
return kwargs
8485

85-
def test_build_dataloader(self):
86+
def test_build_dataloader_train(self):
8687
kwargs = self._get_kwargs()
8788
dl = build_detection_train_loader(**kwargs)
8889
next(iter(dl))
8990

90-
def test_build_iterable_dataloader(self):
91+
def test_build_iterable_dataloader_train(self):
9192
kwargs = self._get_kwargs()
9293
ds = DatasetFromList(kwargs.pop("dataset"))
9394
ds = ToIterableDataset(ds, TrainingSampler(len(ds)))
9495
dl = build_detection_train_loader(dataset=ds, **kwargs)
9596
next(iter(dl))
9697

97-
def test_build_dataloader_inference(self):
98+
def _check_is_range(self, data_loader, N):
99+
# check that data_loader produces range(N)
100+
data = list(iter(data_loader))
101+
data = [x for batch in data for x in batch] # flatten the batches
102+
self.assertEqual(len(data), N)
103+
self.assertEqual(set(data), set(range(N)))
104+
105+
def test_build_batch_dataloader_inference(self):
106+
# Test that build_batch_data_loader can be used for inference
98107
N = 96
99108
ds = DatasetFromList(list(range(N)))
100109
sampler = InferenceSampler(len(ds))
101110
dl = build_batch_data_loader(ds, sampler, 8, num_workers=3)
111+
self._check_is_range(dl, N)
102112

103-
data = list(iter(dl))
104-
data = [x for batch in data for x in batch] # flatten the batches
105-
self.assertEqual(len(data), N)
106-
self.assertEqual(set(data), set(range(N)))
113+
def test_build_dataloader_inference(self):
114+
N = 50
115+
ds = DatasetFromList(list(range(N)))
116+
sampler = InferenceSampler(len(ds))
117+
dl = build_detection_test_loader(
118+
dataset=ds, sampler=sampler, mapper=lambda x: x, num_workers=3
119+
)
120+
self._check_is_range(dl, N)
121+
122+
def test_build_iterable_dataloader_inference(self):
123+
# Test that build_detection_test_loader supports iterable dataset
124+
N = 50
125+
ds = DatasetFromList(list(range(N)))
126+
ds = ToIterableDataset(ds, InferenceSampler(len(ds)))
127+
dl = build_detection_test_loader(dataset=ds, mapper=lambda x: x, num_workers=3)
128+
self._check_is_range(dl, N)

0 commit comments

Comments
 (0)