|
| 1 | +import torch |
| 2 | +from torch import nn as nn |
| 3 | +from torch.nn import functional as F |
| 4 | + |
| 5 | +from basicsr.utils.registry import ARCH_REGISTRY |
| 6 | +from .arch_util import ResidualBlockNoBN, flow_warp, make_layer |
| 7 | +from .edvr_arch import PCDAlignment, TSAFusion |
| 8 | +from .spynet_arch import SpyNet |
| 9 | + |
| 10 | + |
| 11 | +@ARCH_REGISTRY.register() |
| 12 | +class BasicVSR(nn.Module): |
| 13 | + """A recurrent network for video SR. Now only x4 is supported. |
| 14 | +
|
| 15 | + Args: |
| 16 | + num_feat (int): Number of channels. Default: 64. |
| 17 | + num_block (int): Number of residual blocks for each branch. Default: 15 |
| 18 | + spynet_path (str): Path to the pretrained weights of SPyNet. Default: None. |
| 19 | + """ |
| 20 | + |
| 21 | + def __init__(self, num_feat=64, num_block=15, spynet_path=None): |
| 22 | + super().__init__() |
| 23 | + self.num_feat = num_feat |
| 24 | + |
| 25 | + # alignment |
| 26 | + self.spynet = SpyNet(spynet_path) |
| 27 | + |
| 28 | + # propagation |
| 29 | + self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block) |
| 30 | + self.forward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block) |
| 31 | + |
| 32 | + # reconstruction |
| 33 | + self.fusion = nn.Conv2d(num_feat * 2, num_feat, 1, 1, 0, bias=True) |
| 34 | + self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True) |
| 35 | + self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True) |
| 36 | + self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1) |
| 37 | + self.conv_last = nn.Conv2d(64, 3, 3, 1, 1) |
| 38 | + |
| 39 | + self.pixel_shuffle = nn.PixelShuffle(2) |
| 40 | + |
| 41 | + # activation functions |
| 42 | + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) |
| 43 | + |
| 44 | + def get_flow(self, x): |
| 45 | + b, n, c, h, w = x.size() |
| 46 | + |
| 47 | + x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w) |
| 48 | + x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w) |
| 49 | + |
| 50 | + flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w) |
| 51 | + flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w) |
| 52 | + |
| 53 | + return flows_forward, flows_backward |
| 54 | + |
| 55 | + def forward(self, x): |
| 56 | + flows_forward, flows_backward = self.get_flow(x) |
| 57 | + b, n, _, h, w = x.size() |
| 58 | + |
| 59 | + # backward branch |
| 60 | + out_l = [] |
| 61 | + feat_prop = x.new_zeros(b, self.num_feat, h, w) |
| 62 | + for i in range(n - 1, -1, -1): |
| 63 | + x_i = x[:, i, :, :, :] |
| 64 | + if i < n - 1: |
| 65 | + flow = flows_backward[:, i, :, :, :] |
| 66 | + feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) |
| 67 | + feat_prop = torch.cat([x_i, feat_prop], dim=1) |
| 68 | + feat_prop = self.backward_trunk(feat_prop) |
| 69 | + out_l.insert(0, feat_prop) |
| 70 | + |
| 71 | + # forward branch |
| 72 | + feat_prop = torch.zeros_like(feat_prop) |
| 73 | + for i in range(0, n): |
| 74 | + x_i = x[:, i, :, :, :] |
| 75 | + if i > 0: |
| 76 | + flow = flows_forward[:, i - 1, :, :, :] |
| 77 | + feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) |
| 78 | + |
| 79 | + feat_prop = torch.cat([x_i, feat_prop], dim=1) |
| 80 | + feat_prop = self.forward_trunk(feat_prop) |
| 81 | + |
| 82 | + # upsample |
| 83 | + out = torch.cat([out_l[i], feat_prop], dim=1) |
| 84 | + out = self.lrelu(self.fusion(out)) |
| 85 | + out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) |
| 86 | + out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) |
| 87 | + out = self.lrelu(self.conv_hr(out)) |
| 88 | + out = self.conv_last(out) |
| 89 | + base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False) |
| 90 | + out += base |
| 91 | + out_l[i] = out |
| 92 | + |
| 93 | + return torch.stack(out_l, dim=1) |
| 94 | + |
| 95 | + |
| 96 | +class ConvResidualBlocks(nn.Module): |
| 97 | + |
| 98 | + def __init__(self, num_in_ch=3, num_out_ch=64, num_block=15): |
| 99 | + super().__init__() |
| 100 | + self.main = nn.Sequential( |
| 101 | + nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True), |
| 102 | + make_layer(ResidualBlockNoBN, num_block, num_feat=num_out_ch)) |
| 103 | + |
| 104 | + def forward(self, fea): |
| 105 | + return self.main(fea) |
| 106 | + |
| 107 | + |
| 108 | +@ARCH_REGISTRY.register() |
| 109 | +class IconVSR(nn.Module): |
| 110 | + """IconVSR, proposed also in the BasicVSR paper |
| 111 | + """ |
| 112 | + |
| 113 | + def __init__(self, |
| 114 | + num_feat=64, |
| 115 | + num_block=15, |
| 116 | + keyframe_stride=5, |
| 117 | + temporal_padding=2, |
| 118 | + spynet_path=None, |
| 119 | + edvr_path=None): |
| 120 | + super().__init__() |
| 121 | + |
| 122 | + self.num_feat = num_feat |
| 123 | + self.temporal_padding = temporal_padding |
| 124 | + self.keyframe_stride = keyframe_stride |
| 125 | + |
| 126 | + # keyframe_branch |
| 127 | + self.edvr = EDVRFeatureExtractor(temporal_padding * 2 + 1, num_feat, edvr_path) |
| 128 | + # alignment |
| 129 | + self.spynet = SpyNet(spynet_path) |
| 130 | + |
| 131 | + # propagation |
| 132 | + self.backward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True) |
| 133 | + self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block) |
| 134 | + |
| 135 | + self.forward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True) |
| 136 | + self.forward_trunk = ConvResidualBlocks(2 * num_feat + 3, num_feat, num_block) |
| 137 | + |
| 138 | + # reconstruction |
| 139 | + self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True) |
| 140 | + self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True) |
| 141 | + self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1) |
| 142 | + self.conv_last = nn.Conv2d(64, 3, 3, 1, 1) |
| 143 | + |
| 144 | + self.pixel_shuffle = nn.PixelShuffle(2) |
| 145 | + |
| 146 | + # activation functions |
| 147 | + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) |
| 148 | + |
| 149 | + def pad_spatial(self, x): |
| 150 | + """ Apply padding spatially. |
| 151 | +
|
| 152 | + Since the PCD module in EDVR requires that the resolution is a multiple |
| 153 | + of 4, we apply padding to the input LR images if their resolution is |
| 154 | + not divisible by 4. |
| 155 | +
|
| 156 | + Args: |
| 157 | + x (Tensor): Input LR sequence with shape (n, t, c, h, w). |
| 158 | + Returns: |
| 159 | + Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad). |
| 160 | + """ |
| 161 | + n, t, c, h, w = x.size() |
| 162 | + |
| 163 | + pad_h = (4 - h % 4) % 4 |
| 164 | + pad_w = (4 - w % 4) % 4 |
| 165 | + |
| 166 | + # padding |
| 167 | + x = x.view(-1, c, h, w) |
| 168 | + x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect') |
| 169 | + |
| 170 | + return x.view(n, t, c, h + pad_h, w + pad_w) |
| 171 | + |
| 172 | + def get_flow(self, x): |
| 173 | + b, n, c, h, w = x.size() |
| 174 | + |
| 175 | + x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w) |
| 176 | + x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w) |
| 177 | + |
| 178 | + flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w) |
| 179 | + flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w) |
| 180 | + |
| 181 | + return flows_forward, flows_backward |
| 182 | + |
| 183 | + def get_keyframe_feature(self, x, keyframe_idx): |
| 184 | + if self.temporal_padding == 2: |
| 185 | + x = [x[:, [4, 3]], x, x[:, [-4, -5]]] |
| 186 | + elif self.temporal_padding == 3: |
| 187 | + x = [x[:, [6, 5, 4]], x, x[:, [-5, -6, -7]]] |
| 188 | + x = torch.cat(x, dim=1) |
| 189 | + |
| 190 | + num_frames = 2 * self.temporal_padding + 1 |
| 191 | + feats_keyframe = {} |
| 192 | + for i in keyframe_idx: |
| 193 | + feats_keyframe[i] = self.edvr(x[:, i:i + num_frames].contiguous()) |
| 194 | + return feats_keyframe |
| 195 | + |
| 196 | + def forward(self, x): |
| 197 | + b, n, _, h_input, w_input = x.size() |
| 198 | + |
| 199 | + x = self.pad_spatial(x) |
| 200 | + h, w = x.shape[3:] |
| 201 | + |
| 202 | + keyframe_idx = list(range(0, n, self.keyframe_stride)) |
| 203 | + if keyframe_idx[-1] != n - 1: |
| 204 | + keyframe_idx.append(n - 1) # last frame is a keyframe |
| 205 | + |
| 206 | + # compute flow and keyframe features |
| 207 | + flows_forward, flows_backward = self.get_flow(x) |
| 208 | + feats_keyframe = self.get_keyframe_feature(x, keyframe_idx) |
| 209 | + |
| 210 | + # backward branch |
| 211 | + out_l = [] |
| 212 | + feat_prop = x.new_zeros(b, self.num_feat, h, w) |
| 213 | + for i in range(n - 1, -1, -1): |
| 214 | + x_i = x[:, i, :, :, :] |
| 215 | + if i < n - 1: |
| 216 | + flow = flows_backward[:, i, :, :, :] |
| 217 | + feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) |
| 218 | + if i in keyframe_idx: |
| 219 | + feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1) |
| 220 | + feat_prop = self.backward_fusion(feat_prop) |
| 221 | + feat_prop = torch.cat([x_i, feat_prop], dim=1) |
| 222 | + feat_prop = self.backward_trunk(feat_prop) |
| 223 | + out_l.insert(0, feat_prop) |
| 224 | + |
| 225 | + # forward branch |
| 226 | + feat_prop = torch.zeros_like(feat_prop) |
| 227 | + for i in range(0, n): |
| 228 | + x_i = x[:, i, :, :, :] |
| 229 | + if i > 0: |
| 230 | + flow = flows_forward[:, i - 1, :, :, :] |
| 231 | + feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) |
| 232 | + if i in keyframe_idx: |
| 233 | + feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1) |
| 234 | + feat_prop = self.forward_fusion(feat_prop) |
| 235 | + |
| 236 | + feat_prop = torch.cat([x_i, out_l[i], feat_prop], dim=1) |
| 237 | + feat_prop = self.forward_trunk(feat_prop) |
| 238 | + |
| 239 | + # upsample |
| 240 | + out = self.lrelu(self.pixel_shuffle(self.upconv1(feat_prop))) |
| 241 | + out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) |
| 242 | + out = self.lrelu(self.conv_hr(out)) |
| 243 | + out = self.conv_last(out) |
| 244 | + base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False) |
| 245 | + out += base |
| 246 | + out_l[i] = out |
| 247 | + |
| 248 | + return torch.stack(out_l, dim=1)[..., :4 * h_input, :4 * w_input] |
| 249 | + |
| 250 | + |
| 251 | +class EDVRFeatureExtractor(nn.Module): |
| 252 | + |
| 253 | + def __init__(self, num_input_frame, num_feat, load_path): |
| 254 | + |
| 255 | + super(EDVRFeatureExtractor, self).__init__() |
| 256 | + |
| 257 | + self.center_frame_idx = num_input_frame // 2 |
| 258 | + |
| 259 | + # extrat pyramid features |
| 260 | + self.conv_first = nn.Conv2d(3, num_feat, 3, 1, 1) |
| 261 | + self.feature_extraction = make_layer(ResidualBlockNoBN, 5, num_feat=64) |
| 262 | + self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) |
| 263 | + self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) |
| 264 | + self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1) |
| 265 | + self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) |
| 266 | + |
| 267 | + # pcd and tsa module |
| 268 | + self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=8) |
| 269 | + self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_input_frame, center_frame_idx=self.center_frame_idx) |
| 270 | + |
| 271 | + # activation function |
| 272 | + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) |
| 273 | + |
| 274 | + if load_path: |
| 275 | + self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) |
| 276 | + |
| 277 | + def forward(self, x): |
| 278 | + b, n, c, h, w = x.size() |
| 279 | + |
| 280 | + # extract features for each frame |
| 281 | + # L1 |
| 282 | + feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w))) |
| 283 | + feat_l1 = self.feature_extraction(feat_l1) |
| 284 | + # L2 |
| 285 | + feat_l2 = self.lrelu(self.conv_l2_1(feat_l1)) |
| 286 | + feat_l2 = self.lrelu(self.conv_l2_2(feat_l2)) |
| 287 | + # L3 |
| 288 | + feat_l3 = self.lrelu(self.conv_l3_1(feat_l2)) |
| 289 | + feat_l3 = self.lrelu(self.conv_l3_2(feat_l3)) |
| 290 | + |
| 291 | + feat_l1 = feat_l1.view(b, n, -1, h, w) |
| 292 | + feat_l2 = feat_l2.view(b, n, -1, h // 2, w // 2) |
| 293 | + feat_l3 = feat_l3.view(b, n, -1, h // 4, w // 4) |
| 294 | + |
| 295 | + # PCD alignment |
| 296 | + ref_feat_l = [ # reference feature list |
| 297 | + feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(), |
| 298 | + feat_l3[:, self.center_frame_idx, :, :, :].clone() |
| 299 | + ] |
| 300 | + aligned_feat = [] |
| 301 | + for i in range(n): |
| 302 | + nbr_feat_l = [ # neighboring feature list |
| 303 | + feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone() |
| 304 | + ] |
| 305 | + aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l)) |
| 306 | + aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w) |
| 307 | + |
| 308 | + # TSA fusion |
| 309 | + return self.fusion(aligned_feat) |
0 commit comments