Skip to content

Commit 8dfaae7

Browse files
committed
pano heatmap network
1 parent 8548d15 commit 8dfaae7

File tree

9 files changed

+71
-29
lines changed

9 files changed

+71
-29
lines changed

lib/core/evaluate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def calc_dists(preds, target, normalize):
2828
return dists
2929

3030

31-
def dist_acc(dists, thr=0.5):
31+
def dist_acc(dists, thr=0.2):
3232
''' Return percentage below threshold while ignoring values with a -1 '''
3333
dist_cal = np.not_equal(dists, -1)
3434
num_dist_cal = dist_cal.sum()

lib/core/function.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
175175

176176
idx += num_images
177177

178-
if i % config.PRINT_FREQ == 0:
178+
if i % 100 == 0:
179179
msg = 'Test: [{0}/{1}]\t' \
180180
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
181181
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \

lib/core/loss.py

+19
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,25 @@ def forward(self, output, target):
5555

5656
return loss / num_joints
5757

58+
class JointsDistLoss(nn.Module):
59+
def __init__(self):
60+
super(JointsDistLoss, self).__init__()
61+
self.criterion = nn.MSELoss(reduction='mean').cuda()
62+
63+
def forward(self, output, target):
64+
batch = output.size(0)
65+
num_joints = output.size(1)
66+
67+
output = output.reshape(batch, 32, 2)
68+
target = target.reshape(batch, 32, 2)
69+
70+
loss = self.criterion(output[:, :, 0], target[:, :, 0]) + 0.3*self.criterion(output[:, :, 1], target[:, :, 1])
71+
72+
#diff = [batch, 32]
73+
#diff = torch.sqrt((output[:, :, 0] - target[:, :, 0])**2 + (output[:, :, 1] - target[:, :, 1])**2)
74+
75+
return loss
76+
5877

5978
class JointsOHKMMSELoss(nn.Module):
6079
def __init__(self, use_target_weight, topk=8):

lib/dataset/pano.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(self, cfg, root, image_set, state, transform=None):
7979
)
8080

8181
print('==> initializing pano {} data.'.format(state))
82-
self.annot_path = os.path.join('C:\\Users\CGIP\Desktop\CenterNet\data\pano', state)
82+
self.annot_path = os.path.join('C:\\Users\CGIP\Desktop\github\CenterNet\data\pano', state)
8383
self.pano = pano.PANODataset(self.annot_path)
8484

8585
# load image file names

lib/dataset/panoDataset.py

+35-8
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,15 @@ def _loadImg(self, file):
2525

2626
def loadData(self, file):
2727
original_img = self._loadImg(file)
28-
img, anns = self._cropImg(original_img, self._loadAnn(file))
28+
anns, mask = self._loadAnn(file)
29+
img, anns = self._cropImg(original_img, anns)
2930
img, anns = self._resize(img, anns)
3031
img = self._normalize(img)
3132

33+
#mask = self._normalize(self._resize(self._cropImg(mask)))
34+
#cv2.imshow('result', mask)
35+
#cv2.waitKey(0)
36+
3237
return img, anns
3338

3439
def show_results(self, debugger, original_image, anns, save=True):
@@ -95,18 +100,21 @@ def _normalize(self, img):
95100
inp = (inp - mean) / std
96101
return inp
97102

98-
def _resize(self, img, anns):
103+
def _resize(self, img, anns=None):
99104
img_new = cv2.resize(img, (512, 256), interpolation=cv2.INTER_CUBIC)
100105
fx = np.size(img, 1) / 512
101106
fy = np.size(img, 0) / 256
102107

108+
if anns==None:
109+
return img_new
110+
103111
for ipt in range(1, self.num_joints+1):
104112
anns[ipt]["bbox"] = (anns[ipt]["bbox"]/np.array([fx, fy, fx, fy])).astype(int)
105113
anns[ipt]["center"] = (anns[ipt]["center"]/np.array([fx, fy])).astype(int)
106114

107115
return img_new, anns
108116

109-
def _cropImg(self, img, anns):
117+
def _cropImg(self, img, anns=None):
110118
width = np.size(img, 1)
111119
height = np.size(img, 0)
112120

@@ -129,13 +137,16 @@ def _cropImg(self, img, anns):
129137
diffW = int((width-newW)/2)
130138
img = img[diffH:diffH+newH,diffW:diffW+newW,:]
131139

140+
if anns==None:
141+
return img
142+
132143
for ipt in range(1, self.num_joints+1):
133144
anns[ipt]["bbox"] = anns[ipt]["bbox"] - np.array([diffW + self.w_pad, diffH + self.h_pad, diffW + self.w_pad, diffH + self.h_pad])
134145
anns[ipt]["center"] = anns[ipt]["center"] - np.array([diffW + self.w_pad, diffH + self.h_pad])
135146

136147
return img, anns
137148

