Skip to content

Commit 65b797b

Browse files
authored
Merge commits from BasicSR-private (XPixelGroup#263)
* add get_bare_model * add ffhq dataset * rm NoneDict * add test_ffhq_dataset * update train test commands
1 parent f3aeb69 commit 65b797b

File tree

8 files changed

+164
-45
lines changed

8 files changed

+164
-45
lines changed

basicsr/data/ffhq_dataset.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import mmcv
2+
import numpy as np
3+
from os import path as osp
4+
from torch.utils import data as data
5+
from torchvision.transforms.functional import normalize
6+
7+
from basicsr.data.transforms import augment, totensor
8+
from basicsr.utils import FileClient
9+
10+
11+
class FFHQDataset(data.Dataset):
12+
"""FFHQ dataset for StyleGAN2.
13+
14+
Args:
15+
opt (dict): Config for train datasets. It contains the following keys:
16+
dataroot_gt (str): Data root path for gt.
17+
io_backend (dict): IO backend type and other kwarg.
18+
"""
19+
20+
def __init__(self, opt):
21+
super(FFHQDataset, self).__init__()
22+
self.opt = opt
23+
# file client (io backend)
24+
self.file_client = None
25+
self.io_backend_opt = opt['io_backend']
26+
27+
self.gt_folder = opt['dataroot_gt']
28+
self.mean = opt['mean']
29+
self.std = opt['std']
30+
31+
if self.io_backend_opt['type'] == 'lmdb':
32+
self.io_backend_opt['db_paths'] = self.gt_folder
33+
if not self.gt_folder.endswith('.lmdb'):
34+
raise ValueError("'dataroot_gt' should end with '.lmdb', "
35+
f'but received {self.gt_folder}')
36+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
37+
self.paths = [line.split('.')[0] for line in fin]
38+
else:
39+
self.paths = [
40+
osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)
41+
]
42+
43+
def __getitem__(self, index):
44+
if self.file_client is None:
45+
self.file_client = FileClient(
46+
self.io_backend_opt.pop('type'), **self.io_backend_opt)
47+
48+
# load gt image
49+
gt_path = self.paths[index]
50+
img_bytes = self.file_client.get(gt_path)
51+
img_gt = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255.
52+
53+
# random horizontal flip
54+
img_gt = augment([img_gt], hflip=self.opt['use_hflip'], rotation=False)
55+
# BGR to RGB, HWC to CHW, numpy to tensor
56+
img_gt = totensor(img_gt, bgr2rgb=True, float32=True)
57+
# normalize
58+
normalize(img_gt, self.mean, self.std, inplace=True)
59+
return {'gt': img_gt, 'gt_path': gt_path}
60+
61+
def __len__(self):
62+
return len(self.paths)

basicsr/models/base_model.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,13 @@ def setup_schedulers(self):
100100
raise NotImplementedError(
101101
f'Scheduler {scheduler_type} is not implemented yet.')
102102

103-
def _get_network_description(self, net):
104-
"""Get the string and total parameters of the network"""
103+
def get_bare_model(self, net):
104+
"""Get bare model, especially under wrapping with
105+
DistributedDataParallel or DataParallel.
106+
"""
105107
if isinstance(net, (DataParallel, DistributedDataParallel)):
106108
net = net.module
107-
return str(net), sum(map(lambda x: x.numel(), net.parameters()))
109+
return net
108110

109111
@master_only
110112
def print_network(self, net):
@@ -113,13 +115,16 @@ def print_network(self, net):
113115
Args:
114116
net (nn.Module)
115117
"""
116-
net_str, net_params = self._get_network_description(net)
117118
if isinstance(net, (DataParallel, DistributedDataParallel)):
118119
net_cls_str = (f'{net.__class__.__name__} - '
119120
f'{net.module.__class__.__name__}')
120121
else:
121122
net_cls_str = f'{net.__class__.__name__}'
122123

124+
net = self.get_bare_model(net)
125+
net_str = str(net)
126+
net_params = sum(map(lambda x: x.numel(), net.parameters()))
127+
123128
logger.info(
124129
f'Network: {net_cls_str}, with parameters: {net_params:,d}')
125130
logger.info(net_str)
@@ -255,10 +260,9 @@ def load_network(self, net, load_path, strict=True, param_key='params'):
255260
param_key (str): The parameter key of loaded network.
256261
Default: 'params'.
257262
"""
258-
if isinstance(net, (DataParallel, DistributedDataParallel)):
259-
net = net.module
260-
net_cls_name = net.__class__.__name__
261-
logger.info(f'Loading {net_cls_name} model from {load_path}.')
263+
net = self.get_bare_model(net)
264+
logger.info(
265+
f'Loading {net.__class__.__name__} model from {load_path}.')
262266
load_net = torch.load(load_path)[param_key]
263267
# remove unnecessary 'module.'
264268
for k, v in load_net.items():

