Skip to content

Commit ab1b85f

Browse files
author
Depu Meng (FA Talent)
committedNov 19, 2019
added visualization codes
1 parent f793253 commit ab1b85f

File tree

3 files changed

+573
-0
lines changed

3 files changed

+573
-0
lines changed
 

‎README.md

+12
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,18 @@ python tools/train.py \
219219
--cfg experiments/coco/hrnet/w32_256x192_adam_lr1e-3.yaml \
220220
```
221221

222+
### Visualization
223+
224+
#### Visualizing predictions on COCO val
225+
226+
```
227+
python visualize/plot_coco.py --prediction [your/prediction/path.json]
228+
```
229+
#### Visualizing predictions on MPII test
230+
231+
```
232+
python visualize/plot_mpii.py --prediction [your/prediction/path.mat]
233+
```
222234

223235
### Other applications
224236
Many other dense prediction tasks, such as segmentation, face alignment and object detection, etc. have been benefited by HRNet. More information can be found at [High-Resolution Networks](https://github.com/HRNet).

‎visualize/plot_coco.py

+309
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
# ------------------------------------------------------------------------------
2+
# Copyright (c) Microsoft
3+
# Licensed under the MIT License.
4+
# Written by Ke Sun (sunk@mail.ustc.edu.cn)
5+
# Modified by Depu Meng (mdp@mail.ustc.edu.cn)
6+
# ------------------------------------------------------------------------------
7+
8+
import argparse
9+
import numpy as np
10+
import matplotlib.pyplot as plt
11+
import cv2
12+
import json
13+
import matplotlib.lines as mlines
14+
import matplotlib.patches as mpatches
15+
from pycocotools.coco import COCO
16+
from pycocotools.cocoeval import COCOeval
17+
import os
18+
19+
20+
class ColorStyle:
21+
def __init__(self, color, link_pairs, point_color):
22+
self.color = color
23+
self.link_pairs = link_pairs
24+
self.point_color = point_color
25+
26+
for i in range(len(self.color)):
27+
self.link_pairs[i].append(tuple(np.array(self.color[i])/255.))
28+
29+
self.ring_color = []
30+
for i in range(len(self.point_color)):
31+
self.ring_color.append(tuple(np.array(self.point_color[i])/255.))
32+
33+
# Xiaochu Style
34+
# (R,G,B)
35+
color1 = [(179,0,0),(228,26,28),(255,255,51),
36+
(49,163,84), (0,109,45), (255,255,51),
37+
(240,2,127),(240,2,127),(240,2,127), (240,2,127), (240,2,127),
38+
(217,95,14), (254,153,41),(255,255,51),
39+
(44,127,184),(0,0,255)]
40+
41+
link_pairs1 = [
42+
[15, 13], [13, 11], [11, 5],
43+
[12, 14], [14, 16], [12, 6],
44+
[3, 1],[1, 2],[1, 0],[0, 2],[2,4],
45+
[9, 7], [7,5], [5, 6],
46+
[6, 8], [8, 10],
47+
]
48+
49+
point_color1 = [(240,2,127),(240,2,127),(240,2,127),
50+
(240,2,127), (240,2,127),
51+
(255,255,51),(255,255,51),
52+
(254,153,41),(44,127,184),
53+
(217,95,14),(0,0,255),
54+
(255,255,51),(255,255,51),(228,26,28),
55+
(49,163,84),(252,176,243),(0,176,240),
56+
(255,255,0),(169, 209, 142),
57+
(255,255,0),(169, 209, 142),
58+
(255,255,0),(169, 209, 142)]
59+
60+
xiaochu_style = ColorStyle(color1, link_pairs1, point_color1)
61+
62+
63+
# Chunhua Style
64+
# (R,G,B)
65+
color2 = [(252,176,243),(252,176,243),(252,176,243),
66+
(0,176,240), (0,176,240), (0,176,240),
67+
(240,2,127),(240,2,127),(240,2,127), (240,2,127), (240,2,127),
68+
(255,255,0), (255,255,0),(169, 209, 142),
69+
(169, 209, 142),(169, 209, 142)]
70+
71+
link_pairs2 = [
72+
[15, 13], [13, 11], [11, 5],
73+
[12, 14], [14, 16], [12, 6],
74+
[3, 1],[1, 2],[1, 0],[0, 2],[2,4],
75+
[9, 7], [7,5], [5, 6], [6, 8], [8, 10],
76+
]
77+
78+
point_color2 = [(240,2,127),(240,2,127),(240,2,127),
79+
(240,2,127), (240,2,127),
80+
(255,255,0),(169, 209, 142),
81+
(255,255,0),(169, 209, 142),
82+
(255,255,0),(169, 209, 142),
83+
(252,176,243),(0,176,240),(252,176,243),
84+
(0,176,240),(252,176,243),(0,176,240),
85+
(255,255,0),(169, 209, 142),
86+
(255,255,0),(169, 209, 142),
87+
(255,255,0),(169, 209, 142)]
88+
89+
chunhua_style = ColorStyle(color2, link_pairs2, point_color2)
90+
91+
def parse_args():
92+
parser = argparse.ArgumentParser(description='Visualize COCO predictions')
93+
# general
94+
parser.add_argument('--image-path',
95+
help='Path of COCO val images',
96+
type=str,
97+
default='data/coco/images/val2017/'
98+
)
99+
100+
parser.add_argument('--gt-anno',
101+
help='Path of COCO val annotation',
102+
type=str,
103+
default='data/coco/annotations/person_keypoints_val2017.json'
104+
)
105+
106+
parser.add_argument('--save-path',
107+
help="Path to save the visualizations",
108+
type=str,
109+
default='visualize/coco/')
110+
111+
parser.add_argument('--prediction',
112+
help="Prediction file to visualize",
113+
type=str,
114+
required=True)
115+
116+
parser.add_argument('--style',
117+
help="Style of the visualization: Chunhua style or Xiaochu style",
118+
type=str,
119+
default='chunhua')
120+
121+
args = parser.parse_args()
122+
123+
return args
124+
125+
126+
def map_joint_dict(joints):
127+
joints_dict = {}
128+
for i in range(joints.shape[0]):
129+
x = int(joints[i][0])
130+
y = int(joints[i][1])
131+
id = i
132+
joints_dict[id] = (x, y)
133+
134+
return joints_dict
135+
136+
def plot(data, gt_file, img_path, save_path,
137+
link_pairs, ring_color, save=True):
138+
139+
# joints
140+
coco = COCO(gt_file)
141+
coco_dt = coco.loadRes(data)
142+
coco_eval = COCOeval(coco, coco_dt, 'keypoints')
143+
coco_eval._prepare()
144+
gts_ = coco_eval._gts
145+
dts_ = coco_eval._dts
146+
147+
p = coco_eval.params
148+
p.imgIds = list(np.unique(p.imgIds))
149+
if p.useCats:
150+
p.catIds = list(np.unique(p.catIds))
151+
p.maxDets = sorted(p.maxDets)
152+
153+
# loop through images, area range, max detection number
154+
catIds = p.catIds if p.useCats else [-1]
155+
threshold = 0.3
156+
joint_thres = 0.2
157+
for catId in catIds:
158+
for imgId in p.imgIds[:5000]:
159+
# dimention here should be Nxm
160+
gts = gts_[imgId, catId]
161+
dts = dts_[imgId, catId]
162+
inds = np.argsort([-d['score'] for d in dts], kind='mergesort')
163+
dts = [dts[i] for i in inds]
164+
if len(dts) > p.maxDets[-1]:
165+
dts = dts[0:p.maxDets[-1]]
166+
if len(gts) == 0 or len(dts) == 0:
167+
continue
168+
169+
sum_score = 0
170+
num_box = 0
171+
img_name = str(imgId).zfill(12)
172+
173+
# Read Images
174+
img_file = img_path + img_name + '.jpg'
175+
data_numpy = cv2.imread(img_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
176+
h = data_numpy.shape[0]
177+
w = data_numpy.shape[1]
178+
179+
# Plot
180+
fig = plt.figure(figsize=(w/100, h/100), dpi=100)
181+
ax = plt.subplot(1,1,1)
182+
bk = plt.imshow(data_numpy[:,:,::-1])
183+
bk.set_zorder(-1)
184+
print(img_name)
185+
for j, gt in enumerate(gts):
186+
# matching dt_box and gt_box
187+
bb = gt['bbox']
188+
x0 = bb[0] - bb[2]; x1 = bb[0] + bb[2] * 2
189+
y0 = bb[1] - bb[3]; y1 = bb[1] + bb[3] * 2
190+
191+
# create bounds for ignore regions(double the gt bbox)
192+
g = np.array(gt['keypoints'])
193+
#xg = g[0::3]; yg = g[1::3];
194+
vg = g[2::3]
195+
196+
for i, dt in enumerate(dts):
197+
# Calculate IoU
198+
dt_bb = dt['bbox']
199+
dt_x0 = dt_bb[0] - dt_bb[2]; dt_x1 = dt_bb[0] + dt_bb[2] * 2
200+
dt_y0 = dt_bb[1] - dt_bb[3]; dt_y1 = dt_bb[1] + dt_bb[3] * 2
201+
202+
ol_x = min(x1, dt_x1) - max(x0, dt_x0)
203+
ol_y = min(y1, dt_y1) - max(y0, dt_y0)
204+
ol_area = ol_x * ol_y
205+
s_x = max(x1, dt_x1) - min(x0, dt_x0)
206+
s_y = max(y1, dt_y1) - min(y0, dt_y0)
207+
sum_area = s_x * s_y
208+
iou = ol_area / (sum_area + np.spacing(1))
209+
score = dt['score']
210+
211+
if iou < 0.1 or score < threshold:
212+
continue
213+
else:
214+
print('iou: ', iou)
215+
dt_w = dt_x1 - dt_x0
216+
dt_h = dt_y1 - dt_y0
217+
ref = min(dt_w, dt_h)
218+
num_box += 1
219+
sum_score += dt['score']
220+
dt_joints = np.array(dt['keypoints']).reshape(17,-1)
221+
joints_dict = map_joint_dict(dt_joints)
222+
223+
# stick
224+
for k, link_pair in enumerate(link_pairs):
225+
if link_pair[0] in joints_dict \
226+
and link_pair[1] in joints_dict:
227+
if dt_joints[link_pair[0],2] < joint_thres \
228+
or dt_joints[link_pair[1],2] < joint_thres \
229+
or vg[link_pair[0]] == 0 \
230+
or vg[link_pair[1]] == 0:
231+
continue
232+
if k in range(6,11):
233+
lw = 1
234+
else:
235+
lw = ref / 100.
236+
line = mlines.Line2D(
237+
np.array([joints_dict[link_pair[0]][0],
238+
joints_dict[link_pair[1]][0]]),
239+
np.array([joints_dict[link_pair[0]][1],
240+
joints_dict[link_pair[1]][1]]),
241+
ls='-', lw=lw, alpha=1, color=link_pair[2],)
242+
line.set_zorder(0)
243+
ax.add_line(line)
244+
# black ring
245+
for k in range(dt_joints.shape[0]):
246+
if dt_joints[k,2] < joint_thres \
247+
or vg[link_pair[0]] == 0 \
248+
or vg[link_pair[1]] == 0:
249+
continue
250+
if dt_joints[k,0] > w or dt_joints[k,1] > h:
251+
continue
252+
if k in range(5):
253+
radius = 1
254+
else:
255+
radius = ref / 100
256+
257+
circle = mpatches.Circle(tuple(dt_joints[k,:2]),
258+
radius=radius,
259+
ec='black',
260+
fc=ring_color[k],
261+
alpha=1,
262+
linewidth=1)
263+
circle.set_zorder(1)
264+
ax.add_patch(circle)
265+
266+
avg_score = (sum_score / (num_box+np.spacing(1)))*1000
267+
268+
plt.gca().xaxis.set_major_locator(plt.NullLocator())
269+
plt.gca().yaxis.set_major_locator(plt.NullLocator())
270+
plt.axis('off')
271+
plt.subplots_adjust(top=1,bottom=0,left=0,right=1,hspace=0,wspace=0)
272+
plt.margins(0,0)
273+
if save:
274+
plt.savefig(save_path + \
275+
'score_'+str(np.int(avg_score))+ \
276+
'_id_'+str(imgId)+ \
277+
'_'+img_name + '.png',
278+
format='png', bbox_inckes='tight', dpi=100)
279+
plt.savefig(save_path +'id_'+str(imgId)+ '.pdf', format='pdf',
280+
bbox_inckes='tight', dpi=100)
281+
# plt.show()
282+
plt.close()
283+
284+
if __name__ == '__main__':
285+
286+
args = parse_args()
287+
if args.style == 'xiaochu':
288+
# Xiaochu Style
289+
colorstyle = xiaochu_style
290+
elif args.style == 'chunhua':
291+
# Chunhua Style
292+
colorstyle = chunhua_style
293+
else:
294+
raise Exception('Invalid color style')
295+
296+
save_path = args.save_path
297+
img_path = args.image_path
298+
if not os.path.exists(save_path):
299+
try:
300+
os.makedirs(save_path)
301+
except Exception:
302+
print('Fail to make {}'.format(save_path))
303+
304+
305+
with open(args.prediction) as f:
306+
data = json.load(f)
307+
gt_file = args.gt_anno
308+
plot(data, gt_file, img_path, save_path, colorstyle.link_pairs, colorstyle.ring_color, save=True)
309+

‎visualize/plot_mpii.py

+252
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
# ------------------------------------------------------------------------------
2+
# Copyright (c) Microsoft
3+
# Licensed under the MIT License.
4+
# Written by Ke Sun (sunk@mail.ustc.edu.cn)
5+
# Modified by Depu Meng (mdp@mail.ustc.edu.cn)
6+
# ------------------------------------------------------------------------------
7+
import matplotlib as mpl
8+
mpl.use('Agg')
9+
import numpy as np
10+
from scipy.io import loadmat
11+
import matplotlib.pyplot as plt
12+
import cv2
13+
import matplotlib.lines as mlines
14+
import matplotlib.patches as mpatches
15+
import math
16+
import os
17+
import argparse
18+
19+
20+
def parse_args():
21+
parser = argparse.ArgumentParser(description='Visualize COCO predictions')
22+
# general
23+
parser.add_argument('--image-path',
24+
help='Path of COCO val images',
25+
type=str,
26+
default='data/mpii/images/'
27+
)
28+
29+
parser.add_argument('--save-path',
30+
help="Path to save the visualizations",
31+
type=str,
32+
default='visualize/mpii/')
33+
34+
parser.add_argument('--prediction',
35+
help="Prediction file to visualize",
36+
type=str,
37+
required=True)
38+
39+
parser.add_argument('--style',
40+
help="Style of the visualization: chunhua style or xiaochu style or openpose style",
41+
type=str,
42+
default='chunhua')
43+
44+
args = parser.parse_args()
45+
46+
return args
47+
48+
49+
"""
50+
# pose track
51+
link_pairs = [
52+
[0, 1], [1, 2], [3, 4], [4, 5], [6, 7], [7, 8], [8, 9], [2, 3],
53+
[9, 10], [10, 11], [12, 13], [13, 14], [2, 8], [3, 9]
54+
]
55+
"""
56+
57+
# joint[0]->joint[1](color: joint[2])
58+
# MPII
59+
60+
class ColorStyle:
61+
def __init__(self, color, link_pairs, point_color, ignore_id):
62+
self.color = color
63+
self.link_pairs = link_pairs
64+
self.point_color = point_color
65+
self.ignore_id = ignore_id
66+
for i in range(len(self.color)):
67+
self.link_pairs[i].append(tuple(np.array(self.color[i])/255.))
68+
69+
self.ring_color = []
70+
for i in range(len(self.point_color)):
71+
self.ring_color.append(tuple(np.array(self.point_color[i])/255.))
72+
73+
74+
# XiaoChu
75+
# (R,G,B)
76+
color1 = [(0,109,45),(49,163,84),(255,255,51),(228,26,28),(179,0,0),
77+
(255,255,51), (240,2,127), (0,0,255), (44,127,184), (255,255,51),
78+
(255,255,51), (254,153,41), (217,95,14)]
79+
80+
link_pairs1 = [
81+
[0, 1], [1, 2], [2, 12], [3, 4], [4, 5], [3, 13], [8, 9],
82+
[10, 11], [11,12], [12, 7], [7, 13], [13, 14], [14, 15],
83+
]
84+
85+
point_color1 = [(0,109,45),(49,163,84),(255,255,51),
86+
(255,255,51),(228,26,28),(179,0,0),
87+
(255,255,51),(240,2,127), (240,2,127),
88+
(0,0,255), (44,127,184), (255,255,51),
89+
(255,255,51), (254,153,41), (217,95,14)]
90+
91+
ignore_id1 = [6]
92+
93+
xiaochu_style = ColorStyle(color1, link_pairs1, point_color1, ignore_id1)
94+
95+
# Chunhua
96+
# (R,G,B)
97+
color2 = [(252,176,243),(252,176,243),(252,176,243),
98+
(0,176,240), (0,176,240), (0,176,240),
99+
(165, 104, 210), (255,0,0),
100+
(255,255,0), (255,255,0), (255,255,0),
101+
(169, 209, 142),(169, 209, 142),(169, 209, 142)]
102+
103+
link_pairs2 = [
104+
[0, 1], [1, 2], [2, 6], [6,3], [3, 4], [4, 5], [6, 7], [8, 9],
105+
[10, 11], [11,12], [12, 8], [8, 13], [13, 14], [14, 15],
106+
]
107+
108+
point_color2 = [(252,176,243),(252,176,243),(252,176,243),
109+
(0,176,240),(0,176,240),(0,176,240),
110+
(165, 104, 210), (165, 104, 210), (255,0,0), (255,0,0),
111+
(255,255,0),(255,255,0), (255,255,0),
112+
(169, 209, 142), (169, 209, 142), (169, 209, 142)]
113+
114+
ignore_id2 = []
115+
116+
chunhua_style = ColorStyle(color2, link_pairs2, point_color2, ignore_id2)
117+
118+
# OpenPose
119+
# (R,G,B)
120+
color3 = [(121,67,226),(74,87,226),(47,118,177),
121+
(163,61,204), (216,53,204), (211,48,121),
122+
(63, 214, 217), (177,24,21),
123+
(43,192,128), (83,224,91), (111,210,58),
124+
(220, 132, 72),(194, 169, 37),(172, 214, 69)]
125+
126+
link_pairs3 = [
127+
[0, 1], [1, 2], [2, 6], [6,3], [3, 4], [4, 5], [6, 8], [8, 9],
128+
[10, 11], [11,12], [12, 8], [8, 13], [13, 14], [14, 15],
129+
]
130+
131+
point_color3 = [(121,67,226),(74,87,226),(74,87,226),
132+
(163,61,204),(216,53,204),(211,48,121),
133+
(63, 214, 217),(63, 214, 217), (177,24,21),
134+
(43,192,128), (83,224,91), (111,210,58),
135+
(220, 132, 72), (194, 169, 37), (172, 214, 69)]
136+
137+
ignore_id3 = [7]
138+
139+
openpose_style = ColorStyle(color3, link_pairs3, point_color3, ignore_id3)
140+
141+
142+
"""
143+
def map_joint_array(joints, ignore_id):
144+
new_joints = np.zeros((16,3))
145+
for i in range(joints.shape[1]):
146+
new_joints[i,2] = int(joints[0,i][2][0][0])
147+
if new_joints[i,2] in ignore_id:
148+
continue
149+
else:
150+
new_joints[i,2] = int(joints[0,i][2][0][0])
151+
new_joints[i,0] = int(joints[0,i][0][0,0])
152+
new_joints[i,1] = int(joints[0,i][1][0,0])
153+
return new_joints
154+
"""
155+
156+
def map_joint_array(joints, ignore_id):
157+
new_joints = []
158+
for i in range(joints.shape[1]):
159+
if int(joints[0,i][2][0][0]) in ignore_id:
160+
continue
161+
else:
162+
joint = [int(joints[0,i][j][0,0]) for j in range(3)]
163+
new_joints.append(joint)
164+
return np.array(new_joints)
165+
166+
167+
def map_joint_dict(joints):
168+
joints_dict = {}
169+
for i in range(joints.shape[1]):
170+
x = int(joints[0,i][0][0,0])
171+
y = int(joints[0,i][1][0,0])
172+
id = int(joints[0,i][2][0][0])
173+
joints_dict[id] = (x, y)
174+
175+
return joints_dict
176+
177+
def plot_joints(image, joints):
178+
for id, pos in joints.items():
179+
cv2.circle(image, pos, 3, (0,255,0), 2)
180+
181+
182+
183+
if __name__ == '__main__':
184+
args = parse_args()
185+
save_path = args.save_path
186+
if not os.path.exists(save_path):
187+
try:
188+
os.makedirs(save_path)
189+
except Exception:
190+
print('Fail to make {}'.format(save_path))
191+
pred = loadmat(args.prediction)['pred']
192+
193+
# change color style here
194+
if args.style == 'chunhua':
195+
color_style = chunhua_style
196+
elif args.style == 'xiaochu':
197+
color_style = xiaochu_style
198+
elif args.style == 'openpose':
199+
color_style = openpose_style
200+
201+
link_pairs = color_style.link_pairs
202+
ignore_id = color_style.ignore_id
203+
ring_color = color_style.ring_color
204+
205+
for i in range(0, 1000):
206+
if len(pred[0,i][1]) < 1 or pred[0,i][1][0,0] is None or len(pred[0,i][1][0,0]) < 3 or len(pred[0,i][1][0,0][2][0]) < 1:
207+
continue
208+
img_name = pred[0,i][0][0,0][0][0][:-4]
209+
print('id: ', i)
210+
img_file = args.image_path + pred[0,i][0][0,0][0][0]
211+
scale = pred[0,i][1][0,0][0][0,0]
212+
center = (pred[0,i][1][0,0][1][0,0][0][0,0],pred[0,i][1][0,0][1][0,0][1][0,0])
213+
joints = pred[0,i][1][0,0][2][0,0][0]
214+
joints_array = map_joint_array(joints, ignore_id)
215+
joints_dict = map_joint_dict(joints)
216+
217+
data_numpy = cv2.imread(img_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
218+
h = data_numpy.shape[0]
219+
w = data_numpy.shape[1]
220+
ref = np.min((h,w))
221+
fig = plt.figure(figsize=(w/100, h/100), dpi=100)
222+
ax = plt.subplot(1,1,1)
223+
bk = plt.imshow(data_numpy[:,:,::-1])
224+
bk.set_zorder(-1)
225+
226+
# stick
227+
for link_pair in link_pairs:
228+
if link_pair[0] in joints_dict \
229+
and link_pair[1] in joints_dict:
230+
line = mlines.Line2D(
231+
np.array([joints_dict[link_pair[0]][0],joints_dict[link_pair[1]][0]]),
232+
np.array([joints_dict[link_pair[0]][1],joints_dict[link_pair[1]][1]]),
233+
ls='-', lw=ref/90, alpha=1, color=link_pair[2],)
234+
line.set_zorder(0)
235+
ax.add_line(line)
236+
237+
# black ring
238+
for j in range(joints_array.shape[0]):
239+
circle = mpatches.Circle(tuple(joints_array[j,:2]), radius=ref/90,
240+
ec='black', fc=ring_color[j], alpha=1, linewidth=ref/270)
241+
circle.set_zorder(1)
242+
ax.add_patch(circle)
243+
244+
245+
plt.gca().xaxis.set_major_locator(plt.NullLocator())
246+
plt.gca().yaxis.set_major_locator(plt.NullLocator())
247+
plt.axis('off')
248+
plt.subplots_adjust(top=1,bottom=0,left=0,right=1,hspace=0,wspace=0)
249+
plt.margins(0,0)
250+
plt.savefig(save_path + 'id_' +str(i)+ '.pdf', format='pdf', bbox_inckes='tight', dpi=100)
251+
plt.savefig(save_path + 'id_' +str(i)+ '.png', format='png', bbox_inckes='tight', dpi=100)
252+
plt.close()

0 commit comments

Comments
 (0)
Please sign in to comment.