Skip to content

Commit 9ffcaf2

Browse files
committedMay 1, 2023
added frozen layers
1 parent fb10fcd commit 9ffcaf2

File tree

4 files changed

+169
-168
lines changed

4 files changed

+169
-168
lines changed
 

‎demo/demo.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,9 @@ def main():
460460
)
461461
if len(pose_preds) >= 1:
462462
for kpt in pose_preds:
463-
if len(kpt) == 41:
463+
if len(kpt) == 58:
464+
draw_pose_infinity_coco(kpt, image_bgr)
465+
elif len(kpt) == 41:
464466
draw_pose_infinity(kpt, image_bgr)
465467
else:
466468
draw_pose(kpt, image_bgr) # draw the poses

‎experiments/infinity_coco/hrnet/w48_384x288_adam_lr1e-3.yaml

+15-3
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ CUDNN:
55
ENABLED: true
66
DATA_DIR: ""
77
GPUS: (0,)
8-
OUTPUT_DIR: "output_infinity_coco"
9-
LOG_DIR: "log_infinity_coco"
8+
OUTPUT_DIR: "output_infinity_coco_frozen"
9+
LOG_DIR: "log_infinity_coco_frozen"
1010
WORKERS: 2
1111
PRINT_FREQ: 10
1212

@@ -36,6 +36,18 @@ MODEL:
3636
- 96
3737
SIGMA: 3
3838
EXTRA:
39+
FREEZE_LAYERS: true
40+
FROZEN_LAYERS:
41+
- "conv1"
42+
- "bn1"
43+
- "conv2"
44+
- "bn2"
45+
- "layer1"
46+
- "transition1"
47+
- "stage2"
48+
- "transition2"
49+
- "stage3"
50+
- "transition3"
3951
PRETRAINED_LAYERS:
4052
- "conv1"
4153
- "bn1"
@@ -91,7 +103,7 @@ MODEL:
91103
LOSS:
92104
USE_TARGET_WEIGHT: true
93105
TRAIN:
94-
BATCH_SIZE_PER_GPU: 40
106+
BATCH_SIZE_PER_GPU: 2
95107
SHUFFLE: true
96108
BEGIN_EPOCH: 0
97109
END_EPOCH: 200

‎lib/core/function.py

+104-102
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,46 @@
44
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
55
# ------------------------------------------------------------------------------
66

7-
from __future__ import absolute_import
8-
from __future__ import division
9-
from __future__ import print_function
10-
11-
import time
7+
from __future__ import absolute_import, division, print_function
8+
129
import logging
1310
import os
11+
import time
1412

1513
import numpy as np
1614
import torch
17-
1815
from core.evaluate import accuracy
1916
from core.inference import get_final_preds
2017
from utils.transforms import flip_back
2118
from utils.vis import save_debug_images
2219

23-
2420
logger = logging.getLogger(__name__)
2521

2622

27-
def train(config, train_loader, model, criterion, optimizer, epoch,
28-
output_dir, tb_log_dir, writer_dict):
23+
def train(
24+
config,
25+
train_loader,
26+
model,
27+
criterion,
28+
optimizer,
29+
epoch,
30+
output_dir,
31+
tb_log_dir,
32+
writer_dict,
33+
):
2934
batch_time = AverageMeter()
3035
data_time = AverageMeter()
3136
losses = AverageMeter()
3237
acc = AverageMeter()
3338

3439
# switch to train mode
3540
model.train()
41+
# freeze specified layers
42+
extra = config.MODEL.EXTRA
43+
if "FREEZE_LAYERS" in extra and extra["FREEZE_LAYERS"]:
44+
frozen_layers = extra.FROZEN_LAYERS
45+
for layer in frozen_layers:
46+
eval("model.module." + layer + ".requires_grad_(False)")
3647

3748
end = time.time()
3849
for i, (input, target, target_weight, meta) in enumerate(train_loader):
@@ -63,39 +74,55 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
6374
# measure accuracy and record loss
6475
losses.update(loss.item(), input.size(0))
6576