basicsr/test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from basicsr.data import create_dataloader, create_dataset
77
from basicsr.models import create_model
88
from basicsr.utils import get_env_info, get_root_logger, make_exp_dirs
9-
from basicsr.utils.options import dict2str, dict_to_nonedict, parse
9+
from basicsr.utils.options import dict2str, parse
1010

1111

1212
def main():
@@ -34,8 +34,6 @@ def main():
3434
else:
3535
init_dist(args.launcher)
3636

37-
opt = dict_to_nonedict(opt)
38-
3937
make_exp_dirs(opt)
4038
log_file = osp.join(opt['path']['log'],
4139
f"test_{opt['name']}_{get_time_str()}.log")

basicsr/train.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from basicsr.utils import (MessageLogger, check_resume, get_env_info,
1414
get_root_logger, init_tb_logger, init_wandb_logger,
1515
make_exp_dirs, set_random_seed)
16-
from basicsr.utils.options import dict2str, dict_to_nonedict, parse
16+
from basicsr.utils.options import dict2str, parse
1717

1818

1919
def main():
@@ -52,9 +52,6 @@ def main():
5252
else:
5353
resume_state = None
5454

55-
# convert to NoneDict, which returns None for missing keys
56-
opt = dict_to_nonedict(opt)
57-
5855
# mkdir and loggers
5956
if resume_state is None:
6057
make_exp_dirs(opt)
@@ -82,6 +79,7 @@ def main():
8279
# torch.backends.cudnn.deterministic = True
8380

8481
# create train and val dataloaders
82+
train_loader, val_loader = None, None
8583
for phase, dataset_opt in opt['datasets'].items():
8684
if phase == 'train':
8785
# dataset_ratio: enlarge the size of datasets for each epoch
@@ -169,8 +167,8 @@ def main():
169167
model.save(epoch, current_iter)
170168

171169
# validation
172-
if opt['datasets'][
173-
'val'] and current_iter % opt['val']['val_freq'] == 0:
170+
if opt['val']['val_freq'] is not None and current_iter % opt[
171+
'val']['val_freq'] == 0:
174172
model.validation(val_loader, current_iter, tb_logger,
175173
opt['val']['save_img'])
176174

@@ -183,7 +181,7 @@ def main():
183181
logger.info('Saving the latest model.')
184182
model.save(epoch=-1, current_iter=-1) # -1 for the latest
185183
# last validation
186-
if opt['datasets']['val']:
184+
if opt['val']['val_freq'] is not None:
187185
model.validation(val_loader, current_iter, tb_logger,
188186
opt['val']['save_img'])
189187

basicsr/utils/options.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -106,30 +106,3 @@ def dict2str(opt, indent_level=1):
106106
else:
107107
msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
108108
return msg
109-
110-
111-
class NoneDict(dict):
112-
"""None dict. It will return none if key is not in the dict."""
113-
114-
def __missing__(self, key):
115-
return None
116-
117-
118-
def dict_to_nonedict(opt):
119-
"""Convert to NoneDict, which returns None for missing keys.
120-
121-
Args:
122-
opt (dict): Option dict.
123-
124-
Returns:
125-
(dict): NoneDict for options.
126-
"""
127-
if isinstance(opt, dict):
128-
new_opt = dict()
129-
for key, sub_opt in opt.items():
130-
new_opt[key] = dict_to_nonedict(sub_opt)
131-
return NoneDict(**new_opt)
132-
elif isinstance(opt, list):
133-
return [dict_to_nonedict(sub_opt) for sub_opt in opt]
134-
else:
135-
return opt

docs/TrainTest.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ In general, both the training and testing include the following steps:
2626
### Single GPU Training
2727

2828
```bash
29+
PYTHONPATH="./:${PYTHONPATH}" \
2930
CUDA_VISIBLE_DEVICES=0 \
3031
python basicsr/train.py -opt options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml
3132
```
@@ -35,13 +36,15 @@ python basicsr/train.py -opt options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml
3536
**8 GPUs**
3637

3738
```bash
39+
PYTHONPATH="./:${PYTHONPATH}" \
3840
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
3941
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml --launcher pytorch
4042
```
4143

4244
**4 GPUs**
4345

4446
```bash
47+
PYTHONPATH="./:${PYTHONPATH}" \
4548
CUDA_VISIBLE_DEVICES=0,1,2,3 \
4649
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml --launcher pytorch
4750
```
@@ -53,6 +56,7 @@ python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr
5356
**1 GPU**
5457

