Skip to content

Commit 29309ab

Browse files
committed
ECBSR support RGB training
1 parent 9f8b6a2 commit 29309ab

File tree

2 files changed

+146
-1
lines changed

2 files changed

+146
-1
lines changed

basicsr/archs/ecbsr_arch.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@ class ECBSR(nn.Module):
227227

228228
def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_type, scale):
229229
super(ECBSR, self).__init__()
230+
self.num_in_ch = num_in_ch
231+
self.scale = scale
230232

231233
backbone = []
232234
backbone += [ECB(num_in_ch, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
@@ -240,6 +242,10 @@ def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_
240242
self.upsampler = nn.PixelShuffle(scale)
241243

242244
def forward(self, x):
243-
y = self.backbone(x) + x # will repeat the input in the channel dimension (repeat scale * scale times)
245+
if self.num_in_ch > 1:
246+
shortcut = torch.repeat_interleave(x, self.scale * self.scale, dim=1)
247+
else:
248+
shortcut = x # will repeat the input in the channel dimension (repeat scale * scale times)
249+
y = self.backbone(x) + shortcut
244250
y = self.upsampler(y)
245251
return y
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# general settings
2+
name: 100_train_ECBSR_x4_m4c16_prelu_RGB
3+
model_type: SRModel
4+
scale: 4
5+
num_gpu: 1 # set num_gpu: 0 for cpu mode
6+
manual_seed: 0
7+
8+
# dataset and data loader settings
9+
datasets:
10+
train:
11+
name: DIV2K
12+
type: PairedImageDataset
13+
# It is strongly recommended to use lmdb for faster IO speed, especially for small networks
14+
dataroot_gt: datasets/DF2K/DIV2K_train_HR_sub.lmdb
15+
dataroot_lq: datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub.lmdb
16+
meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt
17+
filename_tmpl: '{}'
18+
io_backend:
19+
type: lmdb
20+
21+
gt_size: 256
22+
use_flip: true
23+
use_rot: true
24+
25+
# data loader
26+
use_shuffle: true
27+
num_worker_per_gpu: 12
28+
batch_size_per_gpu: 32
29+
dataset_enlarge_ratio: 10
30+
prefetch_mode: ~
31+
32+
# we use multiple validation datasets. The SR benchmark datasets can be download from: https://cv.snu.ac.kr/research/EDSR/benchmark.tar
33+
val:
34+
name: Set5
35+
type: PairedImageDataset
36+
dataroot_gt: datasets/benchmark/Set5/HR
37+
dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
38+
filename_tmpl: '{}x4'
39+
io_backend:
40+
type: disk
41+
42+
val_2:
43+
name: Set14
44+
type: PairedImageDataset
45+
dataroot_gt: datasets/benchmark/Set14/HR
46+
dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4
47+
filename_tmpl: '{}x4'
48+
io_backend:
49+
type: disk
50+
51+
val_3:
52+
name: B100
53+
type: PairedImageDataset
54+
dataroot_gt: datasets/benchmark/B100/HR
55+
dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4
56+
filename_tmpl: '{}x4'
57+
io_backend:
58+
type: disk
59+
60+
val_4:
61+
name: Urban100
62+
type: PairedImageDataset
63+
dataroot_gt: datasets/benchmark/Urban100/HR
64+
dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4
65+
filename_tmpl: '{}x4'
66+
io_backend:
67+
type: disk
68+
69+
# network structures
70+
network_g:
71+
type: ECBSR
72+
num_in_ch: 3
73+
num_out_ch: 3
74+
num_block: 4
75+
num_channel: 16
76+
with_idt: False
77+
act_type: prelu
78+
scale: 4
79+
80+
# path
81+
path:
82+
pretrain_network_g: ~
83+
strict_load_g: true
84+
resume_state: ~
85+
86+
# training settings
87+
train:
88+
ema_decay: 0
89+
optim_g:
90+
type: Adam
91+
lr: !!float 5e-4
92+
weight_decay: 0
93+
betas: [0.9, 0.99]
94+
95+
scheduler:
96+
type: MultiStepLR
97+
milestones: [1600000]
98+
gamma: 1
99+
100+
total_iter: 1600000
101+
warmup_iter: -1 # no warm up
102+
103+
# losses
104+
pixel_opt:
105+
type: L1Loss
106+
loss_weight: 1.0
107+
reduction: mean
108+
109+
# validation settings
110+
val:
111+
val_freq: !!float 1600 # the same as the original setting. # TODO: Can be larger
112+
save_img: false
113+
pbar: False
114+
115+
metrics:
116+
psnr:
117+
type: calculate_psnr
118+
crop_border: 4
119+
test_y_channel: true
120+
better: higher # the higher, the better. Default: higher
121+
ssim:
122+
type: calculate_ssim
123+
crop_border: 4
124+
test_y_channel: true
125+
better: higher # the higher, the better. Default: higher
126+
127+
# logging settings
128+
logger:
129+
print_freq: 100
130+
save_checkpoint_freq: !!float 1600
131+
use_tb_logger: true
132+
wandb:
133+
project: ~
134+
resume_id: ~
135+
136+
# dist training settings
137+
dist_params:
138+
backend: nccl
139+
port: 29500

0 commit comments

Comments
 (0)