66-
_, avg_acc, cnt, pred = accuracy(output.detach().cpu().numpy(),
67-
target.detach().cpu().numpy())
77+
_, avg_acc, cnt, pred = accuracy(
78+
output.detach().cpu().numpy(), target.detach().cpu().numpy()
79+
)
6880
acc.update(avg_acc, cnt)
6981

7082
# measure elapsed time
7183
batch_time.update(time.time() - end)
7284
end = time.time()
7385

7486
if i % config.PRINT_FREQ == 0:
75-
msg = 'Epoch: [{0}][{1}/{2}]\t' \
76-
'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
77-
'Speed {speed:.1f} samples/s\t' \
78-
'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
79-
'Loss {loss.val:.5f} ({loss.avg:.5f})\t' \
80-
'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
81-
epoch, i, len(train_loader), batch_time=batch_time,
82-
speed=input.size(0)/batch_time.val,
83-
data_time=data_time, loss=losses, acc=acc)
87+
msg = (
88+
"Epoch: [{0}][{1}/{2}]\t"
89+
"Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t"
90+
"Speed {speed:.1f} samples/s\t"
91+
"Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t"
92+
"Loss {loss.val:.5f} ({loss.avg:.5f})\t"
93+
"Accuracy {acc.val:.3f} ({acc.avg:.3f})".format(
94+
epoch,
95+
i,
96+
len(train_loader),
97+
batch_time=batch_time,
98+
speed=input.size(0) / batch_time.val,
99+
data_time=data_time,
100+
loss=losses,
101+
acc=acc,
102+
)
103+
)
84104
logger.info(msg)
85105

86-
writer = writer_dict['writer']
87-
global_steps = writer_dict['train_global_steps']
88-
writer.add_scalar('train_loss', losses.val, global_steps)
89-
writer.add_scalar('train_acc', acc.val, global_steps)
90-
writer_dict['train_global_steps'] = global_steps + 1
91-
92-
prefix = '{}_{}'.format(os.path.join(output_dir, 'train'), i)
93-
save_debug_images(config, input, meta, target, pred*4, output,
94-
prefix)
95-
96-
97-
def validate(config, val_loader, val_dataset, model, criterion, output_dir,
98-
tb_log_dir, writer_dict=None):
106+
writer = writer_dict["writer"]
107+
global_steps = writer_dict["train_global_steps"]
108+
writer.add_scalar("train_loss", losses.val, global_steps)
109+
writer.add_scalar("train_acc", acc.val, global_steps)
110+
writer_dict["train_global_steps"] = global_steps + 1
111+
112+
prefix = "{}_{}".format(os.path.join(output_dir, "train"), i)
113+
save_debug_images(config, input, meta, target, pred * 4, output, prefix)
114+
115+
116+
def validate(
117+
config,
118+
val_loader,
119+
val_dataset,
120+
model,
121+
criterion,
122+
output_dir,
123+
tb_log_dir,
124+
writer_dict=None,
125+
):
99126
batch_time = AverageMeter()
100127
losses = AverageMeter()
101128
acc = AverageMeter()
@@ -104,10 +131,7 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
104131
model.eval()
105132

106133
num_samples = len(val_dataset)
107-
all_preds = np.zeros(
108-
(num_samples, config.MODEL.NUM_JOINTS, 3),
109-
dtype=np.float32
110-
)
134+
all_preds = np.zeros((num_samples, config.MODEL.NUM_JOINTS, 3), dtype=np.float32)
111135
all_boxes = np.zeros((num_samples, 6))
112136
image_path = []
113137
filenames = []
@@ -132,15 +156,14 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
132156
else:
133157
output_flipped = outputs_flipped
134158

135-
output_flipped = flip_back(output_flipped.cpu().numpy(),
136-
val_dataset.flip_pairs)
159+
output_flipped = flip_back(
160+
output_flipped.cpu().numpy(), val_dataset.flip_pairs
161+
)
137162
output_flipped = torch.from_numpy(output_flipped.copy()).cuda()
138163