5558
```bash
59+
PYTHONPATH="./:${PYTHONPATH}" \
5660
GLOG_vmodule=MemcachedClient=-1 \
5761
srun -p [partition] --mpi=pmi2 --job-name=MSRResNetx4 --gres=gpu:1 --ntasks=1 --ntasks-per-node=1 --cpus-per-task=6 --kill-on-bad-exit=1 \
5862
python -u basicsr/train.py -opt options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml --launcher="slurm"
@@ -61,6 +65,7 @@ python -u basicsr/train.py -opt options/train/SRResNet_SRGAN/train_MSRResNet_x4.
6165
**4 GPUs**
6266

6367
```bash
68+
PYTHONPATH="./:${PYTHONPATH}" \
6469
GLOG_vmodule=MemcachedClient=-1 \
6570
srun -p [partition] --mpi=pmi2 --job-name=EDVRMwoTSA --gres=gpu:4 --ntasks=4 --ntasks-per-node=4 --cpus-per-task=4 --kill-on-bad-exit=1 \
6671
python -u basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml --launcher="slurm"
@@ -69,6 +74,7 @@ python -u basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA
6974
**8 GPUs**
7075

7176
```bash
77+
PYTHONPATH="./:${PYTHONPATH}" \
7278
GLOG_vmodule=MemcachedClient=-1 \
7379
srun -p [partition] --mpi=pmi2 --job-name=EDVRMwoTSA --gres=gpu:8 --ntasks=8 --ntasks-per-node=8 --cpus-per-task=6 --kill-on-bad-exit=1 \
7480
python -u basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml --launcher="slurm"
@@ -79,6 +85,7 @@ python -u basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA
7985
### Single GPU Testing
8086

8187
```bash
88+
PYTHONPATH="./:${PYTHONPATH}" \
8289
CUDA_VISIBLE_DEVICES=0 \
8390
python basicsr/test.py -opt options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml
8491
```
@@ -88,13 +95,15 @@ python basicsr/test.py -opt options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml
8895
**8 GPUs**
8996

9097
```bash
98+
PYTHONPATH="./:${PYTHONPATH}" \
9199
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
92100
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/test.py -opt options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml --launcher pytorch
93101
```
94102

95103
**4 GPUs**
96104

97105
```bash
106+
PYTHONPATH="./:${PYTHONPATH}" \
98107
CUDA_VISIBLE_DEVICES=0,1,2,3 \
99108
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/test.py -opt options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml --launcher pytorch
100109
```
@@ -106,6 +115,7 @@ python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr
106115
**1 GPU**
107116

108117
```bash
118+
PYTHONPATH="./:${PYTHONPATH}" \
109119
GLOG_vmodule=MemcachedClient=-1 \
110120
srun -p [partition] --mpi=pmi2 --job-name=test --gres=gpu:1 --ntasks=1 --ntasks-per-node=1 --cpus-per-task=6 --kill-on-bad-exit=1 \
111121
python -u basicsr/test.py -opt options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml --launcher="slurm"
@@ -114,6 +124,7 @@ python -u basicsr/test.py -opt options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml
114124
**4 GPUs**
115125

116126
```bash
127+
PYTHONPATH="./:${PYTHONPATH}" \
117128
GLOG_vmodule=MemcachedClient=-1 \
118129
srun -p [partition] --mpi=pmi2 --job-name=test --gres=gpu:4 --ntasks=4 --ntasks-per-node=4 --cpus-per-task=4 --kill-on-bad-exit=1 \
119130
python -u basicsr/test.py -opt options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml --launcher="slurm"
@@ -122,6 +133,7 @@ python -u basicsr/test.py -opt options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml --la
122133
**8 GPUs**
123134

124135
```bash
136+
PYTHONPATH="./:${PYTHONPATH}" \
125137
GLOG_vmodule=MemcachedClient=-1 \
126138
srun -p [partition] --mpi=pmi2 --job-name=test --gres=gpu:8 --ntasks=8 --ntasks-per-node=8 --cpus-per-task=6 --kill-on-bad-exit=1 \
127139
python -u basicsr/test.py -opt options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml --launcher="slurm"

docs/TrainTest_CN.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
### 单GPU训练
2727

2828
```bash
29+
PYTHONPATH="./:${PYTHONPATH}" \
2930
CUDA_VISIBLE_DEVICES=0 \
3031
python basicsr/train.py -opt options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml
3132
```
@@ -35,13 +36,15 @@ python basicsr/train.py -opt options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml
3536
**8 GPUs**
3637

