Skip to content

Commit cedf2ca

Browse files
authored
Updates (XPixelGroup#300)
* update resume pretrained paths * test_scripts * add dist util * use os * add matlab functions * scandir and bgr2rgb replace * cv2.flip * replace imwrite * update file client * update utils.util * update utils.download * use relative import * add img_util * update strict_load * updat train.py * update test.py * add flow util * update requirements.txt * fix bugs * fix bugs
1 parent 7d03ae2 commit cedf2ca

File tree

93 files changed

+1216
-603
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

93 files changed

+1216
-603
lines changed

basicsr/data/__init__.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
import importlib
2-
import mmcv
32
import numpy as np
43
import random
54
import torch
65
import torch.utils.data
76
from functools import partial
8-
from mmcv.runner import get_dist_info
97
from os import path as osp
108

119
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
12-
from basicsr.utils import get_root_logger
10+
from basicsr.utils import get_root_logger, scandir
11+
from basicsr.utils.dist_util import get_dist_info
1312

1413
__all__ = ['create_dataset', 'create_dataloader']
1514

1615
# automatically scan and import dataset modules
1716
# scan all the files under the data folder with '_dataset' in file names
1817
data_folder = osp.dirname(osp.abspath(__file__))
1918
dataset_filenames = [
20-
osp.splitext(osp.basename(v))[0] for v in mmcv.scandir(data_folder)
19+
osp.splitext(osp.basename(v))[0] for v in scandir(data_folder)
2120
if v.endswith('_dataset.py')
2221
]
2322
# import all the dataset modules

basicsr/data/ffhq_dataset.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
import mmcv
2-
import numpy as np
31
from os import path as osp
42
from torch.utils import data as data
53
from torchvision.transforms.functional import normalize
64

7-
from basicsr.data.transforms import augment, totensor
8-
from basicsr.utils import FileClient
5+
from basicsr.data.transforms import augment
6+
from basicsr.utils import FileClient, imfrombytes, img2tensor
97

108

119
class FFHQDataset(data.Dataset):
@@ -53,12 +51,12 @@ def __getitem__(self, index):
5351
# load gt image
5452
gt_path = self.paths[index]
5553
img_bytes = self.file_client.get(gt_path)
56-
img_gt = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255.
54+
img_gt = imfrombytes(img_bytes, float32=True)
5755

5856
# random horizontal flip
5957
img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
6058
# BGR to RGB, HWC to CHW, numpy to tensor
61-
img_gt = totensor(img_gt, bgr2rgb=True, float32=True)
59+
img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
6260
# normalize
6361
normalize(img_gt, self.mean, self.std, inplace=True)
6462
return {'gt': img_gt, 'gt_path': gt_path}

basicsr/data/paired_image_dataset.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
import mmcv
2-
import numpy as np
31
from torch.utils import data as data
42

5-
from basicsr.data.transforms import augment, paired_random_crop, totensor
3+
from basicsr.data.transforms import augment, paired_random_crop
64
from basicsr.data.util import (paired_paths_from_folder,
75
paired_paths_from_lmdb,
86
paired_paths_from_meta_info_file)
9-
from basicsr.utils import FileClient
7+
from basicsr.utils import FileClient, imfrombytes, img2tensor
108

119

1210
class PairedImageDataset(data.Dataset):
@@ -79,10 +77,10 @@ def __getitem__(self, index):
7977
# image range: [0, 1], float32.
8078
gt_path = self.paths[index]['gt_path']
8179
img_bytes = self.file_client.get(gt_path, 'gt')
82-
img_gt = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255.
80+
img_gt = imfrombytes(img_bytes, float32=True)
8381
lq_path = self.paths[index]['lq_path']
8482
img_bytes = self.file_client.get(lq_path, 'lq')
85-
img_lq = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255.
83+
img_lq = imfrombytes(img_bytes, float32=True)
8684

8785
# augmentation for training
8886
if self.opt['phase'] == 'train':
@@ -96,7 +94,9 @@ def __getitem__(self, index):
9694

9795
# TODO: color space transform
9896
# BGR to RGB, HWC to CHW, numpy to tensor
99-
img_gt, img_lq = totensor([img_gt, img_lq], bgr2rgb=True, float32=True)
97+
img_gt, img_lq = img2tensor([img_gt, img_lq],
98+
bgr2rgb=True,
99+
float32=True)
100100

101101
return {
102102
'lq': img_lq,

basicsr/data/reds_dataset.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
import mmcv
21
import numpy as np
32
import random
43
import torch
54
from pathlib import Path
65
from torch.utils import data as data
76

8-
from basicsr.data.transforms import augment, paired_random_crop, totensor
9-
from basicsr.utils import FileClient, get_root_logger
7+
from basicsr.data.transforms import augment, paired_random_crop
8+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
9+
from basicsr.utils.flow_util import dequantize_flow
1010

1111

1212
class REDSDataset(data.Dataset):
@@ -144,7 +144,7 @@ def __getitem__(self, index):
144144
else:
145145
img_gt_path = self.gt_root / clip_name / f'{frame_name}.png'
146146
img_bytes = self.file_client.get(img_gt_path, 'gt')
147-
img_gt = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255.
147+
img_gt = imfrombytes(img_bytes, float32=True)
148148

149149
# get the neighboring LQ frames
150150
img_lqs = []
@@ -154,7 +154,7 @@ def __getitem__(self, index):
154154
else:
155155
img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
156156
img_bytes = self.file_client.get(img_lq_path, 'lq')
157-
img_lq = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255.
157+
img_lq = imfrombytes(img_bytes, float32=True)
158158
img_lqs.append(img_lq)
159159

160160
# get flows
@@ -168,10 +168,11 @@ def __getitem__(self, index):
168168
flow_path = (
169169
self.flow_root / clip_name / f'{frame_name}_p{i}.png')
170170
img_bytes = self.file_client.get(flow_path, 'flow')
171-
cat_flow = mmcv.imfrombytes(
172-
img_bytes, flag='grayscale') # uint8, [0, 255]
171+
cat_flow = imfrombytes(
172+
img_bytes, flag='grayscale',
173+
float32=False) # uint8, [0, 255]
173174
dx, dy = np.split(cat_flow, 2, axis=0)
174-
flow = mmcv.video.dequantize_flow(
175+
flow = dequantize_flow(
175176
dx, dy, max_val=20,
176177
denorm=False) # we use max_val 20 here.
177178
img_flows.append(flow)
@@ -183,9 +184,11 @@ def __getitem__(self, index):
183184
flow_path = (
184185
self.flow_root / clip_name / f'{frame_name}_n{i}.png')
185186
img_bytes = self.file_client.get(flow_path, 'flow')
186-
cat_flow = mmcv.imfrombytes(img_bytes, flag='grayscale')
187+
cat_flow = imfrombytes(
188+
img_bytes, flag='grayscale',
189+
float32=False) # uint8, [0, 255]
187190
dx, dy = np.split(cat_flow, 2, axis=0)
188-
flow = mmcv.video.dequantize_flow(
191+
flow = dequantize_flow(
189192
dx, dy, max_val=20,
190193
denorm=False) # we use max_val 20 here.
191194
img_flows.append(flow)
@@ -210,12 +213,12 @@ def __getitem__(self, index):
210213
img_results = augment(img_lqs, self.opt['use_flip'],
211214
self.opt['use_rot'])
212215

213-
img_results = totensor(img_results)
216+
img_results = img2tensor(img_results)
214217
img_lqs = torch.stack(img_results[0:-1], dim=0)
215218
img_gt = img_results[-1]
216219

217220
if self.flow_root is not None:
218-
img_flows = totensor(img_flows)
221+
img_flows = img2tensor(img_flows)
219222
# add the zero center flow
220223
img_flows.insert(self.num_half_frames,
221224
torch.zeros_like(img_flows[0]))

basicsr/data/single_image_dataset.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
1-
import mmcv
2-
import numpy as np
31
from os import path as osp
42
from torch.utils import data as data
53
from torchvision.transforms.functional import normalize
64

7-
from basicsr.data.transforms import totensor
8-
from basicsr.utils import FileClient
5+
from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir
96

107

118
class SingleImageDataset(data.Dataset):
@@ -40,10 +37,7 @@ def __init__(self, opt):
4037
line.split(' ')[0]) for line in fin
4138
]
4239
else:
43-
self.paths = [
44-
osp.join(self.lq_folder, v)
45-
for v in mmcv.scandir(self.lq_folder)
46-
]
40+
self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
4741

4842
def __getitem__(self, index):
4943
if self.file_client is None:
@@ -53,11 +47,11 @@ def __getitem__(self, index):
5347
# load lq image
5448
lq_path = self.paths[index]
5549
img_bytes = self.file_client.get(lq_path)
56-
img_lq = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255.
50+
img_lq = imfrombytes(img_bytes, float32=True)
5751

5852
# TODO: color space transform
5953
# BGR to RGB, HWC to CHW, numpy to tensor
60-
img_lq = totensor(img_lq, bgr2rgb=True, float32=True)
54+
img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
6155
# normalize
6256
if self.mean is not None or self.std is not None:
6357
normalize(img_lq, self.mean, self.std, inplace=True)

basicsr/data/transforms.py

+9-37
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
import mmcv
1+
import cv2
22
import random
3-
import torch
43

54

65
def mod_crop(img, scale):
@@ -110,20 +109,20 @@ def augment(imgs, hflip=True, rotation=True, flows=None):
110109
rot90 = rotation and random.random() < 0.5
111110

112111
def _augment(img):
113-
if hflip:
114-
mmcv.imflip_(img, 'horizontal')
115-
if vflip:
116-
mmcv.imflip_(img, 'vertical')
112+
if hflip: # horizontal
113+
cv2.flip(img, 1, img)
114+
if vflip: # vertical
115+
cv2.flip(img, 0, img)
117116
if rot90:
118117
img = img.transpose(1, 0, 2)
119118
return img
120119

121120
def _augment_flow(flow):
122-
if hflip:
123-
mmcv.imflip_(flow, 'horizontal')
121+
if hflip: # horizontal
122+
cv2.flip(flow, 1, flow)
124123
flow[:, :, 0] *= -1
125-
if vflip:
126-
mmcv.imflip_(flow, 'vertical')
124+
if vflip: # vertical
125+
cv2.flip(flow, 0, flow)
127126
flow[:, :, 1] *= -1
128127
if rot90:
129128
flow = flow.transpose(1, 0, 2)
@@ -145,30 +144,3 @@ def _augment_flow(flow):
145144
return imgs, flows
146145
else:
147146
return imgs
148-
149-
150-
def totensor(imgs, bgr2rgb=True, float32=True):
151-
"""Numpy array to tensor.
152-
153-
Args:
154-
imgs (list[ndarray] | ndarray): Input images.
155-
bgr2rgb (bool): Whether to change bgr to rgb.
156-
float32 (bool): Whether to change to float32.
157-
158-
Returns:
159-
list[tensor] | tensor: Tensor images. If returned results only have
160-
one element, just return tensor.
161-
"""
162-
163-
def _totensor(img, bgr2rgb, float32):
164-
if img.shape[2] == 3 and bgr2rgb:
165-
img = mmcv.bgr2rgb(img)
166-
img = torch.from_numpy(img.transpose(2, 0, 1))
167-
if float32:
168-
img = img.float()
169-
return img
170-
171-
if isinstance(imgs, list):
172-
return [_totensor(img, bgr2rgb, float32) for img in imgs]
173-
else:
174-
return _totensor(imgs, bgr2rgb, float32)

basicsr/data/util.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
import mmcv
1+
import cv2
22
import numpy as np
33
import torch
44
from os import path as osp
55
from torch.nn import functional as F
66

7-
from basicsr.data.transforms import mod_crop, totensor
7+
from basicsr.data.transforms import mod_crop
8+
from basicsr.utils import img2tensor, scandir
89

910

1011
def read_img_seq(path, require_mod_crop=False, scale=1):
@@ -22,11 +23,11 @@ def read_img_seq(path, require_mod_crop=False, scale=1):
2223
if isinstance(path, list):
2324
img_paths = path
2425
else:
25-
img_paths = sorted([osp.join(path, v) for v in mmcv.scandir(path)])
26-
imgs = [mmcv.imread(v).astype(np.float32) / 255. for v in img_paths]
26+
img_paths = sorted(list(scandir(path, full_path=True)))
27+
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
2728
if require_mod_crop:
2829
imgs = [mod_crop(img, scale) for img in imgs]
29-
imgs = totensor(imgs, bgr2rgb=True, float32=True)
30+
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
3031
imgs = torch.stack(imgs, dim=0)
3132
return imgs
3233

@@ -227,8 +228,8 @@ def paired_paths_from_folder(folders, keys, filename_tmpl):
227228
input_folder, gt_folder = folders
228229
input_key, gt_key = keys
229230

230-
input_paths = list(mmcv.scandir(input_folder))
231-
gt_paths = list(mmcv.scandir(gt_folder))
231+
input_paths = list(scandir(input_folder))
232+
gt_paths = list(scandir(gt_folder))
232233
assert len(input_paths) == len(gt_paths), (
233234
f'{input_key} and {gt_key} datasets have different number of images: '
234235
f'{len(input_paths)}, {len(gt_paths)}.')
@@ -256,7 +257,7 @@ def paths_from_folder(folder):
256257
list[str]: Returned path list.
257258
"""
258259

259-
paths = list(mmcv.scandir(folder))
260+
paths = list(scandir(folder))
260261
paths = [osp.join(folder, path) for path in paths]
261262
return paths
262263

basicsr/data/video_test_dataset.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import glob
2-
import mmcv
32
import torch
43
from os import path as osp
54
from torch.utils import data as data
65

76
from basicsr.data import util as util
87
from basicsr.data.util import duf_downsample
9-
from basicsr.utils import get_root_logger
8+
from basicsr.utils import get_root_logger, scandir
109

1110

1211
class VideoTestDataset(data.Dataset):
@@ -81,14 +80,10 @@ def __init__(self, opt):
8180
subfolders_gt):
8281
# get frame list for lq and gt
8382
subfolder_name = osp.basename(subfolder_lq)
84-
img_paths_lq = sorted([
85-
osp.join(subfolder_lq, v)
86-
for v in mmcv.scandir(subfolder_lq)
87-
])
88-
img_paths_gt = sorted([
89-
osp.join(subfolder_gt, v)
90-
for v in mmcv.scandir(subfolder_gt)
91-
])
83+
img_paths_lq = sorted(
84+
list(scandir(subfolder_lq, full_path=True)))
85+
img_paths_gt = sorted(
86+
list(scandir(subfolder_gt, full_path=True)))
9287

9388
max_idx = len(img_paths_lq)
9489
assert max_idx == len(img_paths_gt), (

0 commit comments

Comments
 (0)