138-
def _loadAnn(self, file, isBinary=False):
149+
def _loadAnn(self, file):
139150
'''
140151
:param file: pano jpg image path
141152
:return data: { 1 : {"bbox":[(0,0),(0,0),(0,0),(0,0)], "center":(0,0), "visible":1/0} , 2: {} , ... }
@@ -149,30 +160,46 @@ def _loadAnn(self, file, isBinary=False):
149160
f = open(file.replace('jpg', 'txt'), 'r')
150161
txt = f.readlines()
151162
f.close()
152-
for t in txt:
163+
164+
for idx, t in enumerate(txt):
153165
dict = {}
154166
t = t.replace('\n', '').split(', ')
155167

156168
x_coord = []
157169
y_coord = []
170+
coords = np.zeros((8, 2), np.int32)
158171

159172
for i in range(8):
173+
coords[i] = [centerX + int(t[2 * i + 3]), centerY + int(t[2 * i + 4])]
160174
x_coord.append(centerX + int(t[2 * i + 3]))
161175
y_coord.append(centerY + int(t[2 * i + 4]))
162176

163-
coords = [min(x_coord), min(y_coord), max(x_coord), max(y_coord)] # minX, minY, maxX, maxY
177+
bboxcoords = [min(x_coord), min(y_coord), max(x_coord), max(y_coord)] # minX, minY, maxX, maxY
164178

165-
dict["bbox"] = np.array([coords[0], coords[1], coords[2], coords[3]]) # from top-left corner, clockwise
179+
dict["bbox"] = np.array([bboxcoords[0], bboxcoords[1], bboxcoords[2], bboxcoords[3]]) # from top-left corner, clockwise
166180
dict["center"] = np.array([centerX + int(t[27]), centerY + int(t[28])])
167181

168182
if t[1]=='True':
169183
dict["visible"] = 1
170184
else:
171185
dict["visible"] = 0
172186

187+
dict["mask"] = coords
188+
173189
data[teeth_num[t[0]]] = dict
174190

175-
return data
191+
mask = self.getSegMask(np.zeros_like(img), data)
192+
193+
return data, mask
194+
195+
def getSegMask(self, img, data):
196+
for i in range(1, 33):
197+
visible = data[i]["visible"]
198+
mask = data[i]["mask"]
199+
if visible==1:
200+
cv2.fillConvexPoly(img, mask, (255,0,0))
201+
202+
return img
176203

177204
def showImage(self, img, anns):
178205
for ann in anns:

lib/models/__init__.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,4 @@
88
from __future__ import division
99
from __future__ import print_function
1010

11-
from __future__ import absolute_import
12-
from __future__ import division
13-
from __future__ import print_function
14-
15-
import models.pose_resnet
16-
import models.pose_hrnet
17-
import models.pose_hrnet_plus
11+
import models.panoNet

lib/utils/vis.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def save_images_landmark(meta, batch_landmark, batch_classification, prefix, i):
259259

260260
cls = batch_classification[0][i]
261261

262-
if cls>0:
262+
if True:
263263
cv2.putText(image, teeth_num[i+1], (int(joint_lm[0])+4, int(joint_lm[1])+4), cv2.FONT_ITALIC, 0.4,
264264
[255, 0, 0], 1)
265265
cv2.circle(image, (int(joint_lm[0]), int(joint_lm[1])), 2, [255, 0, 0], 2)

tools/test.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
import _init_paths
2626
from config import cfg
2727
from config import update_config
28-
from core.loss import JointsMSELoss, JointsCELoss
29-
from core.function_plus import test
28+
from core.loss import JointsMSELoss, JointsCELoss, JointsDistLoss
29+
from core.function import test
3030
from utils.utils import create_logger
3131

3232
import dataset
@@ -105,7 +105,8 @@ def main():
105105

106106
# classifierLoss = nn.MSELoss(reduction='mean').cuda()
107107
classifierLoss = JointsCELoss().cuda()
108-
lmloss = nn.MSELoss(reduction='mean').cuda()
108+
#lmloss = nn.MSELoss(reduction='mean').cuda()
109+
lmloss = JointsDistLoss().cuda()
109110

110111
# Data loading code
111112
test_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
@@ -124,7 +125,7 @@ def main():
124125
)
125126

126127
# evaluate on validation set
127-
test(cfg, test_loader, test_dataset, model, [classifierLoss, lmloss],
128+
test(cfg, test_loader, test_dataset, model, [heatmapLoss, lmloss],
128129
final_output_dir, tb_log_dir)
129130

130131

tools/train.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
import _init_paths
2727
from config import cfg
2828
from config import update_config
29-
from core.loss import JointsMSELoss, JointsCELoss
29+
from core.loss import JointsMSELoss, JointsCELoss, JointsDistLoss
3030
#from core.function import train
3131
#from core.function import validate
32-
from core.function_plus import train
33-
from core.function_plus import validate
32+
from core.function import train
33+
from core.function import validate
3434
from utils.utils import get_optimizer
3535
from utils.utils import save_checkpoint
3636
from utils.utils import create_logger
@@ -123,7 +123,8 @@ def main():
123123

124124
#classifierLoss = nn.MSELoss(reduction='mean').cuda()
125125
classifierLoss = JointsCELoss().cuda()
126-
lmloss = nn.MSELoss(reduction='mean').cuda()
126+
#lmloss = nn.MSELoss(reduction='mean').cuda()
127+
lmloss = JointsDistLoss().cuda()
127128

128129
# Data loading code
129130
train_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
@@ -185,13 +186,13 @@ def main():
185186
lr_scheduler.step()
186187

187188
# train for one epoch
188-
train(cfg, train_loader, model, [classifierLoss, lmloss], optimizer, epoch,
189+
train(cfg, train_loader, model, [heatmapLoss, lmloss], optimizer, epoch,
189190
final_output_dir, tb_log_dir, writer_dict)
190191

191192

192193
# evaluate on validation set
193194
perf_indicator = validate(
194-
cfg, valid_loader, valid_dataset, model, [classifierLoss, lmloss],
195+
cfg, valid_loader, valid_dataset, model, [heatmapLoss, lmloss],
195196
final_output_dir, tb_log_dir, writer_dict
196197
)
197198

0 commit comments

Comments
 (0)