Skip to content

Commit c59490a

Browse files
xinntaogagajian
andauthored
Add BasicVSR and IconVSR (XPixelGroup#427)
* add basicvsr, iconvsr * fixed bugs (XPixelGroup#424) * update .gitignore * minor fix * Add BasicVSR-GAN model code; add BasicVSR inference code (XPixelGroup#426) * fixed bugs * Add BasicVSR inference code * Add python bicubic downsampling code * Add BasicVSR_GAN model * Minor improvement in BasicVSR_GAN model * Format code * remove DS_store * format code * format code * merge * reorganize * reorganize * minor update * update readme * update datasets * update publish_models * add BasicVSR * update readme * fix lgtm warns Co-authored-by: gagajian <32482923+gagajian@users.noreply.github.com>
1 parent 35d53ad commit c59490a

26 files changed

+2186
-19
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ tb_logger/*
66
wandb/*
77
tmp/*
88

9+
*.DS_Store
910
.vscode
1011
.idea
1112

README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ BasicSR (**Basic** **S**uper **R**estoration) is an open source **image and vide
2525

2626
## :sparkles: New Features
2727

28-
- Nov 29, 2020. Add **ESRGAN** and **DFDNet** [colab demo](colab).
28+
- July 31, 2021. Add **bi-directional video super-resolution** codes: [**BasicVSR** and IconVSR](https://arxiv.org/abs/2012.02181).
29+
- July 20, 2021. Add **dual-blind face restoration** codes: [HiFaceGAN](https://github.com/Lotayou/Face-Renovation) codes by [Lotayou](https://lotayou.github.io/).
30+
- Nov 29, 2020. Add **ESRGAN** and **DFDNet** [colab demo](colab)
2931
- Sep 8, 2020. Add **blind face restoration** inference codes: [DFDNet](https://github.com/csxmli2016/DFDNet).
3032
- Aug 27, 2020. Add **StyleGAN2 training and testing** codes: [StyleGAN2](https://github.com/rosinality/stylegan2-pytorch).
3133

README_CN.md

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源
2525

2626
## :sparkles: 新的特性
2727

28+
- July 31, 2021. Add **bi-directional video super-resolution** codes: [**BasicVSR** and IconVSR](https://arxiv.org/abs/2012.02181).
29+
- July 20, 2021. Add **dual-blind face restoration** codes: [**HiFaceGAN**](https://github.com/Lotayou/Face-Renovation) codes by [Lotayou](https://lotayou.github.io/).
2830
- Nov 29, 2020. 添加 **ESRGAN** and **DFDNet** [colab demo](colab).
2931
- Sep 8, 2020. 添加 **盲人脸复原**测试代码: [DFDNet](https://github.com/csxmli2016/DFDNet).
3032
- Aug 27, 2020. 添加 **StyleGAN2 训练和测试** 代码: [StyleGAN2](https://github.com/rosinality/stylegan2-pytorch).

basicsr/archs/basicvsr_arch.py

+309
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
import torch
2+
from torch import nn as nn
3+
from torch.nn import functional as F
4+
5+
from basicsr.utils.registry import ARCH_REGISTRY
6+
from .arch_util import ResidualBlockNoBN, flow_warp, make_layer
7+
from .edvr_arch import PCDAlignment, TSAFusion
8+
from .spynet_arch import SpyNet
9+
10+
11+
@ARCH_REGISTRY.register()
12+
class BasicVSR(nn.Module):
13+
"""A recurrent network for video SR. Now only x4 is supported.
14+
15+
Args:
16+
num_feat (int): Number of channels. Default: 64.
17+
num_block (int): Number of residual blocks for each branch. Default: 15
18+
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
19+
"""
20+
21+
def __init__(self, num_feat=64, num_block=15, spynet_path=None):
22+
super().__init__()
23+
self.num_feat = num_feat
24+
25+
# alignment
26+
self.spynet = SpyNet(spynet_path)
27+
28+
# propagation
29+
self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
30+
self.forward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
31+
32+
# reconstruction
33+
self.fusion = nn.Conv2d(num_feat * 2, num_feat, 1, 1, 0, bias=True)
34+
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
35+
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
36+
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
37+
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
38+
39+
self.pixel_shuffle = nn.PixelShuffle(2)
40+
41+
# activation functions
42+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
43+
44+
def get_flow(self, x):
45+
b, n, c, h, w = x.size()
46+
47+
x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
48+
x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
49+
50+
flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
51+
flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
52+
53+
return flows_forward, flows_backward
54+
55+
def forward(self, x):
56+
flows_forward, flows_backward = self.get_flow(x)
57+
b, n, _, h, w = x.size()
58+
59+
# backward branch
60+
out_l = []
61+
feat_prop = x.new_zeros(b, self.num_feat, h, w)
62+
for i in range(n - 1, -1, -1):
63+
x_i = x[:, i, :, :, :]
64+
if i < n - 1:
65+
flow = flows_backward[:, i, :, :, :]
66+
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
67+
feat_prop = torch.cat([x_i, feat_prop], dim=1)
68+
feat_prop = self.backward_trunk(feat_prop)
69+
out_l.insert(0, feat_prop)
70+
71+
# forward branch
72+
feat_prop = torch.zeros_like(feat_prop)
73+
for i in range(0, n):
74+
x_i = x[:, i, :, :, :]
75+
if i > 0:
76+
flow = flows_forward[:, i - 1, :, :, :]
77+
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
78+
79+
feat_prop = torch.cat([x_i, feat_prop], dim=1)
80+
feat_prop = self.forward_trunk(feat_prop)
81+
82+
# upsample
83+
out = torch.cat([out_l[i], feat_prop], dim=1)
84+
out = self.lrelu(self.fusion(out))
85+
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
86+
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
87+
out = self.lrelu(self.conv_hr(out))
88+
out = self.conv_last(out)
89+
base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
90+
out += base
91+
out_l[i] = out
92+
93+
return torch.stack(out_l, dim=1)
94+
95+
96+
class ConvResidualBlocks(nn.Module):
97+
98+
def __init__(self, num_in_ch=3, num_out_ch=64, num_block=15):
99+
super().__init__()
100+
self.main = nn.Sequential(
101+
nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True),
102+
make_layer(ResidualBlockNoBN, num_block, num_feat=num_out_ch))
103+
104+
def forward(self, fea):
105+
return self.main(fea)
106+
107+
108+
@ARCH_REGISTRY.register()
109+
class IconVSR(nn.Module):
110+
"""IconVSR, proposed also in the BasicVSR paper
111+
"""
112+
113+
def __init__(self,
114+
num_feat=64,
115+
num_block=15,
116+
keyframe_stride=5,
117+
temporal_padding=2,
118+
spynet_path=None,
119+
edvr_path=None):
120+
super().__init__()
121+
122+
self.num_feat = num_feat
123+
self.temporal_padding = temporal_padding
124+
self.keyframe_stride = keyframe_stride
125+
126+
# keyframe_branch
127+
self.edvr = EDVRFeatureExtractor(temporal_padding * 2 + 1, num_feat, edvr_path)
128+
# alignment
129+
self.spynet = SpyNet(spynet_path)
130+
131+
# propagation
132+
self.backward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
133+
self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
134+
135+
self.forward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
136+
self.forward_trunk = ConvResidualBlocks(2 * num_feat + 3, num_feat, num_block)
137+
138+
# reconstruction
139+
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
140+
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
141+
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
142+
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
143+
144+
self.pixel_shuffle = nn.PixelShuffle(2)
145+
146+
# activation functions
147+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
148+
149+
def pad_spatial(self, x):
150+
""" Apply padding spatially.
151+
152+
Since the PCD module in EDVR requires that the resolution is a multiple
153+
of 4, we apply padding to the input LR images if their resolution is
154+
not divisible by 4.
155+
156+
Args:
157+
x (Tensor): Input LR sequence with shape (n, t, c, h, w).
158+
Returns:
159+
Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad).
160+
"""
161+
n, t, c, h, w = x.size()
162+
163+
pad_h = (4 - h % 4) % 4
164+
pad_w = (4 - w % 4) % 4
165+
166+
# padding
167+
x = x.view(-1, c, h, w)
168+
x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
169+
170+
return x.view(n, t, c, h + pad_h, w + pad_w)
171+
172+
def get_flow(self, x):
173+
b, n, c, h, w = x.size()
174+
175+
x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
176+
x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
177+
178+
flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
179+
flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
180+
181+
return flows_forward, flows_backward
182+
183+
def get_keyframe_feature(self, x, keyframe_idx):
184+
if self.temporal_padding == 2:
185+
x = [x[:, [4, 3]], x, x[:, [-4, -5]]]
186+
elif self.temporal_padding == 3:
187+
x = [x[:, [6, 5, 4]], x, x[:, [-5, -6, -7]]]
188+
x = torch.cat(x, dim=1)
189+
190+
num_frames = 2 * self.temporal_padding + 1
191+
feats_keyframe = {}
192+
for i in keyframe_idx:
193+
feats_keyframe[i] = self.edvr(x[:, i:i + num_frames].contiguous())
194+
return feats_keyframe
195+
196+
def forward(self, x):
197+
b, n, _, h_input, w_input = x.size()
198+
199+
x = self.pad_spatial(x)
200+
h, w = x.shape[3:]
201+
202+
keyframe_idx = list(range(0, n, self.keyframe_stride))
203+
if keyframe_idx[-1] != n - 1:
204+
keyframe_idx.append(n - 1) # last frame is a keyframe
205+
206+
# compute flow and keyframe features
207+
flows_forward, flows_backward = self.get_flow(x)
208+
feats_keyframe = self.get_keyframe_feature(x, keyframe_idx)
209+
210+
# backward branch
211+
out_l = []
212+
feat_prop = x.new_zeros(b, self.num_feat, h, w)
213+
for i in range(n - 1, -1, -1):
214+
x_i = x[:, i, :, :, :]
215+
if i < n - 1:
216+
flow = flows_backward[:, i, :, :, :]
217+
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
218+
if i in keyframe_idx:
219+
feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
220+
feat_prop = self.backward_fusion(feat_prop)
221+
feat_prop = torch.cat([x_i, feat_prop], dim=1)
222+
feat_prop = self.backward_trunk(feat_prop)
223+
out_l.insert(0, feat_prop)
224+
225+
# forward branch
226+
feat_prop = torch.zeros_like(feat_prop)
227+
for i in range(0, n):
228+
x_i = x[:, i, :, :, :]
229+
if i > 0:
230+
flow = flows_forward[:, i - 1, :, :, :]
231+
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
232+
if i in keyframe_idx:
233+
feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
234+
feat_prop = self.forward_fusion(feat_prop)
235+
236+
feat_prop = torch.cat([x_i, out_l[i], feat_prop], dim=1)
237+
feat_prop = self.forward_trunk(feat_prop)
238+
239+
# upsample
240+
out = self.lrelu(self.pixel_shuffle(self.upconv1(feat_prop)))
241+
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
242+
out = self.lrelu(self.conv_hr(out))
243+
out = self.conv_last(out)
244+
base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
245+
out += base
246+
out_l[i] = out
247+
248+
return torch.stack(out_l, dim=1)[..., :4 * h_input, :4 * w_input]
249+
250+
251+
class EDVRFeatureExtractor(nn.Module):
252+
253+
def __init__(self, num_input_frame, num_feat, load_path):
254+
255+
super(EDVRFeatureExtractor, self).__init__()
256+
257+
self.center_frame_idx = num_input_frame // 2
258+
259+
# extrat pyramid features
260+
self.conv_first = nn.Conv2d(3, num_feat, 3, 1, 1)
261+
self.feature_extraction = make_layer(ResidualBlockNoBN, 5, num_feat=64)
262+
self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
263+
self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
264+
self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
265+
self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
266+
267+
# pcd and tsa module
268+
self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=8)
269+
self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_input_frame, center_frame_idx=self.center_frame_idx)
270+
271+
# activation function
272+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
273+
274+
if load_path:
275+
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
276+
277+
def forward(self, x):
278+
b, n, c, h, w = x.size()
279+
280+
# extract features for each frame
281+
# L1
282+
feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
283+
feat_l1 = self.feature_extraction(feat_l1)
284+
# L2
285+
feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
286+
feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
287+
# L3
288+
feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
289+
feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
290+
291+
feat_l1 = feat_l1.view(b, n, -1, h, w)
292+
feat_l2 = feat_l2.view(b, n, -1, h // 2, w // 2)
293+
feat_l3 = feat_l3.view(b, n, -1, h // 4, w // 4)
294+
295+
# PCD alignment
296+
ref_feat_l = [ # reference feature list
297+
feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
298+
feat_l3[:, self.center_frame_idx, :, :, :].clone()
299+
]
300+
aligned_feat = []
301+
for i in range(n):
302+
nbr_feat_l = [ # neighboring feature list
303+
feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
304+
]
305+
aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
306+
aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
307+
308+
# TSA fusion
309+
return self.fusion(aligned_feat)

0 commit comments

Comments
 (0)