|
15 | 15 | MapDataset,
|
16 | 16 | ToIterableDataset,
|
17 | 17 | build_batch_data_loader,
|
| 18 | + build_detection_test_loader, |
18 | 19 | build_detection_train_loader,
|
19 | 20 | )
|
20 | 21 | from detectron2.data.samplers import InferenceSampler, TrainingSampler
|
@@ -82,25 +83,46 @@ def _get_kwargs(self):
|
82 | 83 | kwargs = {k: instantiate(v) for k, v in cfg.items()}
|
83 | 84 | return kwargs
|
84 | 85 |
|
85 |
| - def test_build_dataloader(self): |
| 86 | + def test_build_dataloader_train(self): |
86 | 87 | kwargs = self._get_kwargs()
|
87 | 88 | dl = build_detection_train_loader(**kwargs)
|
88 | 89 | next(iter(dl))
|
89 | 90 |
|
90 |
| - def test_build_iterable_dataloader(self): |
| 91 | + def test_build_iterable_dataloader_train(self): |
91 | 92 | kwargs = self._get_kwargs()
|
92 | 93 | ds = DatasetFromList(kwargs.pop("dataset"))
|
93 | 94 | ds = ToIterableDataset(ds, TrainingSampler(len(ds)))
|
94 | 95 | dl = build_detection_train_loader(dataset=ds, **kwargs)
|
95 | 96 | next(iter(dl))
|
96 | 97 |
|
97 |
| - def test_build_dataloader_inference(self): |
| 98 | + def _check_is_range(self, data_loader, N): |
| 99 | + # check that data_loader produces range(N) |
| 100 | + data = list(iter(data_loader)) |
| 101 | + data = [x for batch in data for x in batch] # flatten the batches |
| 102 | + self.assertEqual(len(data), N) |
| 103 | + self.assertEqual(set(data), set(range(N))) |
| 104 | + |
| 105 | + def test_build_batch_dataloader_inference(self): |
| 106 | + # Test that build_batch_data_loader can be used for inference |
98 | 107 | N = 96
|
99 | 108 | ds = DatasetFromList(list(range(N)))
|
100 | 109 | sampler = InferenceSampler(len(ds))
|
101 | 110 | dl = build_batch_data_loader(ds, sampler, 8, num_workers=3)
|
| 111 | + self._check_is_range(dl, N) |
102 | 112 |
|
103 |
| - data = list(iter(dl)) |
104 |
| - data = [x for batch in data for x in batch] # flatten the batches |
105 |
| - self.assertEqual(len(data), N) |
106 |
| - self.assertEqual(set(data), set(range(N))) |
| 113 | + def test_build_dataloader_inference(self): |
| 114 | + N = 50 |
| 115 | + ds = DatasetFromList(list(range(N))) |
| 116 | + sampler = InferenceSampler(len(ds)) |
| 117 | + dl = build_detection_test_loader( |
| 118 | + dataset=ds, sampler=sampler, mapper=lambda x: x, num_workers=3 |
| 119 | + ) |
| 120 | + self._check_is_range(dl, N) |
| 121 | + |
| 122 | + def test_build_iterable_dataloader_inference(self): |
| 123 | + # Test that build_detection_test_loader supports iterable dataset |
| 124 | + N = 50 |
| 125 | + ds = DatasetFromList(list(range(N))) |
| 126 | + ds = ToIterableDataset(ds, InferenceSampler(len(ds))) |
| 127 | + dl = build_detection_test_loader(dataset=ds, mapper=lambda x: x, num_workers=3) |
| 128 | + self._check_is_range(dl, N) |
0 commit comments