-
Notifications
You must be signed in to change notification settings - Fork 478
/
Copy pathutil.py
executable file
·83 lines (66 loc) · 2.58 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import torch
import torchvision
import random
import numpy as np
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def get_paths_from_images(path):
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
images = []
for dirpath, _, fnames in sorted(os.walk(path)):
for fname in sorted(fnames):
if is_image_file(fname):
img_path = os.path.join(dirpath, fname)
images.append(img_path)
assert images, '{:s} has no valid image file'.format(path)
return sorted(images)
def augment(img_list, hflip=True, rot=True, split='val'):
# horizontal flip OR rotate
hflip = hflip and (split == 'train' and random.random() < 0.5)
vflip = rot and (split == 'train' and random.random() < 0.5)
rot90 = rot and (split == 'train' and random.random() < 0.5)
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(img) for img in img_list]
def transform2numpy(img):
img = np.array(img)
img = img.astype(np.float32) / 255.
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return img
def transform2tensor(img, min_max=(0, 1)):
# HWC to CHW
img = torch.from_numpy(np.ascontiguousarray(
np.transpose(img, (2, 0, 1)))).float()
# to range min_max
img = img*(min_max[1] - min_max[0]) + min_max[0]
return img
# implementation by numpy and torch
# def transform_augment(img_list, split='val', min_max=(0, 1)):
# imgs = [transform2numpy(img) for img in img_list]
# imgs = augment(imgs, split=split)
# ret_img = [transform2tensor(img, min_max) for img in imgs]
# return ret_img
# implementation by torchvision, detail in https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement/issues/14
totensor = torchvision.transforms.ToTensor()
hflip = torchvision.transforms.RandomHorizontalFlip()
def transform_augment(img_list, split='val', min_max=(0, 1)):
imgs = [totensor(img) for img in img_list]
if split == 'train':
imgs = torch.stack(imgs, 0)
imgs = hflip(imgs)
imgs = torch.unbind(imgs, dim=0)
ret_img = [img * (min_max[1] - min_max[0]) + min_max[0] for img in imgs]
return ret_img