Skip to content

Commit b80bdb7

Browse files
authored
Fix v2 transforms in spawn mp context (#8067)
1 parent 96d2ce9 commit b80bdb7

File tree

3 files changed

+70
-40
lines changed

3 files changed

+70
-40
lines changed

test/datasets_utils.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@
2727
import torchvision.io
2828
from common_utils import disable_console_output, get_tmp_dir
2929
from torch.utils._pytree import tree_any
30+
from torch.utils.data import DataLoader
31+
from torchvision import tv_tensors
32+
from torchvision.datasets import wrap_dataset_for_transforms_v2
3033
from torchvision.transforms.functional import get_dimensions
34+
from torchvision.transforms.v2.functional import get_size
3135

3236

3337
__all__ = [
@@ -568,9 +572,6 @@ def test_transforms(self, config):
568572

569573
@test_all_configs
570574
def test_transforms_v2_wrapper(self, config):
571-
from torchvision import tv_tensors
572-
from torchvision.datasets import wrap_dataset_for_transforms_v2
573-
574575
try:
575576
with self.create_dataset(config) as (dataset, info):
576577
for target_keys in [None, "all"]:
@@ -709,26 +710,29 @@ def _no_collate(batch):
709710
return batch
710711

711712

712-
def check_transforms_v2_wrapper_spawn(dataset):
713-
# On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new
714-
# subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what
715-
# we are enforcing here.
716-
if platform.system() != "Darwin":
717-
pytest.skip("Multiprocessing spawning is only checked on macOS.")
713+
def check_transforms_v2_wrapper_spawn(dataset, expected_size):
714+
# This check ensures that the wrapped datasets can be used with multiprocessing_context="spawn" in the DataLoader.
715+
# We also check that transforms are applied correctly as a non-regression test for
716+
# https://github.com/pytorch/vision/issues/8066
717+
# Implicitly, this also checks that the wrapped datasets are pickleable.
718718

719-
from torch.utils.data import DataLoader
720-
from torchvision import tv_tensors
721-
from torchvision.datasets import wrap_dataset_for_transforms_v2
719+
# To save CI/test time, we only check on Windows where "spawn" is the default
720+
if platform.system() != "Windows":
721+
pytest.skip("Multiprocessing spawning is only checked on macOS.")
722722

723723
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
724724

725725
dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate)
726726

727-
for wrapped_sample in dataloader:
728-
assert tree_any(
729-
lambda item: isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)), wrapped_sample
727+
def resize_was_applied(item):
728+
# Checking the size of the output ensures that the Resize transform was correctly applied
729+
return isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)) and get_size(item) == list(
730+
expected_size
730731
)
731732

733+
for wrapped_sample in dataloader:
734+
assert tree_any(resize_was_applied, wrapped_sample)
735+
732736

