Skip to content

Commit ab1b85f

Browse files
author
Depu Meng (FA Talent)
committed
added visualization codes
1 parent f793253 commit ab1b85f

File tree

3 files changed

+573
-0
lines changed

3 files changed

+573
-0
lines changed

README.md

+12
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,18 @@ python tools/train.py \
219219
--cfg experiments/coco/hrnet/w32_256x192_adam_lr1e-3.yaml \
220220
```
221221

222+
### Visualization
223+
224+
#### Visualizing predictions on COCO val
225+
226+
```
227+
python visualize/plot_coco.py --prediction [your/prediction/path.json]
228+
```
229+
#### Visualizing predictions on MPII test
230+
231+
```
232+
python visualize/plot_mpii.py --prediction [your/prediction/path.mat]
233+
```
222234

223235
### Other applications
224236
Many other dense prediction tasks, such as segmentation, face alignment and object detection, etc. have been benefited by HRNet. More information can be found at [High-Resolution Networks](https://github.com/HRNet).

visualize/plot_coco.py

+309
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
# ------------------------------------------------------------------------------
2+
# Copyright (c) Microsoft
3+
# Licensed under the MIT License.
4+
# Written by Ke Sun (sunk@mail.ustc.edu.cn)
5+
# Modified by Depu Meng (mdp@mail.ustc.edu.cn)
6+
# ------------------------------------------------------------------------------
7+
8+
import argparse
9+
import numpy as np
10+
import matplotlib.pyplot as plt
11+
import cv2
12+
import json
13+
import matplotlib.lines as mlines
14+
import matplotlib.patches as mpatches
15+
from pycocotools.coco import COCO
16+
from pycocotools.cocoeval import COCOeval
17+
import os
18+
19+
20+
class ColorStyle:
21+
def __init__(self, color, link_pairs, point_color):
22+
self.color = color
23+
self.link_pairs = link_pairs
24+
self.point_color = point_color
25+
26+
for i in range(len(self.color)):
27+
self.link_pairs[i].append(tuple(np.array(self.color[i])/255.))
28+
29+
self.ring_color = []
30+
for i in range(len(self.point_color)):
31+
self.ring_color.append(tuple(np.array(self.point_color[i])/255.))
32+
33+
# Xiaochu Style
34+
# (R,G,B)
35+
color1 = [(179,0,0),(228,26,28),(255,255,51),
36+
(49,163,84), (0,109,45), (255,255,51),
37+
(240,2,127),(240,2,127),(240,2,127), (240,2,127), (240,2,127),
38+
(217,95,14), (254,153,41),(255,255,51),
39+
(44,127,184),(0,0,255)]
40+
41+
link_pairs1 = [
42+
[15, 13], [13, 11], [11, 5],
43+
[12, 14], [14, 16], [12, 6],
44+
[3, 1],[1, 2],[1, 0],[0, 2],[2,4],
45+
[9, 7], [7,5], [5, 6],
46+
[6, 8], [8, 10],
47+
]
48+
49+
point_color1 = [(240,2,127),(240,2,127),(240,2,127),
50+
(240,2,127), (240,2,127),
51+
(255,255,51),(255,255,51),
52+
(254,153,41),(44,127,184),
53+
(217,95,14),(0,0,255),
54+
(255,255,51),(255,255,51),(228,26,28),
55+
(49,163,84),(252,176,243),(0,176,240),
56+
(255,255,0),(169, 209, 142),
57+
(255,255,0),(169, 209, 142),
58+
(255,255,0),(169, 209, 142)]
59+
60+
xiaochu_style = ColorStyle(color1, link_pairs1, point_color1)
61+
62+
63+
# Chunhua Style
64+
# (R,G,B)
65+
color2 = [(252,176,243),(252,176,243),(252,176,243),
66+
(0,176,240), (0,176,240), (0,176,240),
67+
(240,2,127),(240,2,127),(240,2,127), (240,2,127), (240,2,127),
68+
(255,255,0), (255,255,0),(169, 209, 142),
69+
(169, 209, 142),(169, 209, 142)]
70+
71+
link_pairs2 = [
72+
[15, 13], [13, 11], [11, 5],
73+
[12, 14], [14, 16], [12, 6],
74+
[3, 1],[1, 2],[1, 0],[0, 2],[2,4],
75+
[9, 7], [7,5], [5, 6], [6, 8], [8, 10],
76+
]
77+
78+
point_color2 = [(240,2,127),(240,2,127),(240,2,127),
79+
(240,2,127), (240,2,127),
80+
(255,255,0),(169, 209, 142),
81+
(255,255,0),(169, 209, 142),
82+
(255,255,0),(169, 209, 142),
83+
(252,176,243),(0,176,240),(252,176,243),
84+
(0,176,240),(252,176,243),(0,176,240),
85+
(255,255,0),(169, 209, 142),
86+
(255,255,0),(169, 209, 142),
87+
(255,255,0),(169, 209, 142)]
88+
89+
chunhua_style = ColorStyle(color2, link_pairs2, point_color2)
90+
91+
def parse_args():
92+
parser = argparse.ArgumentParser(description='Visualize COCO predictions')
93+
# general
94+
parser.add_argument('--image-path',
95+
help='Path of COCO val images',
96+
type=str,
97+
default='data/coco/images/val2017/'
98+
)
99+
100+
parser.add_argument('--gt-anno',
101+
help='Path of COCO val annotation',
102+
type=str,
103+
default='data/coco/annotations/person_keypoints_val2017.json'
104+
)
105+
106+
parser.add_argument('--save-path',
107+
help="Path to save the visualizations",
108+
type=str,
109+
default='visualize/coco/')
110+
111+
parser.add_argument('--prediction',
112+
help="Prediction file to visualize",
113+
type=str,
114+
required=True)
115+
116+
parser.add_argument('--style',
117+
help="Style of the visualization: Chunhua style or Xiaochu style",
118+
type=str,
119+
default='chunhua')
120+
121+
args = parser.parse_args()
122+
123+
return args
124+
125+
126+
def map_joint_dict(joints):
127+
joints_dict = {}
128+
for i in range(joints.shape[0]):
129+
x = int(joints[i][0])
130+
y = int(joints[i][1])
131+
id = i
132+
joints_dict[id] = (x, y)
133+
134+
return joints_dict
135+
136+
def plot(data, gt_file, img_path, save_path,
137+
link_pairs, ring_color, save=True):
138+
139+
# joints
140+
coco = COCO(gt_file)
141+
coco_dt = coco.loadRes(data)
142+
coco_eval = COCOeval(coco, coco_dt, 'keypoints')
143+
coco_eval._prepare()
144+
gts_ = coco_eval._gts
145+
dts_ = coco_eval._dts
146+
147+
p = coco_eval.params
148+
p.imgIds = list(np.unique(p.imgIds))
149+
if p.useCats:
150+
p.catIds = list(np.unique(p.catIds))
151+
p.maxDets = sorted(p.maxDets)
152+
153+
# loop through images, area range, max detection number
154+
catIds = p.catIds if p.useCats else [-1]
155+
threshold = 0.3
156+
joint_thres = 0.2
157+
for catId in catIds:
158+
for imgId in p.imgIds[:5000]:
159+
# dimention here should be Nxm
160+
gts = gts_[imgId, catId]
161+
dts = dts_[imgId, catId]
162+
inds = np.argsort([-d['score'] for d in dts], kind='mergesort')
163+
dts = [dts[i] for i in inds]
164+
if len(dts) > p.maxDets[-1]:
165+
dts = dts[0:p.maxDets[-1]]
166+
if len(gts) == 0 or len(dts) == 0:
167+
continue
168+
169+
sum_score = 0
170+
num_box = 0
171+
img_name = str(imgId).zfill(12)
172+
173+
# Read Images
174+
img_file = img_path + img_name + '.jpg'
175+
data_numpy = cv2.imread(img_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
176+
h = data_numpy.shape[0]
177+
w = data_numpy.shape[1]
178+
179+
# Plot
180+
fig = plt.figure(figsize=(w/100, h/100), dpi=100)
181+
ax = plt.subplot(1,1,1)
182+
bk = plt.imshow(data_numpy[:,:,::-1])
183+
bk.set_zorder(-1)
184+
print(img_name)
185+
for j, gt in enumerate(gts):
186+
# matching dt_box and gt_box
187+
bb = gt['bbox']
188+
x0 = bb[0] - bb[2]; x1 = bb[0] + bb[2] * 2
189+
y0 = bb[1] - bb[3]; y1 = bb[1] + bb[3] * 2
190+
191+
# create bounds for ignore regions(double the gt bbox)
192+
g = np.array(gt['keypoints'])
193+
#xg = g[0::3]; yg = g[1::3];
194+
vg = g[2::3]
195+
196+
for i, dt in enumerate(dts):
197+
# Calculate IoU
198+
dt_bb = dt['bbox']
199+
dt_x0 = dt_bb[0] - dt_bb[2]; dt_x1 = dt_bb[0] + dt_bb[2] * 2
200+
dt_y0 = dt_bb[1] - dt_bb[3]; dt_y1 = dt_bb[1] + dt_bb[3] * 2
201+
202+
ol_x = min(x1, dt_x1) - max(x0, dt_x0)
203+
ol_y = min(y1, dt_y1) - max(y0, dt_y0)
204+
ol_area = ol_x * ol_y
205+
s_x = max(x1, dt_x1) - min(x0, dt_x0)
206+
s_y = max(y1, dt_y1) - min(y0, dt_y0)
207+
sum_area = s_x * s_y
208+
iou = ol_area / (sum_area + np.spacing(1))
209+
score = dt['score']
210+
211+
if iou < 0.1 or score < threshold:
212+
continue
213+
else:
214+
print('iou: ', iou)
215+
dt_w = dt_x1 - dt_x0
216+
dt_h = dt_y1 - dt_y0
217+
ref = min(dt_w, dt_h)
218+
num_box += 1
219+
sum_score += dt['score']
220+
dt_joints = np.array(dt['keypoints']).reshape(17,-1)
221+
joints_dict = map_joint_dict(dt_joints)
222+
223+
# stick
224+
for k, link_pair in enumerate(link_pairs):
225+
if link_pair[0] in joints_dict \
226+
and link_pair[1] in joints_dict:
227+
if dt_joints[link_pair[0],2] < joint_thres \
228+
or dt_joints[link_pair[1],2] < joint_thres \
229+
or vg[link_pair[0]] == 0 \
230+
or vg[link_pair[1]] == 0:
231+
continue
232+
if k in range(6,11):
233+
lw = 1
234+
else:
235+
lw = ref / 100.
236+
line = mlines.Line2D(
237+
np.array([joints_dict[link_pair[0]][0],
238+
joints_dict[link_pair[1]][0]]),
239+
np.array([joints_dict[link_pair[0]][1],
240+
joints_dict[link_pair[1]][1]]),
241+
ls='-', lw=lw, alpha=1, color=link_pair[2],)
242+
line.set_zorder(0)
243+
ax.add_line(line)
244+
# black ring
245+
for k in range(dt_joints.shape[0]):
246+
if dt_joints[k,2] < joint_thres \
247+
or vg[link_pair[0]] == 0 \
248+
or vg[link_pair[1]] == 0:
249+
continue
250+
if dt_joints[k,0] > w or dt_joints[k,1] > h:
251+
continue
252+
if k in range(5):
253+
radius = 1
254+
else:
255+
radius = ref / 100
256+
257+
circle = mpatches.Circle(tuple(dt_joints[k,:2]),
258+
radius=radius,
259+
ec='black',
260+
fc=ring_color[k],
261+
alpha=1,
262+
linewidth=1)
263+
circle.set_zorder(1)
264+
ax.add_patch(circle)
265+
266+
avg_score = (sum_score / (num_box+np.spacing(1)))*1000
267+
268+
plt.gca().xaxis.set_major_locator(plt.NullLocator())
269+
plt.gca().yaxis.set_major_locator(plt.NullLocator())
270+
plt.axis('off')
271+
plt.subplots_adjust(top=1,bottom=0,left=0,right=1,hspace=0,wspace=0)
272+
plt.margins(0,0)
273+
if save:
274+
plt.savefig(save_path + \
275+
'score_'+str(np.int(avg_score))+ \
276+
'_id_'+str(imgId)+ \
277+
'_'+img_name + '.png',
278+
format='png', bbox_inckes='tight', dpi=100)
279+
plt.savefig(save_path +'id_'+str(imgId)+ '.pdf', format='pdf',
280+
bbox_inckes='tight', dpi=100)
281+
# plt.show()
282+
plt.close()
283+
284+
if __name__ == '__main__':
285+
286+
args = parse_args()
287+
if args.style == 'xiaochu':
288+
# Xiaochu Style
289+
colorstyle = xiaochu_style
290+
elif args.style == 'chunhua':
291+
# Chunhua Style
292+
colorstyle = chunhua_style
293+
else:
294+
raise Exception('Invalid color style')
295+
296+
save_path = args.save_path
297+
img_path = args.image_path
298+
if not os.path.exists(save_path):
299+
try:
300+
os.makedirs(save_path)
301+
except Exception:
302+
print('Fail to make {}'.format(save_path))
303+
304+
305+
with open(args.prediction) as f:
306+
data = json.load(f)
307+
gt_file = args.gt_anno
308+
plot(data, gt_file, img_path, save_path, colorstyle.link_pairs, colorstyle.ring_color, save=True)
309+

0 commit comments

Comments
 (0)