Skip to content

Commit 7be5147

Browse files
author
Turkka Helinoja
committed
WIP: Add custom dataset class to load single images and videos
1 parent e4a6fc7 commit 7be5147

File tree

3 files changed

+87
-13
lines changed

3 files changed

+87
-13
lines changed

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ json_tricks
99
scikit-image
1010
yacs>=0.1.5
1111
tensorboardX
12+
Pillow

tools/custom_dataset.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import queue
2+
import pathlib
3+
4+
from torch.utils.data import Dataset, IterableDataset
5+
import torchvision.transforms
6+
import torchvision.transforms.functional as TF
7+
8+
from PIL import Image
9+
10+
class CustomDataset(Dataset):
11+
"""Custom Dataset for handling single images and videos.
12+
13+
Custom torch.utils.data.Dataset class for handling single images and videos
14+
and creating Dataset objects from them.
15+
"""
16+
def __init__(self, path: str, transform: torchvision.transforms = None) -> None:
17+
"""Initializes CustomDataset
18+
19+
Args:
20+
path (str): Path to the image or video
21+
transform (torchvision.transforms, optional):
22+
Transfrom that's used for transforming the images to tensors.
23+
Defaults to None.
24+
"""
25+
super().__init__()
26+
self.path = path
27+
self.transform = transform
28+
self.queue = self.__load_data(self.path)
29+
30+
31+
def __load_data(self, path: str) -> queue.Queue:
32+
"""Loads the data from the path. Data can be image or video.
33+
TODO: Implement video support.
34+
35+
Args:
36+
path (str): Path to the image or video
37+
38+
Returns:
39+
queue.Queue: Queue which has the images loaded
40+
"""
41+
buffer = queue.Queue()
42+
file_extension = pathlib.Path(path).suffix
43+
if file_extension == ".jpg" or file_extension == ".png":
44+
image = Image.open(path)
45+
#buffer.put(TF.to_tensor(image))
46+
buffer.put(image)
47+
return buffer
48+
49+
def __read_next_image(self):
50+
while self.queue.qsize() > 0:
51+
image = self.queue.get()
52+
if self.transform is not None:
53+
image = self.transform(image)
54+
yield image
55+
56+
return None
57+
58+
def __iter__(self):
59+
return self.__read_next_image()
60+
61+
def __len__(self):
62+
return self.queue.qsize()
63+
64+
def __getitem__(self, x):
65+
image = self.queue.get()
66+
if self.transform is not None:
67+
image = self.transform(image)
68+
69+
return image, None

tools/run_and_visualize.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from core.function import validate
2222
from utils.utils import create_logger
2323

24+
from custom_dataset import CustomDataset
25+
2426
import dataset
2527
import models
2628

@@ -113,11 +115,13 @@ def main():
113115

114116
# Load data
115117
if args.input != "":
116-
with open(args.input, "r") as image:
117118
# TODO: Write a way to handle single images
118119
# TODO: Handle visualization
119-
pass
120-
elif args.video:
120+
valid_dataset = CustomDataset(args.input, transforms.Compose([
121+
transforms.ToTensor(),
122+
normalize,
123+
]))
124+
elif args.video != "":
121125
# TODO: Write a way to handle videos image by image
122126
# TODO: Handle visualization
123127
pass
@@ -131,17 +135,17 @@ def main():
131135
])
132136
)
133137

134-
valid_loader = torch.utils.data.DataLoader(
135-
valid_dataset,
136-
batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
137-
shuffle=False,
138-
num_workers=cfg.WORKERS,
139-
pin_memory=True
140-
)
138+
valid_loader = torch.utils.data.DataLoader(
139+
valid_dataset,
140+
batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
141+
shuffle=False,
142+
num_workers=cfg.WORKERS,
143+
pin_memory=True
144+
)
141145

142-
# evaluate on validation set
143-
validate(cfg, valid_loader, valid_dataset, model, criterion,
144-
final_output_dir, tb_log_dir)
146+
# evaluate on validation set
147+
validate(cfg, valid_loader, valid_dataset, model, criterion,
148+
final_output_dir, tb_log_dir)
145149

146150
if __name__ == '__main__':
147151
main()

0 commit comments

Comments
 (0)