733737
def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
734738
r"""Create a random uint8 tensor.

test/test_datasets.py

+38-24
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import torch.nn.functional as F
2525
from common_utils import combinations_grid
2626
from torchvision import datasets
27+
from torchvision.transforms import v2
2728

2829

2930
class STL10TestCase(datasets_utils.ImageDatasetTestCase):
@@ -184,8 +185,9 @@ def test_combined_targets(self):
184185
f"{actual} is not {expected}",
185186

186187
def test_transforms_v2_wrapper_spawn(self):
187-
with self.create_dataset(target_type="category") as (dataset, _):
188-
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
188+
expected_size = (123, 321)
189+
with self.create_dataset(target_type="category", transform=v2.Resize(size=expected_size)) as (dataset, _):
190+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
189191

190192

191193
class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
@@ -263,8 +265,9 @@ def inject_fake_data(self, tmpdir, config):
263265
return split_to_num_examples[config["split"]]
264266

265267
def test_transforms_v2_wrapper_spawn(self):
266-
with self.create_dataset() as (dataset, _):
267-
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
268+
expected_size = (123, 321)
269+
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
270+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
268271

269272

270273
class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
@@ -391,9 +394,10 @@ def test_feature_types_target_polygon(self):
391394
(polygon_target, info["expected_polygon_target"])
392395

393396
def test_transforms_v2_wrapper_spawn(self):
397+
expected_size = (123, 321)
394398
for target_type in ["instance", "semantic", ["instance", "semantic"]]:
395-
with self.create_dataset(target_type=target_type) as (dataset, _):
396-
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
399+
with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _):
400+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
397401

398402

399403
class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
@@ -427,8 +431,9 @@ def inject_fake_data(self, tmpdir, config):
427431
return num_examples
428432

429433
def test_transforms_v2_wrapper_spawn(self):
430-
with self.create_dataset() as (dataset, _):
431-
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
434+
expected_size = (123, 321)
435+
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
436+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
432437

433438

434439
class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
@@ -625,9 +630,10 @@ def test_images_names_split(self):
625630
assert merged_imgs_names == all_imgs_names
626631

627632
def test_transforms_v2_wrapper_spawn(self):
633+
expected_size = (123, 321)
628634
for target_type in ["identity", "bbox", ["identity", "bbox"]]:
629-
with self.create_dataset(target_type=target_type) as (dataset, _):
630-
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
635+
with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _):
636+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
631637

632638

633639
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
@@ -717,8 +723,9 @@ def add_bndbox(obj, bndbox=None):
717723
return data
718724

719725
def test_transforms_v2_wrapper_spawn(self):
720-
with self.create_dataset() as (dataset, _):
721-
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
726+
expected_size = (123, 321)
727+
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
728+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
722729

723730

724731
class VOCDetectionTestCase(VOCSegmentationTestCase):
@@ -741,8 +748,9 @@ def test_annotations(self):
741748
assert object == info["annotation"]
742749

743750
def test_transforms_v2_wrapper_spawn(self):
744-
with self.create_dataset() as (dataset, _):
745-
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
751+
expected_size = (123, 321)
752+
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
753+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
746754

747755

748756
class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
@@ -815,8 +823,9 @@ def _create_json(self, root, name, content):
815823
return file
816824

817825
def test_transforms_v2_wrapper_spawn(self):
818-
with self.create_dataset() as (dataset, _):
819-
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
826+
expected_size = (123, 321)
827+
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
828+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
820829

821830

822831
class CocoCaptionsTestCase(CocoDetectionTestCase):
@@ -1005,9 +1014,11 @@ def inject_fake_data(self, tmpdir, config):
10051014
)
10061015
return num_videos_per_class * len(classes)
10071016

1017+
@pytest.mark.xfail(reason="FIXME")
10081018
def test_transforms_v2_wrapper_spawn(self):
1009-
with self.create_dataset(output_format="TCHW") as (dataset, _):
1010-
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
1019+
expected_size = (123, 321)
1020+
with self.create_dataset(output_format="TCHW", transform=v2.Resize(size=expected_size)) as (dataset, _):
1021+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
10111022

10121023

10131024
class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
@@ -1237,8 +1248,9 @@ def _file_stem(self, idx):
12371248
return f"2008_{idx:06d}"
12381249

12391250
def test_transforms_v2_wrapper_spawn(self):
1240-
with self.create_dataset(mode="segmentation") as (dataset, _):
1241-
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
1251+
expected_size = (123, 321)
1252+
with self.create_dataset(mode="segmentation", transforms=v2.Resize(size=expected_size)) as (dataset, _):
1253+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
12421254

12431255

12441256
class FakeDataTestCase(datasets_utils.ImageDatasetTestCase):
@@ -1690,8 +1702,9 @@ def inject_fake_data(self, tmpdir, config):
16901702
return split_to_num_examples[config["train"]]
16911703

16921704
def test_transforms_v2_wrapper_spawn(self):
1693-
with self.create_dataset() as (dataset, _):
1694-
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
1705+
expected_size = (123, 321)
1706+
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
1707+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
16951708

16961709

16971710
class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
@@ -2568,8 +2581,9 @@ def _meta_to_split_and_classification_ann(self, meta, idx):
25682581
return (image_id, class_id, species, breed_id)
25692582

25702583
def test_transforms_v2_wrapper_spawn(self):
2571-
with self.create_dataset() as (dataset, _):
2572-
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
2584+
expected_size = (123, 321)
2585+
with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
2586+
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
25732587

25742588

25752589
class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):

torchvision/tv_tensors/_dataset_wrapper.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import contextlib
88
from collections import defaultdict
9+
from copy import copy
910

1011
import torch
1112

@@ -198,8 +199,19 @@ def __getitem__(self, idx):
198199
def __len__(self):
199200
return len(self._dataset)
200201

202+
# TODO: maybe we should use __getstate__ and __setstate__ instead of __reduce__, as recommended in the docs.
201203
def __reduce__(self):
202-
return wrap_dataset_for_transforms_v2, (self._dataset, self._target_keys)
204+
# __reduce__ gets called when we try to pickle the dataset.
205+
# In a DataLoader with spawn context, this gets called `num_workers` times from the main process.
206+
207+
# We have to reset the [target_]transform[s] attributes of the dataset
208+
# to their original values, because we previously set them to None in __init__().
209+
dataset = copy(self._dataset)
210+
dataset.transform = self.transform
211+
dataset.transforms = self.transforms
212+
dataset.target_transform = self.target_transform
213+
214+
return wrap_dataset_for_transforms_v2, (dataset, self._target_keys)
203215

204216

205217
def raise_not_supported(description):

0 commit comments

Comments
 (0)