forked from leoxiaobin/deep-high-resolution-net.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_dataset.py
69 lines (54 loc) · 2 KB
/
custom_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import queue
import pathlib
from torch.utils.data import Dataset, IterableDataset
import torchvision.transforms
import torchvision.transforms.functional as TF
from PIL import Image
class CustomDataset(Dataset):
"""Custom Dataset for handling single images and videos.
Custom torch.utils.data.Dataset class for handling single images and videos
and creating Dataset objects from them.
"""
def __init__(self, path: str, transform: torchvision.transforms = None) -> None:
"""Initializes CustomDataset
Args:
path (str): Path to the image or video
transform (torchvision.transforms, optional):
Transfrom that's used for transforming the images to tensors.
Defaults to None.
"""
super().__init__()
self.path = path
self.transform = transform
self.queue = self.__load_data(self.path)
def __load_data(self, path: str) -> queue.Queue:
"""Loads the data from the path. Data can be image or video.
TODO: Implement video support.
Args:
path (str): Path to the image or video
Returns:
queue.Queue: Queue which has the images loaded
"""
buffer = queue.Queue()
file_extension = pathlib.Path(path).suffix
if file_extension == ".jpg" or file_extension == ".png":
image = Image.open(path)
#buffer.put(TF.to_tensor(image))
buffer.put(image)
return buffer
def __read_next_image(self):
while self.queue.qsize() > 0:
image = self.queue.get()
if self.transform is not None:
image = self.transform(image)
yield image
return None
def __iter__(self):
return self.__read_next_image()
def __len__(self):
return self.queue.qsize()
def __getitem__(self, x):
image = self.queue.get()
if self.transform is not None:
image = self.transform(image)
return image, None