Skip to content

Commit 359efc7

Browse files
committed
HRNet (heatmap+classification)
1 parent 42de22a commit 359efc7

File tree

10 files changed

+165
-175
lines changed

10 files changed

+165
-175
lines changed

experiments/pano/hrnet/pano_test1.yaml

-127
This file was deleted.

lib/config/default.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
_C.MODEL.NAME = 'pose_hrnet'
3838
_C.MODEL.INIT_WEIGHTS = True
3939
_C.MODEL.PRETRAINED = ''
40-
_C.MODEL.NUM_JOINTS = 17
40+
_C.MODEL.NUM_JOINTS = 32
4141
_C.MODEL.TAG_PER_JOINT = True
4242
_C.MODEL.TARGET_TYPE = 'gaussian'
4343
_C.MODEL.IMAGE_SIZE = [256, 256] # width * height, ex: 192 * 256
@@ -82,6 +82,7 @@
8282
_C.TRAIN.NESTEROV = False
8383
_C.TRAIN.GAMMA1 = 0.99
8484
_C.TRAIN.GAMMA2 = 0.0
85+
_C.TRAIN.LOSS_WEIGHT = [1e1,1e1,1e1]
8586

8687
_C.TRAIN.BEGIN_EPOCH = 0
8788
_C.TRAIN.END_EPOCH = 140

lib/core/evaluate.py

+41
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,45 @@ def accuracy(output, target, hm_type='gaussian', thr=0.5):
7070
acc[0] = avg_acc
7171
return acc, avg_acc, cnt, pred
7272

73+
def accuracy_classification(output, target, thres=0.0):
74+
'''
75+
Calculate accuracy according to PCK,
76+
but uses ground truth heatmap rather than x,y locations
77+
First value to be returned is average accuracy across 'idxs',
78+
followed by individual accuracies
79+
'''
80+
81+
output[output<thres] = 0
82+
output[output>=thres] = 1
83+
84+
cnt = output.shape[0]
85+
acc = np.zeros(cnt)
86+
87+
target = target.reshape(cnt, -1)
88+
89+
for i in range(cnt):
90+
acc[i] = sum(output[i]==target[i])/32
91+
92+
avg_acc = np.mean(acc)
93+
return avg_acc, cnt
94+
95+
def accuracy_landmark(output, target):
96+
'''
97+
Calculate accuracy according to PCK,
98+
but uses ground truth heatmap rather than x,y locations
99+
First value to be returned is average accuracy across 'idxs',
100+
followed by individual accuracies
101+
'''
102+
103+
cnt = output.shape[0]
104+
acc = np.zeros(cnt)
105+
106+
target = target.reshape(cnt, -1)
107+
108+
for i in range(cnt):
109+
acc[i] = sum(output[i]==target[i])/32
110+
111+
avg_acc = np.mean(acc)
112+
return avg_acc, cnt
113+
73114

lib/core/function.py

+51-30
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import numpy as np
1616
import torch
1717

18-
from core.evaluate import accuracy
18+
from core.evaluate import accuracy, accuracy_classification, accuracy_landmark
1919
from core.inference import get_final_preds
2020
from utils.transforms import flip_back
2121
from utils.vis import save_result_images, save_debug_images
@@ -29,6 +29,9 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
2929
batch_time = AverageMeter()
3030
data_time = AverageMeter()
3131
losses = AverageMeter()
32+
loss_classifier = AverageMeter()
33+
loss_heatmap = AverageMeter()
34+
loss_landmark = AverageMeter()
3235
acc = AverageMeter()
3336

3437
# switch to train mode
@@ -40,19 +43,27 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
4043
data_time.update(time.time() - end)
4144

4245
# compute output
43-
outputs = model(input)
46+
heatmap, classification, landmark = model(input)
4447

4548
target = target.cuda(non_blocking=True)
4649
target_weight = target_weight.cuda(non_blocking=True)
4750

