Skip to content

Commit f7cf78c

Browse files
committed
pano_hrnet48
1 parent 4fac5b3 commit f7cf78c

File tree

4 files changed

+27
-114
lines changed

4 files changed

+27
-114
lines changed

lib/core/function.py

+19-89
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
import numpy as np
1616
import torch
1717

18-
from core.evaluate import accuracy, accuracy_classification, accuracy_landmark
19-
from core.inference import get_final_preds
18+
from core.evaluate import accuracy
2019
from utils.transforms import flip_back
2120
from utils.vis import save_result_images, save_debug_images
2221

@@ -25,7 +24,7 @@
2524

2625

2726
def train(config, train_loader, model, criterion, optimizer, epoch,
28-
output_dir, tb_log_dir, writer_dict):
27+
output_dir):
2928
batch_time = AverageMeter()
3029
data_time = AverageMeter()
3130
losses = AverageMeter()
@@ -46,12 +45,12 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
4645
target_weight = target_weight.cuda(non_blocking=True)
4746

4847
if isinstance(heatmap, list):
49-
loss = criterion[0](heatmap[0], target, target_weight)
48+
loss = criterion(heatmap[0], target, target_weight)
5049
for output in heatmap[1:]:
51-
loss += criterion[0](output, target, target_weight)
50+
loss += criterion(output, target, target_weight)
5251
else:
5352
output = heatmap
54-
loss = criterion[0](output, target, target_weight)
53+
loss = criterion(output, target, target_weight)
5554

5655
# loss = criterion(output, target, target_weight)
5756

@@ -90,8 +89,7 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
9089
prefix)
9190

9291

93-
def validate(config, val_loader, val_dataset, model, criterion, output_dir,
94-
tb_log_dir, writer_dict=None):
92+
def validate(config, val_loader, val_dataset, model, criterion, output_dir):
9593
batch_time = AverageMeter()
9694
losses = AverageMeter()
9795
acc = AverageMeter()
@@ -100,14 +98,7 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
10098
model.eval()
10199

102100
num_samples = len(val_dataset)
103-
all_preds = np.zeros(
104-
(num_samples, config.MODEL.NUM_JOINTS, 3),
105-
dtype=np.float32
106-
)
107-
all_boxes = np.zeros((num_samples, 6))
108-
image_path = []
109-
filenames = []
110-
imgnums = []
101+
111102
idx = 0
112103
with torch.no_grad():
113104
end = time.time()
@@ -119,31 +110,10 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
119110
else:
120111
output = heatmap
121112

122-
if config.TEST.FLIP_TEST:
123-
input_flipped = input.flip(3)
124-
outputs_flipped = model(input_flipped)
125-
126-
if isinstance(outputs_flipped, list):
127-
output_flipped = outputs_flipped[-1]
128-
else:
129-
output_flipped = outputs_flipped
130-
131-
output_flipped = flip_back(output_flipped.cpu().numpy(),
132-
val_dataset.flip_pairs)
133-
output_flipped = torch.from_numpy(output_flipped.copy()).cuda()
134-
135-
136-
# feature is not aligned, shift flipped heatmap for higher accuracy
137-
if config.TEST.SHIFT_HEATMAP:
138-
output_flipped[:, :, :, 1:] = \
139-
output_flipped.clone()[:, :, :, 0:-1]
140-
141-
output = (output + output_flipped) * 0.5
142-
143113
target = target.cuda(non_blocking=True)
144114
target_weight = target_weight.cuda(non_blocking=True)
145115

146-
loss = criterion[0](output, target, target_weight)
116+
loss = criterion(output, target, target_weight)
147117

148118
num_images = input.size(0)
149119
# measure accuracy and record loss
@@ -157,22 +127,6 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
157127
batch_time.update(time.time() - end)
158128
end = time.time()
159129

160-
c = meta['center'].numpy()
161-
s = meta['scale'].numpy()
162-
score = meta['score'].numpy()
163-
164-
preds, maxvals = get_final_preds(
165-
config, output.clone().cpu().numpy(), c, s)
166-
167-
all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
168-
all_preds[idx:idx + num_images, :, 2:3] = maxvals
169-
# double check this all_boxes parts
170-
all_boxes[idx:idx + num_images, 0:2] = c[:, 0:2]
171-
all_boxes[idx:idx + num_images, 2:4] = s[:, 0:2]
172-
all_boxes[idx:idx + num_images, 4] = np.prod(s*200, 1)
173-
all_boxes[idx:idx + num_images, 5] = score
174-
image_path.extend(meta['image'])
175-
176130
idx += num_images
177131

178132
if i % 100 == 0:
@@ -192,24 +146,16 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
192146

193147
return acc.avg
194148

195-
def test(config, val_loader, val_dataset, model, criterion, output_dir,
196-
tb_log_dir, writer_dict=None):
149+
def test(config, val_loader, val_dataset, model, criterion, output_dir):
197150
batch_time = AverageMeter()
198151
losses = AverageMeter()
199152
acc = AverageMeter()
153+
acc_mse = AverageMeter()
200154

201155
# switch to evaluate mode
202156
model.eval()
203157

