Skip to content

Commit 7684446

Browse files
committed
Yuv inference ready for sr generation
1 parent 559c601 commit 7684446

10 files changed

+35
-46
lines changed

basicsr/archs/spynet_arch.py

-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def __init__(self, load_path=None):
3737
super(SpyNet, self).__init__()
3838
self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)])
3939
if load_path:
40-
print('KEYS',torch.load(load_path).keys())
4140
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'], strict=True)
4241

4342
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))

bvi-aom

-1
This file was deleted.

bvi-sr

-1
This file was deleted.

inference/triple_run.sh

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
python3 yuv_swinir.py --input /mnt/e/datasets/bvi-aom/ds-yuv/BFireS21Mitch_1920x1088_24fps_10bit_420.yuv --output /mnt/e/datasets/bvi-aom/excluded_sequences/mitchx2.yuv --num_frames 60 --width 1920 --height 1088 --scale 2
2+
python3 yuv_swinir.py --input /mnt/e/datasets/bvi-aom/ds-yuv/CFireS21Mitch_960x544_24fps_10bit_420.yuv --output /mnt/e/datasets/bvi-aom/excluded_sequences/mitchx4.yuv --num_frames 60 --width 960 --height 544 --scale 4
3+
python3 yuv_swinir.py --input /mnt/e/datasets/bvi-aom/ds-yuv/BFireS21Mitch_480x272_24fps_10bit_420.yuv --output /mnt/e/datasets/bvi-aom/excluded_sequences/mitchx8.yuv --num_frames 60 --width 480 --height 272 --scale 8

inference/yuv_basic_vsrpp.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import shutil
66
import torch
7+
from tqdm import tqdm
78

89
from basicsr.archs.basicvsrpp_arch import BasicVSRPlusPlus
910
from basicsr.data.data_util import read_img_seq
@@ -18,52 +19,51 @@ def inference(frames_tensor, model, save_path):
1819
outputs = outputs.squeeze()
1920
outputs = outputs.permute(0, 2, 3, 1)
2021
outputs = outputs.cpu()
21-
print("Output video shape: ", outputs.shape)
2222
rgb_to_yuv420p10bit(outputs, save_path)
2323

2424
def main():
2525
parser = argparse.ArgumentParser()
26-
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/basicvsr_plusplus_c64n7_8x1_600k_reds4_20210217-db622b2f.pth')
27-
parser.add_argument(
28-
'--input_path', type=str, default='', help='input yuv video')
29-
parser.add_argument('--save_path', type=str, default='results/BasicVSRPP', help='save image path')
26+
parser.add_argument('--input', type=str, default='', help='input yuv video')
27+
parser.add_argument('--output', type=str, default='results/BasicVSRPP', help='save image path')
28+
parser.add_argument('--num_frames', type=int, default=60, help='Number of frames to process')
29+
parser.add_argument('--width', type=int, default=960, help='Width of the video')
30+
parser.add_argument('--height', type=int, default=544, help='Height of the video')
3031
parser.add_argument('--interval', type=int, default=15, help='interval size')
3132
args = parser.parse_args()
3233

3334
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
3435

36+
model_path = '/home/sk24938/source/sr/BasicSR/experiments/pretrained_models/basic_vsr_pp/basicvsr_plusplus_reds4.pth'
37+
3538
# set up model
3639
model = BasicVSRPlusPlus(mid_channels=64, num_blocks=7, spynet_path='/home/sk24938/source/sr/BasicSR/experiments/pretrained_models/spynet_20210409-c6c1bd09.pth')
37-
chkpt = torch.load(args.model_path)
38-
print(chkpt.keys())
39-
model.load_state_dict(torch.load(args.model_path)['state_dict'], strict=True)
40+
model.load_state_dict(torch.load(model_path, weights_only=True)['state_dict'], strict=True)
4041
model.eval()
4142
model = model.to(device)
4243

4344
# want to process yuv input frames
4445
# load yuv frames in rgb format
4546
# convert to tensor
4647
frames_np = load_yuv_frames(
47-
video_file_path=args.input_path,
48+
video_file_path=args.input,
4849
start_idx=0,
49-
num_frames=12,
50-
width=256,
51-
height=256,
50+
num_frames=args.num_frames,
51+
width=args.width,
52+
height=args.height,
5253
bit_depth=10,
5354
pixel_format='yuv420p'
5455
)
5556
frames_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2).float()
5657