139-
140164
# feature is not aligned, shift flipped heatmap for higher accuracy
141165
if config.TEST.SHIFT_HEATMAP:
142-
output_flipped[:, :, :, 1:] = \
143-
output_flipped.clone()[:, :, :, 0:-1]
166+
output_flipped[:, :, :, 1:] = output_flipped.clone()[:, :, :, 0:-1]
144167

145168
output = (output + output_flipped) * 0.5
146169

@@ -152,51 +175,47 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
152175
num_images = input.size(0)
153176
# measure accuracy and record loss
154177
losses.update(loss.item(), num_images)
155-
_, avg_acc, cnt, pred = accuracy(output.cpu().numpy(),
156-
target.cpu().numpy())
178+
_, avg_acc, cnt, pred = accuracy(output.cpu().numpy(), target.cpu().numpy())
157179

158180
acc.update(avg_acc, cnt)
159181

160182
# measure elapsed time
161183
batch_time.update(time.time() - end)
162184
end = time.time()
163185

164-
c = meta['center'].numpy()
165-
s = meta['scale'].numpy()
166-
score = meta['score'].numpy()
186+
c = meta["center"].numpy()
187+
s = meta["scale"].numpy()
188+
score = meta["score"].numpy()
167189

168-
preds, maxvals = get_final_preds(
169-
config, output.clone().cpu().numpy(), c, s)
190+
preds, maxvals = get_final_preds(config, output.clone().cpu().numpy(), c, s)
170191

171-
all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
172-
all_preds[idx:idx + num_images, :, 2:3] = maxvals
192+
all_preds[idx : idx + num_images, :, 0:2] = preds[:, :, 0:2]
193+
all_preds[idx : idx + num_images, :, 2:3] = maxvals
173194
# double check this all_boxes parts
174-
all_boxes[idx:idx + num_images, 0:2] = c[:, 0:2]
175-
all_boxes[idx:idx + num_images, 2:4] = s[:, 0:2]
176-
all_boxes[idx:idx + num_images, 4] = np.prod(s*200, 1)
177-
all_boxes[idx:idx + num_images, 5] = score
178-
image_path.extend(meta['image'])
195+
all_boxes[idx : idx + num_images, 0:2] = c[:, 0:2]
196+
all_boxes[idx : idx + num_images, 2:4] = s[:, 0:2]
197+
all_boxes[idx : idx + num_images, 4] = np.prod(s * 200, 1)
198+
all_boxes[idx : idx + num_images, 5] = score
199+
image_path.extend(meta["image"])
179200

180201
idx += num_images
181202

182203
if i % config.PRINT_FREQ == 0:
183-
msg = 'Test: [{0}/{1}]\t' \
184-
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
185-
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
186-
'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
187-
i, len(val_loader), batch_time=batch_time,
188-
loss=losses, acc=acc)
204+
msg = (
205+
"Test: [{0}/{1}]\t"
206+
"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
207+
"Loss {loss.val:.4f} ({loss.avg:.4f})\t"
208+
"Accuracy {acc.val:.3f} ({acc.avg:.3f})".format(
209+
i, len(val_loader), batch_time=batch_time, loss=losses, acc=acc
210+
)
211+
)
189212
logger.info(msg)
190213

191-
prefix = '{}_{}'.format(
192-
os.path.join(output_dir, 'val'), i
193-
)
194-
save_debug_images(config, input, meta, target, pred*4, output,
195-
prefix)
214+
prefix = "{}_{}".format(os.path.join(output_dir, "val"), i)
215+
save_debug_images(config, input, meta, target, pred * 4, output, prefix)
196216

197217
name_values, perf_indicator = val_dataset.evaluate(
198-
config, all_preds, output_dir, all_boxes, image_path,
199-
filenames, imgnums
218+
config, all_preds, output_dir, all_boxes, image_path, filenames, imgnums
200219
)
201220

202221
model_name = config.MODEL.NAME
@@ -207,32 +226,16 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
207226
_print_name_value(name_values, model_name)
208227