48-
if isinstance(outputs, list):
49-
loss = criterion(outputs[0], target, target_weight)
50-
for output in outputs[1:]:
51-
loss += criterion(output, target, target_weight)
51+
if isinstance(heatmap, list):
52+
heatloss = criterion[0](heatmap[0], target, target_weight)
53+
for output in heatmap[1:]:
54+
heatloss += criterion[0](output, target, target_weight)
5255
else:
53-
output = outputs
54-
loss = criterion(output, target, target_weight)
56+
output = heatmap
57+
heatloss = criterion[0](output, target, target_weight)
5558

59+
#target2 = meta["visible"].type(torch.FloatTensor).cuda(non_blocking=True).view(classification.size(0),-1)
60+
target2 = meta["visible"].type(torch.FloatTensor).cuda(non_blocking=True)
61+
classloss = criterion[1](classification, target2)
62+
63+
target3 = meta["joints"].reshape(-1,64).type(torch.FloatTensor).cuda(non_blocking=True)
64+
lmloss = criterion[2](landmark, target3)
65+
66+
loss = config.TRAIN.LOSS_WEIGHT[1]*classloss + config.TRAIN.LOSS_WEIGHT[2]*lmloss
5667
# loss = criterion(output, target, target_weight)
5768

5869
# compute gradient and do update step
@@ -62,6 +73,9 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
6273

6374
# measure accuracy and record loss
6475
losses.update(loss.item(), input.size(0))
76+
loss_classifier.update(classloss.item(), input.size(0))
77+
loss_heatmap.update(heatloss.item(), input.size(0))
78+
loss_landmark.update(lmloss.item(), input.size(0))
6579

6680
_, avg_acc, cnt, pred = accuracy(output.detach().cpu().numpy(),
6781
target.detach().cpu().numpy())
@@ -76,19 +90,15 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
7690
'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
7791
'Speed {speed:.1f} samples/s\t' \
7892
'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
79-
'Loss {loss.val:.5f} ({loss.avg:.5f})\t' \
80-
'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
93+
'Loss {loss.val:.5f} ({loss.avg:.5f}) ({classific.avg: .5f}+{lm.avg: .5f})\t' \
94+
'Accuracy(heatmap) {acc.val:.3f} ({acc.avg:.3f})'.format(
8195
epoch, i, len(train_loader), batch_time=batch_time,
8296
speed=input.size(0)/batch_time.val,
83-
data_time=data_time, loss=losses, acc=acc)
97+
data_time=data_time,
98+
loss=losses, classific=loss_classifier, lm=loss_landmark,
99+
acc=acc)
84100
logger.info(msg)
85101

86-
writer = writer_dict['writer']
87-
global_steps = writer_dict['train_global_steps']
88-
writer.add_scalar('train_loss', losses.val, global_steps)
89-
writer.add_scalar('train_acc', acc.val, global_steps)
90-
writer_dict['train_global_steps'] = global_steps + 1
91-
92102
prefix = '{}_{}'.format(os.path.join(output_dir, 'train'), i)
93103
save_debug_images(config, input, meta, target, pred*4, output,
94104
prefix)
@@ -117,11 +127,11 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
117127
end = time.time()
118128
for i, (input, target, target_weight, meta) in enumerate(val_loader):
119129
# compute output
120-
outputs = model(input)
121-
if isinstance(outputs, list):
122-
output = outputs[-1]
130+
heatmap, classification, landmark = model(input)
131+
if isinstance(heatmap, list):
132+
output = heatmap[-1]
123133
else:
124-
output = outputs
134+
output = heatmap
125135

126136
if config.TEST.FLIP_TEST:
127137
input_flipped = input.flip(3)
@@ -147,7 +157,11 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
147157
target = target.cuda(non_blocking=True)
148158
target_weight = target_weight.cuda(non_blocking=True)
149159