3738
```bash
39+
PYTHONPATH="./:${PYTHONPATH}" \
3840
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
3941
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml --launcher pytorch
4042
```
4143

4244
**4 GPUs**
4345

4446
```bash
47+
PYTHONPATH="./:${PYTHONPATH}" \
4548
CUDA_VISIBLE_DEVICES=0,1,2,3 \
4649
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml --launcher pytorch
4750
```
@@ -53,6 +56,7 @@ python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr
5356
**1 GPU**
5457

5558
```bash
59+
PYTHONPATH="./:${PYTHONPATH}" \
5660
GLOG_vmodule=MemcachedClient=-1 \
5761
srun -p [partition] --mpi=pmi2 --job-name=MSRResNetx4 --gres=gpu:1 --ntasks=1 --ntasks-per-node=1 --cpus-per-task=6 --kill-on-bad-exit=1 \
5862
python -u basicsr/train.py -opt options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml --launcher="slurm"
@@ -61,6 +65,7 @@ python -u basicsr/train.py -opt options/train/SRResNet_SRGAN/train_MSRResNet_x4.
6165
**4 GPUs**
6266

6367
```bash
68+
PYTHONPATH="./:${PYTHONPATH}" \
6469
GLOG_vmodule=MemcachedClient=-1 \
6570
srun -p [partition] --mpi=pmi2 --job-name=EDVRMwoTSA --gres=gpu:4 --ntasks=4 --ntasks-per-node=4 --cpus-per-task=4 --kill-on-bad-exit=1 \
6671
python -u basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml --launcher="slurm"
@@ -69,6 +74,7 @@ python -u basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA
6974
**8 GPUs**
7075

7176
```bash
77+
PYTHONPATH="./:${PYTHONPATH}" \
7278
GLOG_vmodule=MemcachedClient=-1 \
7379
srun -p [partition] --mpi=pmi2 --job-name=EDVRMwoTSA --gres=gpu:8 --ntasks=8 --ntasks-per-node=8 --cpus-per-task=6 --kill-on-bad-exit=1 \
7480
python -u basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml --launcher="slurm"
@@ -79,6 +85,7 @@ python -u basicsr/train.py -opt options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA
7985
### 单GPU测试
8086

8187
```bash
88+
PYTHONPATH="./:${PYTHONPATH}" \
8289
CUDA_VISIBLE_DEVICES=0 \
8390
python basicsr/test.py -opt options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml
8491
```
@@ -88,13 +95,15 @@ python basicsr/test.py -opt options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml
8895
**8 GPUs**
8996

9097
```bash
98+
PYTHONPATH="./:${PYTHONPATH}" \
9199
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
92100
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/test.py -opt options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml --launcher pytorch
93101
```
94102

95103
**4 GPUs**
96104

97105
```bash
106+
PYTHONPATH="./:${PYTHONPATH}" \
98107
CUDA_VISIBLE_DEVICES=0,1,2,3 \
99108
python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/test.py -opt options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml --launcher pytorch
100109
```
@@ -106,6 +115,7 @@ python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr
106115
**1 GPU**
107116

108117
```bash
118+
PYTHONPATH="./:${PYTHONPATH}" \
109119
GLOG_vmodule=MemcachedClient=-1 \
110120
srun -p [partition] --mpi=pmi2 --job-name=test --gres=gpu:1 --ntasks=1 --ntasks-per-node=1 --cpus-per-task=6 --kill-on-bad-exit=1 \
111121
python -u basicsr/test.py -opt options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml --launcher="slurm"
@@ -114,6 +124,7 @@ python -u basicsr/test.py -opt options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml
114124
**4 GPUs**
115125

116126
```bash
127+
PYTHONPATH="./:${PYTHONPATH}" \
117128
GLOG_vmodule=MemcachedClient=-1 \
118129
srun -p [partition] --mpi=pmi2 --job-name=test --gres=gpu:4 --ntasks=4 --ntasks-per-node=4 --cpus-per-task=4 --kill-on-bad-exit=1 \
119130
python -u basicsr/test.py -opt options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml --launcher="slurm"
@@ -122,6 +133,7 @@ python -u basicsr/test.py -opt options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml --la
122133
**8 GPUs**
123134

124135
```bash
136+
PYTHONPATH="./:${PYTHONPATH}" \
125137
GLOG_vmodule=MemcachedClient=-1 \
126138
srun -p [partition] --mpi=pmi2 --job-name=test --gres=gpu:8 --ntasks=8 --ntasks-per-node=8 --cpus-per-task=6 --kill-on-bad-exit=1 \
127139
python -u basicsr/test.py -opt options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml --launcher="slurm"

0 commit comments

Comments
 (0)