1
- import mmcv
2
1
import numpy as np
3
2
import random
4
3
import torch
5
4
from pathlib import Path
6
5
from torch .utils import data as data
7
6
8
- from basicsr .data .transforms import augment , paired_random_crop , totensor
9
- from basicsr .utils import FileClient , get_root_logger
7
+ from basicsr .data .transforms import augment , paired_random_crop
8
+ from basicsr .utils import FileClient , get_root_logger , imfrombytes , img2tensor
9
+ from basicsr .utils .flow_util import dequantize_flow
10
10
11
11
12
12
class REDSDataset (data .Dataset ):
@@ -144,7 +144,7 @@ def __getitem__(self, index):
144
144
else :
145
145
img_gt_path = self .gt_root / clip_name / f'{ frame_name } .png'
146
146
img_bytes = self .file_client .get (img_gt_path , 'gt' )
147
- img_gt = mmcv . imfrombytes (img_bytes ). astype ( np . float32 ) / 255.
147
+ img_gt = imfrombytes (img_bytes , float32 = True )
148
148
149
149
# get the neighboring LQ frames
150
150
img_lqs = []
@@ -154,7 +154,7 @@ def __getitem__(self, index):
154
154
else :
155
155
img_lq_path = self .lq_root / clip_name / f'{ neighbor :08d} .png'
156
156
img_bytes = self .file_client .get (img_lq_path , 'lq' )
157
- img_lq = mmcv . imfrombytes (img_bytes ). astype ( np . float32 ) / 255.
157
+ img_lq = imfrombytes (img_bytes , float32 = True )
158
158
img_lqs .append (img_lq )
159
159
160
160
# get flows
@@ -168,10 +168,11 @@ def __getitem__(self, index):
168
168
flow_path = (
169
169
self .flow_root / clip_name / f'{ frame_name } _p{ i } .png' )
170
170
img_bytes = self .file_client .get (flow_path , 'flow' )
171
- cat_flow = mmcv .imfrombytes (
172
- img_bytes , flag = 'grayscale' ) # uint8, [0, 255]
171
+ cat_flow = imfrombytes (
172
+ img_bytes , flag = 'grayscale' ,
173
+ float32 = False ) # uint8, [0, 255]
173
174
dx , dy = np .split (cat_flow , 2 , axis = 0 )
174
- flow = mmcv . video . dequantize_flow (
175
+ flow = dequantize_flow (
175
176
dx , dy , max_val = 20 ,
176
177
denorm = False ) # we use max_val 20 here.
177
178
img_flows .append (flow )
@@ -183,9 +184,11 @@ def __getitem__(self, index):
183
184
flow_path = (
184
185
self .flow_root / clip_name / f'{ frame_name } _n{ i } .png' )
185
186
img_bytes = self .file_client .get (flow_path , 'flow' )
186
- cat_flow = mmcv .imfrombytes (img_bytes , flag = 'grayscale' )
187
+ cat_flow = imfrombytes (
188
+ img_bytes , flag = 'grayscale' ,
189
+ float32 = False ) # uint8, [0, 255]
187
190
dx , dy = np .split (cat_flow , 2 , axis = 0 )
188
- flow = mmcv . video . dequantize_flow (
191
+ flow = dequantize_flow (
189
192
dx , dy , max_val = 20 ,
190
193
denorm = False ) # we use max_val 20 here.
191
194
img_flows .append (flow )
@@ -210,12 +213,12 @@ def __getitem__(self, index):
210
213
img_results = augment (img_lqs , self .opt ['use_flip' ],
211
214
self .opt ['use_rot' ])
212
215
213
- img_results = totensor (img_results )
216
+ img_results = img2tensor (img_results )
214
217
img_lqs = torch .stack (img_results [0 :- 1 ], dim = 0 )
215
218
img_gt = img_results [- 1 ]
216
219
217
220
if self .flow_root is not None :
218
- img_flows = totensor (img_flows )
221
+ img_flows = img2tensor (img_flows )
219
222
# add the zero center flow
220
223
img_flows .insert (self .num_half_frames ,
221
224
torch .zeros_like (img_flows [0 ]))
0 commit comments