Skip to content

Commit 1a57331

Browse files
committed
function, pos_hrnet landmark version 추가
1 parent fdb64fd commit 1a57331

File tree

5 files changed

+796
-6
lines changed

5 files changed

+796
-6
lines changed

lib/core/function_plus.py

+256
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
# ------------------------------------------------------------------------------
2+
# Copyright (c) Microsoft
3+
# Licensed under the MIT License.
4+
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
5+
# ------------------------------------------------------------------------------
6+
7+
from __future__ import absolute_import
8+
from __future__ import division
9+
from __future__ import print_function
10+
11+
import time
12+
import logging
13+
import os
14+
15+
import numpy as np
16+
import torch
17+
18+
from core.evaluate import accuracy, accuracy_classification, accuracy_landmark
19+
from utils.vis import save_result_images, save_debug_images, save_images_landmark
20+
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
def train(config, train_loader, model, criterion, optimizer, epoch,
26+
output_dir, tb_log_dir, writer_dict):
27+
batch_time = AverageMeter()
28+
data_time = AverageMeter()
29+
losses = AverageMeter()
30+
loss_classifier = AverageMeter()
31+
loss_landmark = AverageMeter()
32+
acc = AverageMeter()
33+
acc_cls = AverageMeter()
34+
35+
# switch to train mode
36+
model.train()
37+
38+
end = time.time()
39+
for i, (input, target, target_weight, meta) in enumerate(train_loader):
40+
# measure data loading time
41+
data_time.update(time.time() - end)
42+
43+
# compute output
44+
classification, landmark = model(input)
45+
46+
#target2 = meta["visible"].type(torch.FloatTensor).cuda(non_blocking=True).view(classification.size(0),-1)
47+
target = meta["visible"].type(torch.FloatTensor).cuda(non_blocking=True)
48+
classloss = criterion[0](classification, target)
49+
50+
target2 = meta["joints"].reshape(-1,64).type(torch.FloatTensor).cuda(non_blocking=True)
51+
lmloss = criterion[1](landmark, target2)
52+
53+
#loss = config.TRAIN.LOSS_WEIGHT[0]*classloss + config.TRAIN.LOSS_WEIGHT[1]*lmloss
54+
loss = config.TRAIN.LOSS_WEIGHT[1] * lmloss
55+
56+
# compute gradient and do update step
57+
optimizer.zero_grad()
58+
loss.backward()
59+
optimizer.step()
60+
61+
# measure accuracy and record loss
62+
losses.update(loss.item(), input.size(0))
63+
loss_classifier.update(classloss.item(), input.size(0))
64+
loss_landmark.update(lmloss.item(), input.size(0))
65+
66+
avg_acc, cnt= accuracy_landmark(landmark.detach().cpu().numpy(),
67+
target2.detach().cpu().numpy())
68+
acc.update(avg_acc, cnt)
69+
70+
avg_acc, cnt = accuracy_classification(classification.detach().cpu().numpy(),
71+
target.detach().cpu().numpy())
72+
acc_cls.update(avg_acc, cnt)
73+
74+
# measure elapsed time
75+
batch_time.update(time.time() - end)
76+
end = time.time()
77+
78+
if i % config.PRINT_FREQ == 0:
79+
msg = 'Epoch: [{0}][{1}/{2}]\t' \
80+
'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
81+
'Speed {speed:.1f} samples/s\t' \
82+
'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
83+
'Loss {loss.val:.5f} ({loss.avg:.5f}) ({classific.avg: .5f}+{lm.avg: .5f})\t' \
84+
'Accuracy(landmark) {acc.val:.3f} ({acc.avg:.3f})\t'\
85+
'Accuracy(classification) {acc_cls.val:.3f} ({acc_cls.avg:.3f})'.format(
86+
epoch, i, len(train_loader), batch_time=batch_time,
87+
speed=input.size(0)/batch_time.val,
88+
data_time=data_time,
89+
loss=losses, classific=loss_classifier, lm=loss_landmark,
90+
acc=acc, acc_cls=acc_cls)
91+
logger.info(msg)
92+
93+
94+
def validate(config, val_loader, val_dataset, model, criterion, output_dir,
95+
tb_log_dir, writer_dict=None):
96+
batch_time = AverageMeter()
97+
losses = AverageMeter()
98+
acc = AverageMeter()
99+
100+
# switch to evaluate mode
101+
model.eval()
102+
103+
num_samples = len(val_dataset)
104+
all_preds = np.zeros(
105+
(num_samples, config.MODEL.NUM_JOINTS, 3),
106+
dtype=np.float32
107+
)
108+
all_boxes = np.zeros((num_samples, 6))
109+
image_path = []
110+
filenames = []
111+
imgnums = []
112+
idx = 0
113+
with torch.no_grad():
114+
end = time.time()
115+
for i, (input, target, target_weight, meta) in enumerate(val_loader):
116+
# compute output
117+
classification, landmark = model(input)
118+
119+
target = meta["visible"].type(torch.FloatTensor).cuda(non_blocking=True)
120+
classloss = criterion[0](classification, target)
121+
122+
target2 = meta["joints"].reshape(-1, 64).type(torch.FloatTensor).cuda(non_blocking=True)
123+
lmloss = criterion[1](landmark, target2)
124+
125+
#loss = config.TRAIN.LOSS_WEIGHT[0]*classloss + config.TRAIN.LOSS_WEIGHT[1]*lmloss
126+
loss = config.TRAIN.LOSS_WEIGHT[1] * lmloss
127+
128+
num_images = input.size(0)
129+
# measure accuracy and record loss
130+
losses.update(loss.item(), num_images)
131+
avg_acc, cnt = accuracy_landmark(landmark.detach().cpu().numpy(),
132+
target2.detach().cpu().numpy())
133+
acc.update(avg_acc, cnt)
134+
135+
# measure elapsed time
136+
batch_time.update(time.time() - end)
137+
end = time.time()
138+
139+
idx += num_images
140+
141+
if i % config.PRINT_FREQ == 0:
142+
msg = 'Test: [{0}/{1}]\t' \
143+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
144+
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
145+
'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
146+
i, len(val_loader), batch_time=batch_time,
147+
loss=losses, acc=acc)
148+
logger.info(msg)
149+
150+
return acc.avg
151+
152+
def test(config, val_loader, val_dataset, model, criterion, output_dir,
153+
tb_log_dir, writer_dict=None):
154+
batch_time = AverageMeter()
155+
losses = AverageMeter()
156+
acc = AverageMeter()
157+
acc_cls = AverageMeter()
158+
159+
# switch to evaluate mode
160+
model.eval()
161+
162+
num_samples = len(val_dataset)
163+
all_preds = np.zeros(
164+
(num_samples, config.MODEL.NUM_JOINTS, 3),
165+
dtype=np.float32
166+
)
167+
all_boxes = np.zeros((num_samples, 6))
168+
image_path = []
169+
filenames = []
170+
imgnums = []
171+
idx = 0
172+
with torch.no_grad():
173+
end = time.time()
174+
for i, (input, target, target_weight, meta) in enumerate(val_loader):
175+
# compute output
176+
classification, landmark = model(input)
177+
178+
target = meta["visible"].type(torch.FloatTensor).cuda(non_blocking=True)
179+
classloss = criterion[0](classification, target)
180+
181+
target2 = meta["joints"].reshape(-1, 64).type(torch.FloatTensor).cuda(non_blocking=True)
182+
lmloss = criterion[1](landmark, target2)
183+
184+
#loss = config.TRAIN.LOSS_WEIGHT[0] * classloss + config.TRAIN.LOSS_WEIGHT[1] * lmloss
185+
loss = config.TRAIN.LOSS_WEIGHT[1] * lmloss
186+
187+
num_images = input.size(0)
188+
# measure accuracy and record loss
189+
losses.update(loss.item(), num_images)
190+
191+
avg_acc, cnt = accuracy_landmark(landmark.detach().cpu().numpy(),
192+
target2.detach().cpu().numpy())
193+
acc.update(avg_acc, cnt)
194+
195+
avg_acc, cnt = accuracy_classification(classification.detach().cpu().numpy(),
196+
target.detach().cpu().numpy())
197+
acc_cls.update(avg_acc, cnt)
198+
199+
# measure elapsed time
200+
batch_time.update(time.time() - end)
201+
end = time.time()
202+
203+
if i % 1 == 0:
204+
msg = 'Test: [{0}/{1}]\t' \
205+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
206+
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
207+
'Accuracy {acc.val:.3f} ({acc.avg:.3f})\t'\
208+
'Accuracy {acc2.val:.3f} ({acc2.avg:.3f})'.format(
209+
i, len(val_loader), batch_time=batch_time,
210+
loss=losses, acc=acc, acc2=acc_cls)
211+
logger.info(msg)
212+
213+
prefix = os.path.join(output_dir, 'result')
214+
215+
save_images_landmark(meta, landmark.detach().cpu().numpy(), classification.detach().cpu().numpy(), prefix, i)
216+
217+
return 0
218+
219+
220+
# markdown format output
221+
def _print_name_value(name_value, full_arch_name):
222+
names = name_value.keys()
223+
values = name_value.values()
224+
num_values = len(name_value)
225+
logger.info(
226+
'| Arch ' +
227+
' '.join(['| {}'.format(name) for name in names]) +
228+
' |'
229+
)
230+
logger.info('|---' * (num_values+1) + '|')
231+
232+
if len(full_arch_name) > 15:
233+
full_arch_name = full_arch_name[:8] + '...'
234+
logger.info(
235+
'| ' + full_arch_name + ' ' +
236+
' '.join(['| {:.3f}'.format(value) for value in values]) +
237+
' |'
238+
)
239+
240+
241+
class AverageMeter(object):
242+
"""Computes and stores the average and current value"""
243+
def __init__(self):
244+
self.reset()
245+
246+
def reset(self):
247+
self.val = 0
248+
self.avg = 0
249+
self.sum = 0
250+
self.count = 0
251+
252+
def update(self, val, n=1):
253+
self.val = val
254+
self.sum += val * n
255+
self.count += n
256+
self.avg = self.sum / self.count if self.count != 0 else 0

lib/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414

1515
import models.pose_resnet
1616
import models.pose_hrnet
17+
import models.pose_hrnet_plus

0 commit comments

Comments
 (0)