-
Notifications
You must be signed in to change notification settings - Fork 7k
/
Copy pathomniglot.py
107 lines (87 loc) · 4.39 KB
/
omniglot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from os.path import join
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union
from PIL import Image
from .utils import check_integrity, download_and_extract_archive, list_dir, list_files
from .vision import VisionDataset
class Omniglot(VisionDataset):
"""`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset.
Args:
root (str or ``pathlib.Path``): Root directory of dataset where directory
``omniglot-py`` exists.
background (bool, optional): If True, creates dataset from the "background" set, otherwise
creates from the "evaluation" set. This terminology is defined by the authors.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset zip files from the internet and
puts it in root directory. If the zip files are already downloaded, they are not
downloaded again.
loader (callable, optional): A function to load an image given its path.
By default, it uses PIL as its image loader, but users could also pass in
``torchvision.io.decode_image`` for decoding image data into tensors directly.
"""
folder = "omniglot-py"
download_url_prefix = "https://raw.githubusercontent.com/brendenlake/omniglot/master/python"
zips_md5 = {
"images_background": "68d2efa1b9178cc56df9314c21c6e718",
"images_evaluation": "6b91aef0f799c5bb55b94e3f2daec811",
}
def __init__(
self,
root: Union[str, Path],
background: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
loader: Optional[Callable[[Union[str, Path]], Any]] = None,
) -> None:
super().__init__(join(root, self.folder), transform=transform, target_transform=target_transform)
self.background = background
if download:
self.download()
if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
self.target_folder = join(self.root, self._get_target_folder())
self._alphabets = list_dir(self.target_folder)
self._characters: List[str] = sum(
([join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets), []
)
self._character_images = [
[(image, idx) for image in list_files(join(self.target_folder, character), ".png")]
for idx, character in enumerate(self._characters)
]
self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, [])
self.loader = loader
def __len__(self) -> int:
return len(self._flat_character_images)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target character class.
"""
image_name, character_class = self._flat_character_images[index]
image_path = join(self.target_folder, self._characters[character_class], image_name)
image = Image.open(image_path, mode="r").convert("L") if self.loader is None else self.loader(image_path)
if self.transform:
image = self.transform(image)
if self.target_transform:
character_class = self.target_transform(character_class)
return image, character_class
def _check_integrity(self) -> bool:
zip_filename = self._get_target_folder()
if not check_integrity(join(self.root, zip_filename + ".zip"), self.zips_md5[zip_filename]):
return False
return True
def download(self) -> None:
if self._check_integrity():
return
filename = self._get_target_folder()
zip_filename = filename + ".zip"
url = self.download_url_prefix + "/" + zip_filename
download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename])
def _get_target_folder(self) -> str:
return "images_background" if self.background else "images_evaluation"