Skip to content

Commit 27b8491

Browse files
authored
only return small set of targets by default from dataset wrapper (#7488)
1 parent ce653d8 commit 27b8491

File tree

4 files changed

+223
-69
lines changed

4 files changed

+223
-69
lines changed

gallery/plot_transforms_v2_e2e.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,16 @@ def load_example_coco_detection_dataset(**kwargs):
7575
# :func:`~torchvision.datasets.wrap_dataset_for_transforms_v2` function. For
7676
# :class:`~torchvision.datasets.CocoDetection`, this changes the target structure to a single dictionary of lists. It
7777
# also adds the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` wrapped in the corresponding
78-
# ``torchvision.datapoints``.
78+
# ``torchvision.datapoints``. By default, it only returns ``"boxes"`` and ``"labels"`` to avoid transforming unnecessary
79+
# items down the line, but you can pass the ``target_type`` parameter for fine-grained control.
7980

8081
dataset = datasets.wrap_dataset_for_transforms_v2(dataset)
8182

8283
sample = dataset[0]
8384
image, target = sample
8485
print(type(image))
8586
print(type(target), list(target.keys()))
86-
print(type(target["boxes"]), type(target["masks"]), type(target["labels"]))
87+
print(type(target["boxes"]), type(target["labels"]))
8788

8889
########################################################################################################################
8990
# As baseline, let's have a look at a sample without transformations:

test/datasets_utils.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -572,9 +572,21 @@ def test_transforms_v2_wrapper(self, config):
572572

573573
try:
574574
with self.create_dataset(config) as (dataset, _):
575-
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
576-
wrapped_sample = wrapped_dataset[0]
577-
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
575+
for target_keys in [None, "all"]:
576+
if target_keys is not None and self.DATASET_CLASS not in {
577+
torchvision.datasets.CocoDetection,
578+
torchvision.datasets.VOCDetection,
579+
torchvision.datasets.Kitti,
580+
torchvision.datasets.WIDERFace,
581+
}:
582+
with self.assertRaisesRegex(ValueError, "`target_keys` is currently only supported for"):
583+
wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
584+
continue
585+
586+
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
587+
wrapped_sample = wrapped_dataset[0]
588+
589+
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
578590
except TypeError as error:
579591
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
580592
if str(error).startswith(msg):

test/test_datasets.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,8 @@ def _create_annotations(self, image_ids, num_annotations_per_image):
771771
bbox=torch.rand(4).tolist(),
772772
segmentation=[torch.rand(8).tolist()],
773773
category_id=int(torch.randint(91, ())),
774+
area=float(torch.rand(1)),
775+
iscrowd=int(torch.randint(2, size=(1,))),
774776
)
775777
)
776778
annotion_id += 1
@@ -3336,7 +3338,7 @@ def test_subclass(self, mocker):
33363338
mocker.patch.dict(
33373339
datapoints._dataset_wrapper.WRAPPER_FACTORIES,
33383340
clear=False,
3339-
values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel},
3341+
values={datasets.FakeData: lambda dataset, target_keys: lambda idx, sample: sentinel},
33403342
)
33413343

33423344
class MyFakeData(datasets.FakeData):

0 commit comments

Comments
 (0)