150-
loss = criterion(output, target, target_weight)
160+
target2 = meta["visible"].type(torch.FloatTensor).cuda(non_blocking=True)
161+
target3 = meta["joints"].reshape(-1, 64).type(torch.FloatTensor).cuda(non_blocking=True)
162+
163+
loss = config.TRAIN.LOSS_WEIGHT[1]*criterion[1](classification, target2) \
164+
+ config.TRAIN.LOSS_WEIGHT[2] * criterion[2](landmark, target3)
151165

152166
num_images = input.size(0)
153167
# measure accuracy and record loss
@@ -201,6 +215,7 @@ def test(config, val_loader, val_dataset, model, criterion, output_dir,
201215
batch_time = AverageMeter()
202216
losses = AverageMeter()
203217
acc = AverageMeter()
218+
acc_clas = AverageMeter()
204219

205220
# switch to evaluate mode
206221
model.eval()
@@ -219,11 +234,11 @@ def test(config, val_loader, val_dataset, model, criterion, output_dir,
219234
end = time.time()
220235
for i, (input, target, target_weight, meta) in enumerate(val_loader):
221236
# compute output
222-
outputs = model(input)
223-
if isinstance(outputs, list):
224-
output = outputs[-1]
237+
heatmap, classification = model(input)
238+
if isinstance(heatmap, list):
239+
output = heatmap[-1]
225240
else:
226-
output = outputs
241+
output = heatmap
227242

228243
if config.TEST.FLIP_TEST:
229244
input_flipped = input.flip(3)
@@ -249,16 +264,21 @@ def test(config, val_loader, val_dataset, model, criterion, output_dir,
249264
target = target.cuda(non_blocking=True)
250265
target_weight = target_weight.cuda(non_blocking=True)
251266

252-
loss = criterion(output, target, target_weight)
267+
target_class = meta["visible"].type(torch.FloatTensor).cuda(non_blocking=True)
268+
269+
loss = config.TRAIN.LOSS_WEIGHT[0]*criterion[0](output, target, target_weight) + criterion[1](classification,target_class)
253270

254271
num_images = input.size(0)
255272
# measure accuracy and record loss
256273
losses.update(loss.item(), num_images)
257274
_, avg_acc, cnt, pred = accuracy(output.cpu().numpy(),
258275
target.cpu().numpy())
259-
260276
acc.update(avg_acc, cnt)
261277

278+
avg_acc, cnt = accuracy_classification(classification.cpu().numpy(),
279+
target_class.cpu().numpy())
280+
acc_clas.update(avg_acc, cnt)
281+
262282
# measure elapsed time
263283
batch_time.update(time.time() - end)
264284
end = time.time()
@@ -285,9 +305,10 @@ def test(config, val_loader, val_dataset, model, criterion, output_dir,
285305
msg = 'Test: [{0}/{1}]\t' \
286306
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
287307
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
288-
'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
308+
'Accuracy {acc.val:.3f} ({acc.avg:.3f})\t'\
309+
'Accuracy {acc2.val:.3f} ({acc2.avg:.3f})'.format(
289310
i, len(val_loader), batch_time=batch_time,
290-
loss=losses, acc=acc)
311+
loss=losses, acc=acc, acc2=acc_clas)
291312
logger.info(msg)
292313

293314
prefix = os.path.join(output_dir, 'result')

lib/core/loss.py

+17
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,23 @@ def forward(self, output, target, target_weight):
3838

3939
return loss / num_joints
4040

41+
class JointsCELoss(nn.Module):
42+
def __init__(self):
43+
super(JointsCELoss, self).__init__()
44+
self.criterion = nn.MSELoss(reduction='mean').cuda()
45+
46+
def forward(self, output, target):
47+
batch_size = output.size(0)
48+
num_joints = output.size(1)
49+
loss = 0
50+
51+
for idx in range(num_joints):
52+
class_gt = target[:, idx].view(batch_size)
53+
class_pred = output[:,idx]
54+
loss += self.criterion(class_pred, class_gt)
55+
56+
return loss / num_joints
57+
4158

4259
class JointsOHKMMSELoss(nn.Module):
4360
def __init__(self, use_target_weight, topk=8):

0 commit comments

Comments
 (0)