5758
# load data and inference
58-
num_frames = len(frames_tensor)
59-
if num_frames <= args.interval: # too many images may cause CUDA out of memory
59+
if args.num_frames <= args.interval: # too many images may cause CUDA out of memory
6060
frames_tensor = frames_tensor.unsqueeze(0).to(device)
61-
inference(frames_tensor, model, args.save_path)
61+
inference(frames_tensor, model, args.output)
6262
else:
63-
for idx in range(0, num_frames, args.interval):
64-
interval = min(args.interval, num_frames - idx)
63+
for idx in tqdm(range(0, args.num_frames, args.interval), desc='BasicVSR++'):
64+
interval = min(args.interval, args.num_frames - idx)
6565
frames_tensor = frames_tensor.unsqueeze(0).to(device)
66-
inference(frames_tensor, model, args.save_path)
66+
inference(frames_tensor, model, args.output)
6767

6868

6969
if __name__ == '__main__':

inference/yuv_conventional.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from yuv_utils import *
33
import cv2
44
import argparse
5+
from tqdm import tqdm
56

67

78
def write_yuv_file(y, u, v, output_file_path):
@@ -14,8 +15,6 @@ def write_yuv_file(y, u, v, output_file_path):
1415
f.write(u[i].tobytes())
1516
f.write(v[i].tobytes())
1617

17-
print(f"YUV file written to {output_file_path}")
18-
1918

2019
def rescale_frame(frame, scaling_factor=2, bit_depth=10, method='bicubic'):
2120
interpolation_methods = {
@@ -88,7 +87,7 @@ def main():
8887
y_arr = []
8988
u_arr = []
9089
v_arr = []
91-
for i, frame in enumerate(yuv_frame):
90+
for i, frame in tqdm(enumerate(yuv_frame), desc=f'{args.method}'):
9291
y, u, v = rescale_frame(
9392
frame=frame,
9493
scaling_factor=args.scale,

inference/yuv_edsr.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def main():
1919
parser.add_argument('--width', type=int, default=960, help='Width of the video')
2020
parser.add_argument('--height', type=int, default=544, help='Height of the video')
2121
parser.add_argument('--scale', type=int, default=4, help='Scaling factor')
22-
22+
# Before inference
2323
args = parser.parse_args()
2424

2525
if args.scale == 4:
@@ -42,7 +42,7 @@ def main():
4242
img_range=1.0,
4343
rgb_mean=(0.4488, 0.4371, 0.4040)
4444
)
45-
model.load_state_dict(torch.load(model_path)['params'], strict=True)
45+
model.load_state_dict(torch.load(model_path, weights_only=True)['params'], strict=True)
4646
model.eval()
4747
model = model.to(device)
4848

@@ -64,25 +64,23 @@ def main():
6464
frames_tensor = frames_tensor.to(device)
6565

6666
frame_list = []
67-
for i in tqdm(range(frames_tensor.shape[0]), desc='Processing frames'):
67+
for i in tqdm(range(frames_tensor.shape[0]), desc='EDSR'):
6868
frame = frames_tensor[i, :, :, :]
6969
# inference
7070
try:
7171
with torch.no_grad():
7272
output = model(frame)
73+
7374
except Exception as error:
7475
print('Error', error, i)
7576
else:
7677
# save image
7778
output = output.data.squeeze().cpu()
7879
frame_list.append(output)
7980

80-
print('Saving video...')
8181
video = torch.stack(frame_list, dim=0)
8282
video = video.permute(0, 2, 3, 1)
83-
print('Upsampled video shape: ', video.shape)
8483
rgb_to_yuv420p10bit(video, args.output)
85-
print('Done!')
8684

8785

8886
if __name__ == '__main__':

inference/yuv_edvr.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def main():
5454
num_reconstruct_block=10,
5555
hr_in=False,
5656
)
57-
model.load_state_dict(torch.load(model_path)['params'], strict=True)
57+
model.load_state_dict(torch.load(model_path, weights_only=True)['params'], strict=True)
5858
model.eval()
5959
model = model.to(device)
6060