209228
if writer_dict:
210-
writer = writer_dict['writer']
211-
global_steps = writer_dict['valid_global_steps']
212-
writer.add_scalar(
213-
'valid_loss',
214-
losses.avg,
215-
global_steps
216-
)
217-
writer.add_scalar(
218-
'valid_acc',
219-
acc.avg,
220-
global_steps
221-
)
229+
writer = writer_dict["writer"]
230+
global_steps = writer_dict["valid_global_steps"]
231+
writer.add_scalar("valid_loss", losses.avg, global_steps)
232+
writer.add_scalar("valid_acc", acc.avg, global_steps)
222233
if isinstance(name_values, list):
223234
for name_value in name_values:
224-
writer.add_scalars(
225-
'valid',
226-
dict(name_value),
227-
global_steps
228-
)
235+
writer.add_scalars("valid", dict(name_value), global_steps)
229236
else:
230-
writer.add_scalars(
231-
'valid',
232-
dict(name_values),
233-
global_steps
234-
)
235-
writer_dict['valid_global_steps'] = global_steps + 1
237+
writer.add_scalars("valid", dict(name_values), global_steps)
238+
writer_dict["valid_global_steps"] = global_steps + 1
236239

237240
return perf_indicator
238241

@@ -242,24 +245,23 @@ def _print_name_value(name_value, full_arch_name):
242245
names = name_value.keys()
243246
values = name_value.values()
244247
num_values = len(name_value)
245-
logger.info(
246-
'| Arch ' +
247-
' '.join(['| {}'.format(name) for name in names]) +
248-
' |'
249-
)
250-
logger.info('|---' * (num_values+1) + '|')
248+
logger.info("| Arch " + " ".join(["| {}".format(name) for name in names]) + " |")
249+
logger.info("|---" * (num_values + 1) + "|")
251250

252251
if len(full_arch_name) > 15:
253-
full_arch_name = full_arch_name[:8] + '...'
252+
full_arch_name = full_arch_name[:8] + "..."
254253
logger.info(
255-
'| ' + full_arch_name + ' ' +
256-
' '.join(['| {:.3f}'.format(value) for value in values]) +
257-
' |'
254+
"| "
255+
+ full_arch_name
256+
+ " "
257+
+ " ".join(["| {:.3f}".format(value) for value in values])
258+
+ " |"
258259
)
259260

260261

261262
class AverageMeter(object):
262263
"""Computes and stores the average and current value"""
264+
263265
def __init__(self):
264266
self.reset()
265267

‎tools/test.py

+47-62
Original file line numberDiff line numberDiff line change
@@ -5,62 +5,49 @@
55
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
66
# ------------------------------------------------------------------------------
77

8-
from __future__ import absolute_import
9-
from __future__ import division
10-
from __future__ import print_function
8+
from __future__ import absolute_import, division, print_function
119

1210
import argparse
1311
import os
1412
import pprint
1513

14+
import _init_paths
15+
import dataset
1616
import torch
17-
import torch.nn.parallel
1817
import torch.backends.cudnn as cudnn
18+
import torch.nn.parallel
1919
import torch.optim
2020
import torch.utils.data
2121
import torch.utils.data.distributed
2222
import torchvision.transforms as transforms
23-
24-
import _init_paths
25-
from config import cfg
26-
from config import update_config
27-
from core.loss import JointsMSELoss
23+
from config import cfg, update_config
2824
from core.function import validate
25+
from core.loss import JointsMSELoss
2926
from utils.utils import create_logger
3027

31-
import dataset
3228
import models
3329

3430

