1
1
from pathlib import Path
2
- from typing import Callable , Optional , Union
2
+ from typing import Any , Callable , Optional , Union
3
3
4
- from .folder import ImageFolder
4
+ from .folder import default_loader , ImageFolder
5
5
from .utils import download_and_extract_archive , verify_str_arg
6
6
7
7
@@ -21,6 +21,9 @@ class Country211(ImageFolder):
21
21
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
22
22
download (bool, optional): If True, downloads the dataset from the internet and puts it into
23
23
``root/country211/``. If dataset is already downloaded, it is not downloaded again.
24
+ loader (callable, optional): A function to load an image given its path.
25
+ By default, it uses PIL as its image loader, but users could also pass in
26
+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
24
27
"""
25
28
26
29
_URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz"
@@ -33,6 +36,7 @@ def __init__(
33
36
transform : Optional [Callable ] = None ,
34
37
target_transform : Optional [Callable ] = None ,
35
38
download : bool = False ,
39
+ loader : Callable [[str ], Any ] = default_loader ,
36
40
) -> None :
37
41
self ._split = verify_str_arg (split , "split" , ("train" , "valid" , "test" ))
38
42
@@ -46,7 +50,12 @@ def __init__(
46
50
if not self ._check_exists ():
47
51
raise RuntimeError ("Dataset not found. You can use download=True to download it" )
48
52
49
- super ().__init__ (str (self ._base_folder / self ._split ), transform = transform , target_transform = target_transform )
53
+ super ().__init__ (
54
+ str (self ._base_folder / self ._split ),
55
+ transform = transform ,
56
+ target_transform = target_transform ,
57
+ loader = loader ,
58
+ )
50
59
self .root = str (root )
51
60
52
61
def _check_exists (self ) -> bool :
0 commit comments