-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathdatasets.py
97 lines (82 loc) · 3.15 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.datasets import CIFAR10, CIFAR100
class TinyImageNet(Dataset):
def __init__(self, root, train=True, transform=None):
if not root.endswith("tiny-imagenet-200"):
root = os.path.join(root, "tiny-imagenet-200")
self.train_dir = os.path.join(root, "train")
self.val_dir = os.path.join(root, "val")
self.transform = transform
if train:
self._scan_train()
else:
self._scan_val()
def _scan_train(self):
classes = [d.name for d in os.scandir(self.train_dir) if d.is_dir()]
classes = sorted(classes)
assert len(classes) == 200
self.data = []
for idx, name in enumerate(classes):
this_dir = os.path.join(self.train_dir, name)
for root, _, files in sorted(os.walk(this_dir)):
for fname in sorted(files):
if fname.endswith(".JPEG"):
path = os.path.join(root, fname)
item = (path, idx)
self.data.append(item)
self.labels_dict = {i: classes[i] for i in range(len(classes))}
def _scan_val(self):
self.file_to_class = {}
classes = set()
with open(os.path.join(self.val_dir, "val_annotations.txt"), 'r') as f:
lines = f.readlines()
for line in lines:
words = line.split("\t")
self.file_to_class[words[0]] = words[1]
classes.add(words[1])
classes = sorted(list(classes))
assert len(classes) == 200
class_to_idx = {classes[i]: i for i in range(len(classes))}
self.data = []
this_dir = os.path.join(self.val_dir, "images")
for root, _, files in sorted(os.walk(this_dir)):
for fname in sorted(files):
if fname.endswith(".JPEG"):
path = os.path.join(root, fname)
idx = class_to_idx[self.file_to_class[fname]]
item = (path, idx)
self.data.append(item)
self.labels_dict = {i: classes[i] for i in range(len(classes))}
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
path, label = self.data[idx]
image = Image.open(path)
image = image.convert("RGB")
if self.transform:
image = self.transform(image)
return image, label
def get_dataset(name, root="./data", train=True, flip=False, crop=False, resize=None):
if name == 'cifar':
DATASET = CIFAR10
RES = 32
elif name == 'cifar100':
DATASET = CIFAR100
RES = 32
elif name == 'tiny':
DATASET = TinyImageNet
RES = 64
else:
raise NotImplementedError
tf = [transforms.ToTensor()]
if resize is not None:
tf = [transforms.Resize(resize)] + tf
if train:
if crop:
tf = [transforms.RandomCrop(RES, 4)] + tf
if flip:
tf = [transforms.RandomHorizontalFlip()] + tf
return DATASET(root=root, train=train, transform=transforms.Compose(tf))