diff --git a/basicsr/archs/spynet_arch.py b/basicsr/archs/spynet_arch.py index 4c7af133d..afb4a6922 100644 --- a/basicsr/archs/spynet_arch.py +++ b/basicsr/archs/spynet_arch.py @@ -37,7 +37,7 @@ def __init__(self, load_path=None): super(SpyNet, self).__init__() self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)]) if load_path: - self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) + self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'], strict=True) self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) diff --git a/esrgan_dir.py b/esrgan_dir.py new file mode 100644 index 000000000..c5e480658 --- /dev/null +++ b/esrgan_dir.py @@ -0,0 +1,53 @@ +import os +import subprocess +import argparse +from tqdm import tqdm + +def esrgan_dir(input_dir, output_dir, model, scaling_factor=4): + files = os.listdir(input_dir) + + # Chose model and file identification letter + if model == 'RealESRGAN': + scale_dict = {2: ['RealESRGAN_x2plus.pth', 'B'], 4: ['RealESRGAN_x4plus.pth', 'C']} + else: + scale_dict = {2: ['DF2KOST_official-ff704c30.pth', 'B'], 4: ['DF2KOST_official-ff704c30.pth', 'C']} + + if scaling_factor == 2: + width, height = 1920, 1088 + elif scaling_factor == 4: + width, height = 960, 544 + + output_dir = os.path.join(output_dir, model, 'x' + str(scaling_factor)) + + for file in tqdm(files, desc='Processing'): + + if file[0] != scale_dict[scaling_factor][1]: + continue + + input_file = os.path.join(input_dir, file) + output_file = os.path.join(output_dir, file) + cmd = [ + 'python3', 'inference/esrgan_yuv.py', + '--model_path', 'experiments/pretrained_models/esrgan/{0}'.format(scale_dict[scaling_factor][0]), + '--input', input_file, + '--output', output_file, + '--num_frames', '64', + '--width', str(width), + '--height', str(height) + ] + + subprocess.run(cmd) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--input', type=str, default='datasets/Set14/LRbicx4', help='input test image folder') + parser.add_argument('--output', type=str, default='results/ESRGAN', help='output folder') + parser.add_argument('--scaling_factor', type=int, default=4, help='scaling factor') + parser.add_argument('--model', type=str, default='RealESRGAN', help='model to use') + args = parser.parse_args() + + if not os.path.exists(args.output): + os.makedirs(args.output) + + esrgan_dir(args.input, args.output, args.model, args.scaling_factor) \ No newline at end of file diff --git a/inference/inference_basicvsr.py b/inference/inference_basicvsr.py index 7b5e4b945..741a308f8 100644 --- a/inference/inference_basicvsr.py +++ b/inference/inference_basicvsr.py @@ -48,7 +48,13 @@ def main(): video_name = os.path.splitext(os.path.split(args.input_path)[-1])[0] input_path = os.path.join('./BasicVSR_tmp', video_name) os.makedirs(os.path.join('./BasicVSR_tmp', video_name), exist_ok=True) - os.system(f'ffmpeg -i {args.input_path} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 {input_path} /frame%08d.png') + os.system(f'ffmpeg -i {args.input_path} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 {input_path}/frame%08d.png') + + + # want to process yuv input frames + # load yuv frames in rgb format + # convert to tensor + # load data and inference imgs_list = sorted(glob.glob(os.path.join(input_path, '*'))) diff --git a/inference/triple_run.sh b/inference/triple_run.sh new file mode 100644 index 000000000..8c449e5fb --- /dev/null +++ b/inference/triple_run.sh @@ -0,0 +1,3 @@ +python3 yuv_swinir.py --input /mnt/e/datasets/bvi-aom/ds-yuv/BFireS21Mitch_1920x1088_24fps_10bit_420.yuv --output /mnt/e/datasets/bvi-aom/excluded_sequences/mitchx2.yuv --num_frames 60 --width 1920 --height 1088 --scale 2 +python3 yuv_swinir.py --input /mnt/e/datasets/bvi-aom/ds-yuv/CFireS21Mitch_960x544_24fps_10bit_420.yuv --output /mnt/e/datasets/bvi-aom/excluded_sequences/mitchx4.yuv --num_frames 60 --width 960 --height 544 --scale 4 +python3 yuv_swinir.py --input /mnt/e/datasets/bvi-aom/ds-yuv/BFireS21Mitch_480x272_24fps_10bit_420.yuv --output /mnt/e/datasets/bvi-aom/excluded_sequences/mitchx8.yuv --num_frames 60 --width 480 --height 272 --scale 8 \ No newline at end of file diff --git a/inference/yuv_basic_vsr.py b/inference/yuv_basic_vsr.py new file mode 100644 index 000000000..d15976dbc --- /dev/null +++ b/inference/yuv_basic_vsr.py @@ -0,0 +1,69 @@ +import argparse +import cv2 +import glob +import os +import shutil +import torch + +from basicsr.archs.basicvsr_arch import BasicVSR +from basicsr.data.data_util import read_img_seq +from basicsr.utils.img_util import tensor2img +from yuv_utils import * + + +def inference(frames_tensor, model, save_path): + with torch.no_grad(): + outputs = model(frames_tensor) + # save imgs + outputs = outputs.squeeze() + outputs = outputs.permute(0, 2, 3, 1) + outputs = outputs.cpu() + print(outputs.shape) + rgb_to_yuv420p10bit(outputs, save_path) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/BasicVSR_REDS4.pth') + parser.add_argument( + '--input_path', type=str, default='datasets/REDS4/sharp_bicubic/000', help='input test image folder') + parser.add_argument('--save_path', type=str, default='results/BasicVSR', help='save image path') + parser.add_argument('--interval', type=int, default=15, help='interval size') + args = parser.parse_args() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # set up model + model = BasicVSR(num_feat=64, num_block=30) + model.load_state_dict(torch.load(args.model_path)['params'], strict=True) + model.eval() + model = model.to(device) + + # want to process yuv input frames + # load yuv frames in rgb format + # convert to tensor + frames_np = load_yuv_frames( + video_file_path=args.input_path, + start_idx=0, + num_frames=12, + width=256, + height=256, + bit_depth=10, + pixel_format='yuv420p' + ) + print(frames_np.shape) + frames_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2).float() + print(frames_tensor.shape) + # load data and inference + num_frames = len(frames_tensor) + if num_frames <= args.interval: # too many images may cause CUDA out of memory + frames_tensor = frames_tensor.unsqueeze(0).to(device) + inference(frames_tensor, model, args.save_path) + else: + for idx in range(0, num_frames, args.interval): + interval = min(args.interval, num_frames - idx) + frames_tensor = frames_tensor.unsqueeze(0).to(device) + inference(frames_tensor, model, args.save_path) + + +if __name__ == '__main__': + main() diff --git a/inference/yuv_basic_vsrpp.py b/inference/yuv_basic_vsrpp.py new file mode 100644 index 000000000..c39f2daf3 --- /dev/null +++ b/inference/yuv_basic_vsrpp.py @@ -0,0 +1,70 @@ +import argparse +import cv2 +import glob +import os +import shutil +import torch +from tqdm import tqdm + +from basicsr.archs.basicvsrpp_arch import BasicVSRPlusPlus +from basicsr.data.data_util import read_img_seq +from basicsr.utils.img_util import tensor2img +from yuv_utils import * + + +def inference(frames_tensor, model, save_path): + with torch.no_grad(): + outputs = model(frames_tensor) + # save imgs + outputs = outputs.squeeze() + outputs = outputs.permute(0, 2, 3, 1) + outputs = outputs.cpu() + rgb_to_yuv420p10bit(outputs, save_path) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--input', type=str, default='', help='input yuv video') + parser.add_argument('--output', type=str, default='results/BasicVSRPP', help='save image path') + parser.add_argument('--num_frames', type=int, default=60, help='Number of frames to process') + parser.add_argument('--width', type=int, default=960, help='Width of the video') + parser.add_argument('--height', type=int, default=544, help='Height of the video') + parser.add_argument('--interval', type=int, default=15, help='interval size') + args = parser.parse_args() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + model_path = '/home/sk24938/source/sr/BasicSR/experiments/pretrained_models/basic_vsr_pp/basicvsr_plusplus_reds4.pth' + + # set up model + model = BasicVSRPlusPlus(mid_channels=64, num_blocks=7, spynet_path='/home/sk24938/source/sr/BasicSR/experiments/pretrained_models/spynet_20210409-c6c1bd09.pth') + model.load_state_dict(torch.load(model_path, weights_only=True)['state_dict'], strict=True) + model.eval() + model = model.to(device) + + # want to process yuv input frames + # load yuv frames in rgb format + # convert to tensor + frames_np = load_yuv_frames( + video_file_path=args.input, + start_idx=0, + num_frames=args.num_frames, + width=args.width, + height=args.height, + bit_depth=10, + pixel_format='yuv420p' + ) + frames_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2).float() + + # load data and inference + if args.num_frames <= args.interval: # too many images may cause CUDA out of memory + frames_tensor = frames_tensor.unsqueeze(0).to(device) + inference(frames_tensor, model, args.output) + else: + for idx in tqdm(range(0, args.num_frames, args.interval), desc='BasicVSR++'): + interval = min(args.interval, args.num_frames - idx) + frames_tensor = frames_tensor.unsqueeze(0).to(device) + inference(frames_tensor, model, args.output) + + +if __name__ == '__main__': + main() diff --git a/inference/yuv_conventional.py b/inference/yuv_conventional.py new file mode 100644 index 000000000..937c1a308 --- /dev/null +++ b/inference/yuv_conventional.py @@ -0,0 +1,106 @@ +from PIL import Image +from yuv_utils import * +import cv2 +import argparse +from tqdm import tqdm + + +def write_yuv_file(y, u, v, output_file_path): + """ + Writes YUV frames to a .yuv file in 4:2:0 format. + """ + with open(output_file_path, "wb") as f: + for i in range(len(y)): + f.write(y[i].tobytes()) + f.write(u[i].tobytes()) + f.write(v[i].tobytes()) + + +def rescale_frame(frame, scaling_factor=2, bit_depth=10, method='bicubic'): + interpolation_methods = { + 'nearest': cv2.INTER_NEAREST, + 'bicubic': cv2.INTER_CUBIC + } + + if method not in interpolation_methods: + raise ValueError("Invalid method. Choose 'nearest' or 'bicubic'.") + interpolation = interpolation_methods[method] + height, width = frame.shape[:2] + # Convert to uint16 for 10-bit data if necessary + frame = np.clip(frame, 0, (2 ** bit_depth) - 1) + + # Resize Y (luma) channel + Y = frame[:, :, 0] + Y_resized = cv2.resize( + Y.astype(np.uint16), # Ensure correct type (uint16 for 10-bit) + (int(width * scaling_factor), int(height * scaling_factor)), + interpolation=cv2.INTER_CUBIC + ) + + # Resize U and V (chroma) channels + U = frame[:, :, 1] + V = frame[:, :, 2] + U_resized = cv2.resize( + U.astype(np.uint16), + (int(width * scaling_factor / 2), int(height * scaling_factor / 2)), + interpolation=interpolation + ) + V_resized = cv2.resize( + V.astype(np.uint16), + (int(width * scaling_factor / 2), int(height * scaling_factor / 2)), + interpolation=interpolation + ) + + # Clip resized channels to keep values within 10-bit range + Y_resized = np.clip(Y_resized, 0, (2 ** bit_depth) - 1) + U_resized = np.clip(U_resized, 0, (2 ** bit_depth) - 1) + V_resized = np.clip(V_resized, 0, (2 ** bit_depth) - 1) + + return Y_resized, U_resized, V_resized + + +def main(): + arg_parser = argparse.ArgumentParser(description='Rescale YUV video') + arg_parser.add_argument('--input', help='Path to the YUV video file') + arg_parser.add_argument('--output', help='Path to the output YUV video file') + arg_parser.add_argument('--num_frames', type=int, default=60, help='Number of frames to process') + arg_parser.add_argument('--width', type=int, default=3840, help='Width of the video') + arg_parser.add_argument('--height', type=int, default=2176, help='Height of the video') + arg_parser.add_argument('--scale', type=int, default=4, help='Scaling factor') + arg_parser.add_argument('--method', type=str, default='bicubic', help='Interpolation method') + + args = arg_parser.parse_args() + + # Load the YUV frame + yuv_frame = load_yuv_frames( + video_file_path=args.input, + start_idx=0, + num_frames=args.num_frames, + width=args.width, + height=args.height, + bit_depth=10, + pixel_format='yuv420p', + convert_to_rgb=False + ) + + # Rescale the YUV frame + y_arr = [] + u_arr = [] + v_arr = [] + for i, frame in tqdm(enumerate(yuv_frame), desc=f'{args.method}'): + y, u, v = rescale_frame( + frame=frame, + scaling_factor=args.scale, + bit_depth=10, + method=args.method + ) + y_arr.append(y) + u_arr.append(u) + v_arr.append(v) + + # Write the rescaled YUV video to a file + write_yuv_file(y_arr, u_arr, v_arr, args.output) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/inference/yuv_edsr.py b/inference/yuv_edsr.py new file mode 100644 index 000000000..993b8303a --- /dev/null +++ b/inference/yuv_edsr.py @@ -0,0 +1,87 @@ +import argparse +import cv2 +import glob +import os +import shutil +import torch +from tqdm import tqdm + +from basicsr.archs.edsr_arch import EDSR +from yuv_utils import * + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--input', type=str, default='datasets/REDS4/sharp_bicubic/000', help='input test image folder') + parser.add_argument('--output', type=str, default='results/edsr/test.yuv', help='save image path') + parser.add_argument('--num_frames', type=int, default=60, help='Number of frames to process') + parser.add_argument('--width', type=int, default=960, help='Width of the video') + parser.add_argument('--height', type=int, default=544, help='Height of the video') + parser.add_argument('--scale', type=int, default=4, help='Scaling factor') + # Before inference + args = parser.parse_args() + + if args.scale == 4: + model_path = '/home/sk24938/source/sr/BasicSR/experiments/pretrained_models/edsr/EDSR_Lx4_f256b32_DIV2K_official-76ee1c8f.pth' + elif args.scale == 2: + model_path = '/home/sk24938/source/sr/BasicSR/experiments/pretrained_models/edsr/EDSR_Lx2_f256b32_DIV2K_official-be38e77d.pth' + else: + raise ValueError('Scale factor not supported') + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # set up model + model = EDSR( + num_in_ch=3, + num_out_ch=3, + num_feat=256, + num_block=32, + upscale=args.scale, + res_scale=0.1, + img_range=1.0, + rgb_mean=(0.4488, 0.4371, 0.4040) + ) + model.load_state_dict(torch.load(model_path, weights_only=True)['params'], strict=True) + model.eval() + model = model.to(device) + + # want to process yuv input frames + # load yuv frames in rgb format + # convert to tensor + frames_np = load_yuv_frames( + video_file_path=args.input, + start_idx=0, + num_frames=args.num_frames, + width=args.width, + height=args.height, + bit_depth=10, + pixel_format='yuv420p', + convert_to_rgb=True + ) + + frames_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2).float() + frames_tensor = frames_tensor.to(device) + + frame_list = [] + for i in tqdm(range(frames_tensor.shape[0]), desc='EDSR'): + frame = frames_tensor[i, :, :, :] + # inference + try: + with torch.no_grad(): + output = model(frame) + + except Exception as error: + print('Error', error, i) + else: + # save image + output = output.data.squeeze().cpu() + frame_list.append(output) + + video = torch.stack(frame_list, dim=0) + video = video.permute(0, 2, 3, 1) + rgb_to_yuv420p10bit(video, args.output) + + +if __name__ == '__main__': + main() diff --git a/inference/yuv_edvr.py b/inference/yuv_edvr.py new file mode 100644 index 000000000..d4b653c5f --- /dev/null +++ b/inference/yuv_edvr.py @@ -0,0 +1,93 @@ +import argparse +import cv2 +import glob +import os +import shutil +import torch +from tqdm import tqdm + +from basicsr.archs.edvr_arch import EDVR +from yuv_utils import * + + +def inference(frames_tensor, model, save_path): + with torch.no_grad(): + outputs = model(frames_tensor) + # save imgs + outputs = outputs.permute(0, 2, 3, 1).squeeze() + outputs = outputs.cpu() + return outputs + + +def pad_frames(frames_tensor, pad_size): + # Pad the frames by repeating the first and last frames + start_padding = frames_tensor[:, :1, :, :, :].expand(-1, pad_size, -1, -1, -1) + end_padding = frames_tensor[:, -1:, :, :, :].expand(-1, pad_size, -1, -1, -1) + padded_frames = torch.cat([start_padding, frames_tensor, end_padding], dim=1) + return padded_frames + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--input', type=str, default='datasets/REDS4/sharp_bicubic/000', help='input test image folder') + parser.add_argument('--output', type=str, default='results/edsr/test.yuv', help='save image path') + parser.add_argument('--num_frames', type=int, default=60, help='Number of frames to process') + parser.add_argument('--width', type=int, default=960, help='Width of the video') + parser.add_argument('--height', type=int, default=544, help='Height of the video') + parser.add_argument('--interval', type=int, default=5, help='Interval size for processing frames in chunks') + + args = parser.parse_args() + + model_path = '/home/sk24938/source/sr/BasicSR/experiments/pretrained_models/edvr/EDVR_M_x4_SR_REDS_official-32075921.pth' + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # set up model + model = EDVR( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_frame=args.interval, + deformable_groups=8, + num_extract_block=5, + num_reconstruct_block=10, + hr_in=False, + ) + model.load_state_dict(torch.load(model_path, weights_only=True)['params'], strict=True) + model.eval() + model = model.to(device) + + # want to process yuv input frames + # load yuv frames in rgb format + # convert to tensor + frames_np = load_yuv_frames( + video_file_path=args.input, + start_idx=0, + num_frames=args.num_frames, + width=args.width, + height=args.height, + bit_depth=10, + pixel_format='yuv420p', + convert_to_rgb=True + ) + + frames_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2).float() + frames_tensor = frames_tensor.unsqueeze(0) + pad_size = args.interval // 2 + frames_tensor = pad_frames(frames_tensor, pad_size) + + frame_list = [] + for idx in tqdm(range(0 + pad_size, args.num_frames + pad_size, 1), desc='EDVR'): + start_idx = idx - pad_size + end_idx = idx + pad_size + frames_tensor_chunk = frames_tensor[:, start_idx:end_idx+1, :, :, :].to(device) + output = inference(frames_tensor_chunk, model, args.output) + frame_list.append(output) + torch.cuda.empty_cache() + + video = torch.stack(frame_list, dim=0) + rgb_to_yuv420p10bit(video, args.output) + +if __name__ == '__main__': + main() diff --git a/inference/yuv_esrgan.py b/inference/yuv_esrgan.py new file mode 100644 index 000000000..8d5ead660 --- /dev/null +++ b/inference/yuv_esrgan.py @@ -0,0 +1,75 @@ +import argparse +import cv2 +import glob +import numpy as np +import os +import torch + +from basicsr.archs.rrdbnet_arch import RRDBNet +from yuv_utils import * + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--model_path', + type=str, + default= # noqa: E251 + 'experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth' # noqa: E501 + ) + parser.add_argument('--input', type=str, default='datasets/Set14/LRbicx4', help='input test image folder') + parser.add_argument('--output', type=str, default='results/ESRGAN', help='output folder') + parser.add_argument('--num_frames', type=int, default=64, help='number of frames to process') + parser.add_argument('--width', type=int, default=960, help='width of the frames') + parser.add_argument('--height', type=int, default=544, help='height of the frames') + args = parser.parse_args() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # set up model + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) + checkpoint = torch.load(args.model_path) + + if args.model_path == 'experiments/pretrained_models/esrgan/RealESRGAN_x4plus.pth' or args.model_path == 'experiments/pretrained_models/esrgan/RealESRGAN_x2plus.pth': + param_key = 'params_ema' + else: + param_key = 'params' + + model.load_state_dict(torch.load(args.model_path)[param_key], strict=True) + model.eval() + model = model.to(device) + + frames_np = load_yuv_frames( + video_file_path=args.input, + start_idx=0, + num_frames=args.num_frames, + width=args.width, + height=args.height, + bit_depth=10, + pixel_format='yuv420p' + ) + frames_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2).float() + frames_tensor = frames_tensor.unsqueeze(0).to(device) + + frame_list = [] + for i in range(frames_tensor.shape[1]): + frame = frames_tensor[:, i] + # inference + try: + with torch.no_grad(): + output = model(frame) + except Exception as error: + print('Error', error, i) + else: + # save image + output = output.data.squeeze().cpu() + frame_list.append(output) + + print('Saving video...') + video = torch.stack(frame_list, dim=0) + video = video.permute(0, 2, 3, 1) + print('Upsampled video shape: ', video.shape) + rgb_to_yuv420p10bit(video, args.output) + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/inference/yuv_swinir.py b/inference/yuv_swinir.py new file mode 100644 index 000000000..4ed9aa08a --- /dev/null +++ b/inference/yuv_swinir.py @@ -0,0 +1,100 @@ +# Modified from https://github.com/JingyunLiang/SwinIR +import argparse +import cv2 +import glob +import numpy as np +import os +import torch +from torch.nn import functional as F + +from basicsr.archs.swinir_arch import SwinIR +from yuv_utils import * +from tqdm import tqdm + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--input', type=str, default='datasets/swin.yuv', help='input yuv file') + parser.add_argument('--output', type=str, default='results/swin.yuv', help='output yuv file') + parser.add_argument('--num_frames', type=int, default=60, help='number of frames to process') + parser.add_argument('--width', type=int, default=960, help='width of the frames') + parser.add_argument('--height', type=int, default=544, help='height of the frames') + parser.add_argument('--scale', type=int, default=4, help='scale factor: 2, 4, 8') + parser.add_argument('--patch_size', type=int, default=64, help='patch size') + args = parser.parse_args() + + model_path = f"/home/sk24938/source/sr/BasicSR/experiments/pretrained_models/swinir/001_classicalSR_DF2K_s64w8_SwinIR-M_x{args.scale}.pth" + window_size = 8 + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # set up model + model = define_model(args, model_path) + model.eval() + model = model.to(device) + + frames_np = load_yuv_frames( + video_file_path=args.input, + start_idx=0, + num_frames=args.num_frames, + width=args.width, + height=args.height, + bit_depth=10, + pixel_format='yuv420p' + ) + frames_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2).float().unsqueeze(0) + + frames_list = [] + for idx in tqdm(range(frames_tensor.shape[1]), desc='SwinIR'): + frame = frames_tensor[:, idx].to(device) + # inference + with torch.no_grad(): + # pad input image to be a multiple of window_size + mod_pad_h, mod_pad_w = 0, 0 + _, _, h, w = frame.shape + if h % window_size != 0: + mod_pad_h = window_size - h % window_size + if w % window_size != 0: + mod_pad_w = window_size - w % window_size + frame = F.pad(frame, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + + output = model(frame) + _, _, h, w = output.size() + # remove padding + output = output[:, :, 0:h - mod_pad_h * args.scale, 0:w - mod_pad_w * args.scale] + + # save image + output = output.data.squeeze().float().cpu() + frames_list.append(output) + + video = torch.stack(frames_list, dim=0) + video = video.permute(0, 2, 3, 1) + rgb_to_yuv420p10bit(video, args.output) + +def define_model(args, model_path): + # 001 classical image sr + model = SwinIR( + upscale=args.scale, + in_chans=3, + img_size=args.patch_size, + window_size=8, + img_range=1., + depths=[6, 6, 6, 6, 6, 6], + embed_dim=180, + num_heads=[6, 6, 6, 6, 6, 6], + mlp_ratio=2, + upsampler='pixelshuffle', + resi_connection='1conv') + + + loadnet = torch.load(model_path, weights_only=True) + if 'params_ema' in loadnet: + keyname = 'params_ema' + else: + keyname = 'params' + model.load_state_dict(loadnet[keyname], strict=True) + + return model + + +if __name__ == '__main__': + main() diff --git a/inference/yuv_utils.py b/inference/yuv_utils.py new file mode 100644 index 000000000..22a25bf9c --- /dev/null +++ b/inference/yuv_utils.py @@ -0,0 +1,291 @@ +import numpy as np + + +import numpy as np + +def rgb_to_yuv420p10bit(rgb_frames, output_file_path): + """Convert a sequence of RGB frames stored as np.ndarray to + YUV420P 10-bit format and write to file. + + Args: + rgb_frames (np.ndarray): sequence of frames to save. + output_file_path (str): location to save the YUV file. + """ + num_frames, height, width, _ = rgb_frames.shape + + # List to store the YUV data per frame (as a list of arrays) + yuv_data = [] + + for frame in rgb_frames: + # Convert RGB to YUV using the common RGB -> YUV conversion matrix + yuv_frame = np.zeros((height, width, 3), dtype=np.float32) + # Apply the RGB to YUV conversion matrix (for RGB in [0, 1]) + yuv_frame[:, :, 0] = 0.257 * frame[:, :, 0] + 0.504 * frame[:, :, 1] + 0.098 * frame[:, :, 2] + 0.0625 # Y + yuv_frame[:, :, 1] = -0.148 * frame[:, :, 0] - 0.291 * frame[:, :, 1] + 0.439 * frame[:, :, 2] + 0.5 # U + yuv_frame[:, :, 2] = 0.439 * frame[:, :, 0] - 0.368 * frame[:, :, 1] - 0.071 * frame[:, :, 2] + 0.5 # V + + # Scale to 10-bit range and round + yuv_frame = np.round(yuv_frame * 1023).astype(np.uint16) + + # Extract Y, U, V components + y = yuv_frame[:, :, 0] + u = yuv_frame[:, :, 1] + v = yuv_frame[:, :, 2] + + # Step 2: Downsample U and V channels for YUV420 format + u_downsampled = u[::2, ::2] + v_downsampled = v[::2, ::2] + + # Flatten channels and append to YUV data for writing + frame_data = np.concatenate([y.flatten(), u_downsampled.flatten(), v_downsampled.flatten()]) + yuv_data.append(frame_data) + + # Convert the list of frames into a single numpy array + yuv_data = np.concatenate(yuv_data) + + # Write the YUV data to the file as 10-bit values (2 bytes per sample) + with open(output_file_path, 'wb') as f: + f.write(yuv_data.tobytes()) + + print(f'YUV file saved to {output_file_path}') + + +def patch_2d(yuv_frames, patch_idx=(0, 0), patch_shape=(256, 256)): + """ + Extract a specific spatial patch from a np.ndarray of YUV frames. + + Args: + yuv_frame (np.ndarray): Sequence of YUV frames of shape (num_frames, height, width, 3). + patch_idx (tuple): The (x, y) coordinates of the top-left corner of the patch. + patch_shape (tuple): The shape of the patch (patch_height, patch_width). + + Returns: + np.ndarray: The extracted patch. + """ + num_frames, height, width, _ = yuv_frames.shape + if patch_idx[0] + patch_shape[0] > height or patch_idx[1] + patch_shape[1] > width: + raise ValueError("Patch dimensions exceed frame dimensions.") + return yuv_frames[ + :, + patch_idx[0]:patch_idx[0]+patch_shape[0], + patch_idx[1]:patch_idx[1]+patch_shape[1] + ] + + +def yuv2rgb(image, bit_depth, normalize=False): + """Convert image from YUV to RGB color space.""" + + N = ((2**bit_depth)-1) + + Y = np.float32(image[:,:,0]) + + U = np.float32(image[:,:,1]) + + V = np.float32(image[:,:,2]) + + Y = Y/N + U = U/N + V = V/N + + fy = Y + fu = U-0.5 + fv = V-0.5 + + # parameters + KR = 0.2627 + KG = 0.6780 + KB = 0.0593 + + R = fy + 1.4746*fv + B = fy + 1.8814*fu + G = -(B*KB+KR*R-Y)/KG + + R[R<0] = 0 + R[R>1] = 1 + G[G<0] = 0 + G[G>1] = 1 + B[B<0] = 0 + B[B>1] = 1 + + rgb_image = np.array([R,G,B]) + rgb_image = np.swapaxes(rgb_image,0,2) + rgb_image = np.swapaxes(rgb_image,0,1) + if normalize == False: + rgb_image = rgb_image*N + return rgb_image + + +def load_yuv_frame(video_file, idx, width, height, bit_depth=10, pixel_format='yuv420p'): + """ + Load a YUV frame from a video file. + + Args: + video_file (file object): The video file to read from. + idx (int): The index of the frame to load. + width (int): The width of the frame. + height (int): The height of the frame. + bit_depth (int): Bit depth of the video (default: 10). + pixel_format (str): The pixel format of the video (default: 'yuv420p'). + + Returns: + torch.Tensor: The loaded frame as a PyTorch tensor. + """ + # Calculate frame size based on pixel format + if bit_depth == 10: + multiplier = 2 + _dtype = np.uint16 + elif bit_depth == 8: + multiplier = 1 + _dtype = np.uint8 + else: + raise ValueError(f"Unsupported bit depth: {bit_depth}") + + wh = width * height + wh_2 = wh // 2 + wh_4 = wh // 4 + h_2, w_2 = height // 2, width // 2 + if pixel_format == 'yuv420p': + frame_size = wh * 1.5 # Y + U + V (U and V downsampled) + elif pixel_format == 'yuv422p': + frame_size = wh * 2 # Y + U + V (U and V not downsampled) + elif pixel_format == 'yuv444p': + frame_size = wh * 3 # Y + U + V (all full resolution) + else: + raise ValueError(f"Unsupported pixel format: {pixel_format}") + + # Seek to the specified frame and read the YUV data + video_file.seek(int(frame_size * idx * multiplier), 0) + yuv_frame = np.frombuffer(video_file.read(int(frame_size * multiplier)), dtype=_dtype) + if bit_depth == 10: + yuv_frame = yuv_frame & 0x03FF # Convert 16-bit data to 10-bit + + # Check if we read enough data + if len(yuv_frame) < frame_size: + raise ValueError(f"Not enough data read for frame index {idx}. Expected {frame_size} bytes, got {len(yuv_frame)}.") + + # Load Y, U, and V components + y = yuv_frame[0:wh].reshape((height, width)) + + if pixel_format == 'yuv420p': + u = yuv_frame[wh:wh + wh_4].reshape((h_2, w_2)) + v = yuv_frame[wh + wh_4:].reshape((h_2, w_2)) + # Upsample U and V channels to match Y channel size + u = np.repeat(np.repeat(u, 2, axis=0), 2, axis=1) + v = np.repeat(np.repeat(v, 2, axis=0), 2, axis=1) + elif pixel_format == 'yuv422p': + u = yuv_frame[wh:wh + wh_2].reshape((height, w_2)) + v = yuv_frame[wh + wh_2:].reshape((height, w_2)) + # Upsample U and V channels to match Y channel size + u = u.repeat(2, axis=1) + v = v.repeat(2, axis=1) + elif pixel_format == 'yuv444p': + u = yuv_frame[wh:wh * 2].reshape((height, width)) + v = yuv_frame[wh * 2:].reshape((height, width)) + + # Stack the Y, U, and V components + frame = np.stack((y, u, v), axis=2) + # Convert YUV to RGB + frame = yuv2rgb(frame, bit_depth=bit_depth, normalize=True) + return frame + + +def process_yuv_frames( + yuv_data, + num_frames, + width, + height, + bit_depth, + pixel_format, + frame_size, + multiplier, + convert_to_rgb=True): + + wh = width * height + wh_2 = wh // 2 + wh_4 = wh // 4 + h_2, w_2 = height // 2, width // 2 + + frames = np.empty((num_frames, height, width, 3), dtype=np.float32) + for i in range(num_frames): + # Extract the portion of yuv_data corresponding to the current frame + frame_start = i * frame_size + frame_end = (i + 1) * frame_size + frame_data = yuv_data[frame_start:frame_end] + if bit_depth == 10: + frame_data = frame_data & 0x03FF # Convert 16-bit data to 10-bit + + # Load Y, U, and V components + y = frame_data[:wh].reshape((height, width)) + if pixel_format == 'yuv420p': + u = frame_data[wh:wh + wh_4].reshape((h_2, w_2)) + v = frame_data[wh + wh_4:].reshape((h_2, w_2)) + # Upsample U and V channels to match Y channel size + u = np.repeat(np.repeat(u, 2, axis=0), 2, axis=1) + v = np.repeat(np.repeat(v, 2, axis=0), 2, axis=1) + elif pixel_format == 'yuv422p': + u = frame_data[wh:wh + wh_2].reshape((height, w_2)) + v = frame_data[wh + wh_2:].reshape((height, w_2)) + # Upsample U and V channels to match Y channel size + u = u.repeat(2, axis=1) + v = v.repeat(2, axis=1) + elif pixel_format == 'yuv444p': + u = frame_data[wh:wh * 2].reshape((height, width)) + v = frame_data[wh * 2:].reshape((height, width)) + + # Stack the Y, U, and V components + frame = np.stack((y, u, v), axis=2) + + if convert_to_rgb: + # Convert YUV to RGB + frame = yuv2rgb(frame, bit_depth=bit_depth, normalize=True) + frames[i] = frame + return frames + + +def load_yuv_frames(video_file_path, start_idx, num_frames, width, height, bit_depth=10, pixel_format='yuv420p', convert_to_rgb=True): + """ + Load a specified number of YUV frames from a video file. + + Args: + video_file (file object): The video file to read from. + start_idx (int): The starting index of the frame to load. + num_frames (int): The number of frames to load. + width (int): The width of the frame. + height (int): The height of the frame. + bit_depth (int): Bit depth of the video (default: 10). + pixel_format (str): The pixel format of the video (default: 'yuv420p'). + + Returns: + list: A list of loaded frames as numpy arrays. + """ + # Calculate frame size based on pixel format + if bit_depth == 10: + multiplier = 2 + _dtype = np.uint16 + elif bit_depth == 8: + multiplier = 1 + _dtype = np.uint8 + else: + raise ValueError(f"Unsupported bit depth: {bit_depth}") + + wh = width * height + if pixel_format == 'yuv420p': + frame_size = int(wh * 1.5) # Y + U + V (U and V downsampled) + elif pixel_format == 'yuv422p': + frame_size = int(wh * 2) # Y + U + V (U and V not downsampled) + elif pixel_format == 'yuv444p': + frame_size = int(wh * 3) # Y + U + V (all full resolution) + else: + raise ValueError(f"Unsupported pixel format: {pixel_format}") + + # Read the data for all frames at once + total_size = int(frame_size * num_frames * multiplier) + + with open(video_file_path, 'rb') as video_file: + # Seek to the starting frame + video_file.seek(int(frame_size * start_idx * multiplier), 0) + yuv_data = np.frombuffer(video_file.read(total_size), dtype=_dtype) + + # Process each frame using the optimized combined function + frames = process_yuv_frames(yuv_data, num_frames, width, height, bit_depth, pixel_format, frame_size, multiplier, convert_to_rgb) + return frames \ No newline at end of file diff --git a/swinir_dir.py b/swinir_dir.py new file mode 100644 index 000000000..959ece526 --- /dev/null +++ b/swinir_dir.py @@ -0,0 +1,57 @@ +import os +import subprocess +import argparse +from tqdm import tqdm + + +def swinir_dir(input_dir, output_dir, scaling_factor=4): + files = os.listdir(input_dir) + + # Chose model and file identification letter + + if scaling_factor == 2: + width, height = 1920, 1088 + letter = 'B' + model = 'experiments/pretrained_models/swinir/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth' + elif scaling_factor == 4: + width, height = 960, 544 + letter = 'C' + model = 'experiments/pretrained_models/swinir/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth' + elif scaling_factor == 8: + width, height = 480, 272 + letter = 'D' + model = 'experiments/pretrained_models/swinir/001_classicalSR_DF2K_s64w8_SwinIR-M_x8.pth' + + for file in tqdm(files, desc='Processing videos'): + + if file[0] != letter: + continue + + input_file = os.path.join(input_dir, file) + output_file = os.path.join(output_dir, file) + tqdm.write('Processing {0}'.format(file)) + cmd = [ + 'python3', 'inference/swin_yuv.py', + '--model_path', model, + '--input', input_file, + '--output', output_file, + '--num_frames', '64', + '--width', str(width), + '--height', str(height), + '--scale', str(scaling_factor) + ] + + subprocess.run(cmd) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--input', type=str, default='datasets/swinir.yuv', help='input test image folder') + parser.add_argument('--output', type=str, default='results/swinir.yuv', help='output folder') + parser.add_argument('--scaling_factor', type=int, default=4, help='scaling factor') + args = parser.parse_args() + + if not os.path.exists(args.output): + os.makedirs(args.output) + + swinir_dir(args.input, args.output, args.scaling_factor) \ No newline at end of file