Skip to content

Commit 86f8eb0

Browse files
feat: expose loader argument in Country211 and EuroSAT. (#8922)
Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
1 parent ba94923 commit 86f8eb0

File tree

5 files changed

+49
-11
lines changed

5 files changed

+49
-11
lines changed

test/datasets_utils.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,7 @@ class ImageDatasetTestCase(DatasetTestCase):
611611
"""
612612

613613
FEATURE_TYPES = (PIL.Image.Image, int)
614+
SUPPORT_TV_IMAGE_DECODE: bool = False
614615

615616
@contextlib.contextmanager
616617
def create_dataset(
@@ -632,22 +633,34 @@ def create_dataset(
632633
# This problem only occurs during testing since some tests, e.g. DatasetTestCase.test_feature_types open an
633634
# image, but never use the underlying data. During normal operation it is reasonable to assume that the
634635
# user wants to work with the image he just opened rather than deleting the underlying file.
635-
with self._force_load_images():
636+
with self._force_load_images(loader=(config or {}).get("loader", None)):
636637
yield dataset, info
637638

638639
@contextlib.contextmanager
639-
def _force_load_images(self):
640-
open = PIL.Image.open
640+
def _force_load_images(self, loader: Optional[Callable[[str], Any]] = None):
641+
open = loader or PIL.Image.open
641642

642643
def new(fp, *args, **kwargs):
643644
image = open(fp, *args, **kwargs)
644-
if isinstance(fp, (str, pathlib.Path)):
645+
if isinstance(fp, (str, pathlib.Path)) and isinstance(image, PIL.Image.Image):
645646
image.load()
646647
return image
647648

648-
with unittest.mock.patch("PIL.Image.open", new=new):
649+
with unittest.mock.patch(open.__module__ + "." + open.__qualname__, new=new):
649650
yield
650651

652+
def test_tv_decode_image_support(self):
653+
if not self.SUPPORT_TV_IMAGE_DECODE:
654+
pytest.skip(f"{self.DATASET_CLASS.__name__} does not support torchvision.io.decode_image.")
655+
656+
with self.create_dataset(
657+
config=dict(
658+
loader=torchvision.io.decode_image,
659+
)
660+
) as (dataset, _):
661+
image = dataset[0][0]
662+
assert isinstance(image, torch.Tensor)
663+
651664

652665
class VideoDatasetTestCase(DatasetTestCase):
653666
"""Abstract base class for video dataset testcases.

test/test_datasets.py

+5
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,8 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
405405
REQUIRED_PACKAGES = ("scipy",)
406406
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val"))
407407

408+
SUPPORT_TV_IMAGE_DECODE = True
409+
408410
def inject_fake_data(self, tmpdir, config):
409411
tmpdir = pathlib.Path(tmpdir)
410412

@@ -2308,6 +2310,7 @@ def inject_fake_data(self, tmpdir, config):
23082310
class EuroSATTestCase(datasets_utils.ImageDatasetTestCase):
23092311
DATASET_CLASS = datasets.EuroSAT
23102312
FEATURE_TYPES = (PIL.Image.Image, int)
2313+
SUPPORT_TV_IMAGE_DECODE = True
23112314

23122315
def inject_fake_data(self, tmpdir, config):
23132316
data_folder = os.path.join(tmpdir, "eurosat", "2750")
@@ -2749,6 +2752,8 @@ class Country211TestCase(datasets_utils.ImageDatasetTestCase):
27492752

27502753
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "valid", "test"))
27512754

2755+
SUPPORT_TV_IMAGE_DECODE = True
2756+
27522757
def inject_fake_data(self, tmpdir: str, config):
27532758
split_folder = pathlib.Path(tmpdir) / "country211" / config["split"]
27542759
split_folder.mkdir(parents=True, exist_ok=True)

torchvision/datasets/country211.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pathlib import Path
2-
from typing import Callable, Optional, Union
2+
from typing import Any, Callable, Optional, Union
33

4-
from .folder import ImageFolder
4+
from .folder import default_loader, ImageFolder
55
from .utils import download_and_extract_archive, verify_str_arg
66

77

@@ -21,6 +21,9 @@ class Country211(ImageFolder):
2121
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
2222
download (bool, optional): If True, downloads the dataset from the internet and puts it into
2323
``root/country211/``. If dataset is already downloaded, it is not downloaded again.
24+
loader (callable, optional): A function to load an image given its path.
25+
By default, it uses PIL as its image loader, but users could also pass in
26+
``torchvision.io.decode_image`` for decoding image data into tensors directly.
2427
"""
2528

2629
_URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz"
@@ -33,6 +36,7 @@ def __init__(
3336
transform: Optional[Callable] = None,
3437
target_transform: Optional[Callable] = None,
3538
download: bool = False,
39+
loader: Callable[[str], Any] = default_loader,
3640
) -> None:
3741
self._split = verify_str_arg(split, "split", ("train", "valid", "test"))
3842

@@ -46,7 +50,12 @@ def __init__(
4650
if not self._check_exists():
4751
raise RuntimeError("Dataset not found. You can use download=True to download it")
4852

49-
super().__init__(str(self._base_folder / self._split), transform=transform, target_transform=target_transform)
53+
super().__init__(
54+
str(self._base_folder / self._split),
55+
transform=transform,
56+
target_transform=target_transform,
57+
loader=loader,
58+
)
5059
self.root = str(root)
5160

5261
def _check_exists(self) -> bool:

torchvision/datasets/eurosat.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import os
22
from pathlib import Path
3-
from typing import Callable, Optional, Union
3+
from typing import Any, Callable, Optional, Union
44

5-
from .folder import ImageFolder
5+
from .folder import default_loader, ImageFolder
66
from .utils import download_and_extract_archive
77

88

@@ -21,6 +21,9 @@ class EuroSAT(ImageFolder):
2121
download (bool, optional): If True, downloads the dataset from the internet and
2222
puts it in root directory. If dataset is already downloaded, it is not
2323
downloaded again. Default is False.
24+
loader (callable, optional): A function to load an image given its path.
25+
By default, it uses PIL as its image loader, but users could also pass in
26+
``torchvision.io.decode_image`` for decoding image data into tensors directly.
2427
"""
2528

2629
def __init__(
@@ -29,6 +32,7 @@ def __init__(
2932
transform: Optional[Callable] = None,
3033
target_transform: Optional[Callable] = None,
3134
download: bool = False,
35+
loader: Callable[[str], Any] = default_loader,
3236
) -> None:
3337
self.root = os.path.expanduser(root)
3438
self._base_folder = os.path.join(self.root, "eurosat")
@@ -40,7 +44,12 @@ def __init__(
4044
if not self._check_exists():
4145
raise RuntimeError("Dataset not found. You can use download=True to download it")
4246

43-
super().__init__(self._data_folder, transform=transform, target_transform=target_transform)
47+
super().__init__(
48+
self._data_folder,
49+
transform=transform,
50+
target_transform=target_transform,
51+
loader=loader,
52+
)
4453
self.root = os.path.expanduser(root)
4554

4655
def __len__(self) -> int:

torchvision/datasets/imagenet.py

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class ImageNet(ImageFolder):
3636
target_transform (callable, optional): A function/transform that takes in the
3737
target and transforms it.
3838
loader (callable, optional): A function to load an image given its path.
39+
By default, it uses PIL as its image loader, but users could also pass in
40+
``torchvision.io.decode_image`` for decoding image data into tensors directly.
3941
4042
Attributes:
4143
classes (list): List of the class name tuples.

0 commit comments

Comments
 (0)