Skip to content

Commit cbbc1ce

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
add __len__ for ToIterableDataset
Reviewed By: zhanghang1989 Differential Revision: D30635073 fbshipit-source-id: 664ba17b768a69fed97ca94b944e43d186336966
1 parent ea3b3f2 commit cbbc1ce

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

detectron2/data/common.py

+3
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,9 @@ def __iter__(self):
200200
for idx in sampler:
201201
yield self.dataset[idx]
202202

203+
def __len__(self):
204+
return len(self.sampler)
205+
203206

204207
class AspectRatioGroupedDataset(data.IterableDataset):
205208
"""

tests/data/test_dataset.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
DatasetFromList,
1515
MapDataset,
1616
ToIterableDataset,
17+
build_batch_data_loader,
1718
build_detection_train_loader,
1819
)
19-
from detectron2.data.samplers import TrainingSampler
20+
from detectron2.data.samplers import InferenceSampler, TrainingSampler
2021

2122

2223
def _a_slow_func(x):
@@ -92,3 +93,14 @@ def test_build_iterable_dataloader(self):
9293
ds = ToIterableDataset(ds, TrainingSampler(len(ds)))
9394
dl = build_detection_train_loader(dataset=ds, **kwargs)
9495
next(iter(dl))
96+
97+
def test_build_dataloader_inference(self):
98+
N = 96
99+
ds = DatasetFromList(list(range(N)))
100+
sampler = InferenceSampler(len(ds))
101+
dl = build_batch_data_loader(ds, sampler, 8, num_workers=3)
102+
103+
data = list(iter(dl))
104+
data = [x for batch in data for x in batch] # flatten the batches
105+
self.assertEqual(len(data), N)
106+
self.assertEqual(set(data), set(range(N)))

0 commit comments

Comments
 (0)