3531
def parse_args():
36-
parser = argparse.ArgumentParser(description='Train keypoints network')
32+
parser = argparse.ArgumentParser(description="Train keypoints network")
3733
# general
38-
parser.add_argument('--cfg',
39-
help='experiment configure file name',
40-
required=True,
41-
type=str)
42-
43-
parser.add_argument('opts',
44-
help="Modify config options using the command-line",
45-
default=None,
46-
nargs=argparse.REMAINDER)
47-
48-
parser.add_argument('--modelDir',
49-
help='model directory',
50-
type=str,
51-
default='')
52-
parser.add_argument('--logDir',
53-
help='log directory',
54-
type=str,
55-
default='')
56-
parser.add_argument('--dataDir',
57-
help='data directory',
58-
type=str,
59-
default='')
60-
parser.add_argument('--prevModelDir',
61-
help='prev Model directory',
62-
type=str,
63-
default='')
34+
parser.add_argument(
35+
"--cfg", help="experiment configure file name", required=True, type=str
36+
)
37+
38+
parser.add_argument(
39+
"opts",
40+
help="Modify config options using the command-line",
41+
default=None,
42+
nargs=argparse.REMAINDER,
43+
)
44+
45+
parser.add_argument("--modelDir", help="model directory", type=str, default="")
46+
parser.add_argument("--logDir", help="log directory", type=str, default="")
47+
parser.add_argument("--dataDir", help="data directory", type=str, default="")
48+
parser.add_argument(
49+
"--prevModelDir", help="prev Model directory", type=str, default=""
50+
)
6451

6552
args = parser.parse_args()
6653
return args
@@ -70,8 +57,7 @@ def main():
7057
args = parse_args()
7158
update_config(cfg, args)
7259

73-
logger, final_output_dir, tb_log_dir = create_logger(
74-
cfg, args.cfg, 'valid')
60+
logger, final_output_dir, tb_log_dir = create_logger(cfg, args.cfg, "valid")
7561

7662
logger.info(pprint.pformat(args))
7763
logger.info(cfg)
@@ -81,50 +67,49 @@ def main():
8167
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
8268
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
8369

84-
model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
85-
cfg, is_train=False
86-
)
70+
model = eval("models." + cfg.MODEL.NAME + ".get_pose_net")(cfg, is_train=False)
8771

8872
if cfg.TEST.MODEL_FILE:
89-
logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
73+
logger.info("=> loading model from {}".format(cfg.TEST.MODEL_FILE))
9074
model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
9175
else:
92-
model_state_file = os.path.join(
93-
final_output_dir, 'final_state.pth'
94-
)
95-
logger.info('=> loading model from {}'.format(model_state_file))
76+
model_state_file = os.path.join(final_output_dir, "final_state.pth")
77+
logger.info("=> loading model from {}".format(model_state_file))
9678
model.load_state_dict(torch.load(model_state_file))
9779

9880
model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
99-
10081
# define loss function (criterion) and optimizer
101-
criterion = JointsMSELoss(
102-
use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT
103-
).cuda()
82+
criterion = JointsMSELoss(use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()
10483

10584
# Data loading code
10685
normalize = transforms.Normalize(
10786
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
10887
)
109-
valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
110-
cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
111-
transforms.Compose([
112-
transforms.ToTensor(),
113-
normalize,
114-
])
88+
valid_dataset = eval("dataset." + cfg.DATASET.DATASET)(
89+
cfg,
90+
cfg.DATASET.ROOT,
91+
cfg.DATASET.TEST_SET,
92+
False,
93+
transforms.Compose(
94+
[
95+
transforms.ToTensor(),
96+
normalize,
97+
]
98+
),
11599
)
116100
valid_loader = torch.utils.data.DataLoader(
117101
valid_dataset,
118-
batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
102+
batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
119103
shuffle=False,
120104
num_workers=cfg.WORKERS,
121-
pin_memory=True
105+
pin_memory=True,
122106
)
123107

124108
# evaluate on validation set
125-
validate(cfg, valid_loader, valid_dataset, model, criterion,
126-
final_output_dir, tb_log_dir)
109+
validate(
110+
cfg, valid_loader, valid_dataset, model, criterion, final_output_dir, tb_log_dir
111+
)
127112

128113

129-
if __name__ == '__main__':
114+
if __name__ == "__main__":
130115
main()

0 commit comments

Comments
 (0)
Please sign in to comment.