4
4
import numpy as np
5
5
import operator
6
6
import pickle
7
- import torch .utils .data
7
+ import torch
8
+ import torch .utils .data as torchdata
8
9
from tabulate import tabulate
9
10
from termcolor import colored
10
11
16
17
from detectron2 .utils .logger import _log_api_usage , log_first_n
17
18
18
19
from .catalog import DatasetCatalog , MetadataCatalog
19
- from .common import AspectRatioGroupedDataset , DatasetFromList , MapDataset
20
+ from .common import AspectRatioGroupedDataset , DatasetFromList , MapDataset , ToIterableDataset
20
21
from .dataset_mapper import DatasetMapper
21
22
from .detection_utils import check_metadata_consistency
22
23
from .samplers import (
@@ -270,8 +271,9 @@ def build_batch_data_loader(
270
271
2. use no "batch collation", because this is common for detection training
271
272
272
273
Args:
273
- dataset (torch.utils.data.Dataset): map-style PyTorch dataset. Can be indexed.
274
- sampler (torch.utils.data.sampler.Sampler): a sampler that produces indices
274
+ dataset (torch.utils.data.Dataset): a pytorch map-style or iterable dataset.
275
+ sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices.
276
+ Must be provided iff. ``dataset`` is a map-style dataset.
275
277
total_batch_size, aspect_ratio_grouping, num_workers): see
276
278
:func:`build_detection_train_loader`.
277
279
@@ -285,26 +287,27 @@ def build_batch_data_loader(
285
287
), "Total batch size ({}) must be divisible by the number of gpus ({})." .format (
286
288
total_batch_size , world_size
287
289
)
288
-
289
290
batch_size = total_batch_size // world_size
291
+
292
+ if isinstance (dataset , torchdata .IterableDataset ):
293
+ assert sampler is None , "sampler must be None if dataset is IterableDataset"
294
+ else :
295
+ dataset = ToIterableDataset (dataset , sampler )
296
+
290
297
if aspect_ratio_grouping :
291
- data_loader = torch . utils . data .DataLoader (
298
+ data_loader = torchdata .DataLoader (
292
299
dataset ,
293
- sampler = sampler ,
294
300
num_workers = num_workers ,
295
- batch_sampler = None ,
296
301
collate_fn = operator .itemgetter (0 ), # don't batch, but yield individual elements
297
302
worker_init_fn = worker_init_reset_seed ,
298
303
) # yield individual mapped dict
299
304
return AspectRatioGroupedDataset (data_loader , batch_size )
300
305
else :
301
- batch_sampler = torch .utils .data .sampler .BatchSampler (
302
- sampler , batch_size , drop_last = True
303
- ) # drop_last so the batch always have the same size
304
- return torch .utils .data .DataLoader (
306
+ return torchdata .DataLoader (
305
307
dataset ,
308
+ batch_size = batch_size ,
309
+ drop_last = True ,
306
310
num_workers = num_workers ,
307
- batch_sampler = batch_sampler ,
308
311
collate_fn = trivial_batch_collator ,
309
312
worker_init_fn = worker_init_reset_seed ,
310
313
)
@@ -351,7 +354,6 @@ def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
351
354
}
352
355
353
356
354
- # TODO can allow dataset as an iterable or IterableDataset to make this function more general
355
357
@configurable (from_config = _train_loader_from_config )
356
358
def build_detection_train_loader (
357
359
dataset , * , mapper , sampler = None , total_batch_size , aspect_ratio_grouping = True , num_workers = 0
@@ -362,14 +364,16 @@ def build_detection_train_loader(
362
364
363
365
Args:
364
366
dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
365
- or a map-style pytorch dataset. They can be obtained by using
366
- :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
367
+ or a pytorch dataset (either map-style or iterable). It can be obtained
368
+ by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
367
369
mapper (callable): a callable which takes a sample (dict) from dataset and
368
370
returns the format to be consumed by the model.
369
371
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
370
372
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
371
- indices to be applied on ``dataset``. Default to :class:`TrainingSampler`,
373
+ indices to be applied on ``dataset``.
374
+ If ``dataset`` is map-style, the default sampler is a :class:`TrainingSampler`,
372
375
which coordinates an infinite random shuffle sequence across all workers.
376
+ Sampler must be None if ``dataset`` is iterable.
373
377
total_batch_size (int): total batch size across all workers. Batching
374
378
simply puts data into a list.
375
379
aspect_ratio_grouping (bool): whether to group images with similar
@@ -387,9 +391,13 @@ def build_detection_train_loader(
387
391
dataset = DatasetFromList (dataset , copy = False )
388
392
if mapper is not None :
389
393
dataset = MapDataset (dataset , mapper )
390
- if sampler is None :
391
- sampler = TrainingSampler (len (dataset ))
392
- assert isinstance (sampler , torch .utils .data .sampler .Sampler )
394
+
395
+ if isinstance (dataset , torchdata .IterableDataset ):
396
+ assert sampler is None , "sampler must be None if dataset is IterableDataset"
397
+ else :
398
+ if sampler is None :
399
+ sampler = TrainingSampler (len (dataset ))
400
+ assert isinstance (sampler , torchdata .Sampler ), f"Expect a Sampler but got { type (sampler )} "
393
401
return build_batch_data_loader (
394
402
dataset ,
395
403
sampler ,
@@ -462,8 +470,8 @@ def build_detection_test_loader(dataset, *, mapper, sampler=None, num_workers=0)
462
470
sampler = InferenceSampler (len (dataset ))
463
471
# Always use 1 image per worker during inference since this is the
464
472
# standard when reporting inference time in papers.
465
- batch_sampler = torch . utils . data .sampler .BatchSampler (sampler , 1 , drop_last = False )
466
- data_loader = torch . utils . data .DataLoader (
473
+ batch_sampler = torchdata .sampler .BatchSampler (sampler , 1 , drop_last = False )
474
+ data_loader = torchdata .DataLoader (
467
475
dataset ,
468
476
num_workers = num_workers ,
469
477
batch_sampler = batch_sampler ,
0 commit comments