204158
num_samples = len(val_dataset)
205-
all_preds = np.zeros(
206-
(num_samples, config.MODEL.NUM_JOINTS, 3),
207-
dtype=np.float32
208-
)
209-
all_boxes = np.zeros((num_samples, 6))
210-
image_path = []
211-
filenames = []
212-
imgnums = []
213159
idx = 0
214160
with torch.no_grad():
215161
end = time.time()
@@ -247,7 +193,7 @@ def test(config, val_loader, val_dataset, model, criterion, output_dir,
247193

248194
target_class = meta["visible"].type(torch.FloatTensor).cuda(non_blocking=True)
249195

250-
loss = criterion[0](output, target, target_weight)
196+
loss = criterion(output, target, target_weight)
251197

252198
num_images = input.size(0)
253199
# measure accuracy and record loss
@@ -260,38 +206,22 @@ def test(config, val_loader, val_dataset, model, criterion, output_dir,
260206
batch_time.update(time.time() - end)
261207
end = time.time()
262208

263-
c = meta['center'].numpy()
264-
s = meta['scale'].numpy()
265-
score = meta['score'].numpy()
266-
267-
preds, maxvals = get_final_preds(
268-
config, output.clone().cpu().numpy(), c, s)
269-
270-
all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
271-
all_preds[idx:idx + num_images, :, 2:3] = maxvals
272-
# double check this all_boxes parts
273-
all_boxes[idx:idx + num_images, 0:2] = c[:, 0:2]
274-
all_boxes[idx:idx + num_images, 2:4] = s[:, 0:2]
275-
all_boxes[idx:idx + num_images, 4] = np.prod(s*200, 1)
276-
all_boxes[idx:idx + num_images, 5] = score
277-
image_path.extend(meta['image'])
278-
279209
idx += num_images
280210

281211
if i % 1 == 0:
282-
msg = 'Test: [{0}/{1}]\t' \
283-
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
284-
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
285-
'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
286-
i, len(val_loader), batch_time=batch_time,
287-
loss=losses, acc=acc)
288-
logger.info(msg)
289-
290212
prefix = os.path.join(output_dir, 'result')
291213

292214
save_result_images(config, input, meta, target, pred*4, output,
293215
prefix, i)
294216

217+
msg = 'Test: [{0}/{1}]\t' \
218+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
219+
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
220+
'Accuracy {acc.val:.3f} {acc_mse.val:.3f} ({acc.avg:.3f} {acc_mse.avg:.3f})'.format(
221+
i, len(val_loader), batch_time=batch_time,
222+
loss=losses, acc=acc, acc_mse=acc_mse)
223+
logger.info(msg)
224+
295225
return 0
296226

297227

lib/models/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@
88
from __future__ import division
99
from __future__ import print_function
1010

11-
import models.panoNet
11+
import models.pose_hrnet

tools/test.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def main():
8282
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
8383
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
8484

85-
model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
85+
model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(
8686
cfg, is_train=False
8787
)
8888

@@ -103,11 +103,6 @@ def main():
103103
use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT # true
104104
).cuda()
105105

106-
# classifierLoss = nn.MSELoss(reduction='mean').cuda()
107-
classifierLoss = JointsCELoss().cuda()
108-
#lmloss = nn.MSELoss(reduction='mean').cuda()
109-
lmloss = JointsDistLoss().cuda()
110-
111106
# Data loading code
112107
test_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
113108
cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, 'test',
@@ -125,8 +120,8 @@ def main():
125120
)
126121

127122
# evaluate on validation set
128-
test(cfg, test_loader, test_dataset, model, [heatmapLoss, lmloss],
129-
final_output_dir, tb_log_dir)
123+
test(cfg, test_loader, test_dataset, model, heatmapLoss,
124+
final_output_dir)
130125

131126

132127
if __name__ == '__main__':

tools/train.py

+4-16
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,6 @@ def main():
102102
final_output_dir)
103103
# logger.info(pprint.pformat(model))
104104

105-
writer_dict = {
106-
'writer': SummaryWriter(log_dir=tb_log_dir),
107-
'train_global_steps': 0,
108-
'valid_global_steps': 0,
109-
}
110-
111105
dump_input = torch.rand(
112106
(1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0])
113107
)
@@ -121,11 +115,6 @@ def main():
121115
use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT #true
122116
).cuda()
123117

124-
#classifierLoss = nn.MSELoss(reduction='mean').cuda()
125-
classifierLoss = JointsCELoss().cuda()
126-
#lmloss = nn.MSELoss(reduction='mean').cuda()
127-
lmloss = JointsDistLoss().cuda()
128-
129118
# Data loading code
130119
train_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
131120
cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, 'train',
@@ -186,14 +175,14 @@ def main():
186175
lr_scheduler.step()
187176

188177
# train for one epoch
189-
train(cfg, train_loader, model, [heatmapLoss, lmloss], optimizer, epoch,
190-
final_output_dir, tb_log_dir, writer_dict)
178+
train(cfg, train_loader, model, heatmapLoss, optimizer, epoch,
179+
final_output_dir)
191180

192181

193182
# evaluate on validation set
194183
perf_indicator = validate(
195-
cfg, valid_loader, valid_dataset, model, [heatmapLoss, lmloss],
196-
final_output_dir, tb_log_dir, writer_dict
184+
cfg, valid_loader, valid_dataset, model, heatmapLoss,
185+
final_output_dir
197186
)
198187

199188
if perf_indicator >= best_perf:
@@ -225,7 +214,6 @@ def main():
225214
final_model_state_file)
226215
)
227216
torch.save(model.module.state_dict(), final_model_state_file)
228-
writer_dict['writer'].close()
229217

230218

231219
if __name__ == '__main__':

0 commit comments

Comments
 (0)