@@ -78,20 +78,16 @@ def main():
7878
frames_tensor = pad_frames(frames_tensor, pad_size)
7979

8080
frame_list = []
81-
for idx in tqdm(range(0 + pad_size, args.num_frames + pad_size, 1), desc='Processing frames'):
81+
for idx in tqdm(range(0 + pad_size, args.num_frames + pad_size, 1), desc='EDVR'):
8282
start_idx = idx - pad_size
8383
end_idx = idx + pad_size
8484
frames_tensor_chunk = frames_tensor[:, start_idx:end_idx+1, :, :, :].to(device)
8585
output = inference(frames_tensor_chunk, model, args.output)
8686
frame_list.append(output)
8787
torch.cuda.empty_cache()
8888

89-
print('Saving video...')
9089
video = torch.stack(frame_list, dim=0)
91-
print('Upsampled video shape: ', video.shape)
9290
rgb_to_yuv420p10bit(video, args.output)
93-
print('Done!')
94-
9591

9692
if __name__ == '__main__':
9793
main()

inference/yuv_swinir.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,22 @@ def main():
1616
parser = argparse.ArgumentParser()
1717
parser.add_argument('--input', type=str, default='datasets/swin.yuv', help='input yuv file')
1818
parser.add_argument('--output', type=str, default='results/swin.yuv', help='output yuv file')
19-
parser.add_argument('--num_frames', type=int, default=64, help='number of frames to process')
19+
parser.add_argument('--num_frames', type=int, default=60, help='number of frames to process')
2020
parser.add_argument('--width', type=int, default=960, help='width of the frames')
2121
parser.add_argument('--height', type=int, default=544, help='height of the frames')
22-
parser.add_argument('--patch_size', type=int, default=64, help='training patch size')
2322
parser.add_argument('--scale', type=int, default=4, help='scale factor: 2, 4, 8')
23+
parser.add_argument('--patch_size', type=int, default=64, help='patch size')
2424
args = parser.parse_args()
2525

26-
model_path = f"experiments/pretrained_models/swinir/001_classicalSR_DF2K_s64w8_SwinIR-M_x{args.scale}.pth"
26+
model_path = f"/home/sk24938/source/sr/BasicSR/experiments/pretrained_models/swinir/001_classicalSR_DF2K_s64w8_SwinIR-M_x{args.scale}.pth"
27+
window_size = 8
2728

2829
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
2930
# set up model
3031
model = define_model(args, model_path)
3132
model.eval()
3233
model = model.to(device)
3334

34-
window_size = 8
35-
3635
frames_np = load_yuv_frames(
3736
video_file_path=args.input,
3837
start_idx=0,
@@ -45,7 +44,7 @@ def main():
4544
frames_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2).float().unsqueeze(0)
4645

4746
frames_list = []
48-
for idx in tqdm(range(frames_tensor.shape[1]), desc='Processing frames', leave=False):
47+
for idx in tqdm(range(frames_tensor.shape[1]), desc='SwinIR'):
4948
frame = frames_tensor[:, idx].to(device)
5049
# inference
5150
with torch.no_grad():
@@ -67,12 +66,9 @@ def main():
6766
output = output.data.squeeze().float().cpu()
6867
frames_list.append(output)
6968

70-
tqdm.write('Saving video')
7169
video = torch.stack(frames_list, dim=0)
7270
video = video.permute(0, 2, 3, 1)
73-
tqdm.write(f'Upsampled video shape: {video.shape}')
7471
rgb_to_yuv420p10bit(video, args.output)
75-
tqdm.write('Done!')
7672

7773
def define_model(args, model_path):
7874
# 001 classical image sr
@@ -90,7 +86,7 @@ def define_model(args, model_path):
9086
resi_connection='1conv')
9187

9288

93-
loadnet = torch.load(model_path)
89+
loadnet = torch.load(model_path, weights_only=True)
9490
if 'params_ema' in loadnet:
9591
keyname = 'params_ema'
9692
else:

test.yuv

-23.9 MB
Binary file not shown.

0 commit comments

Comments
 (0)