|
| 1 | +# ------------------------------------------------------------------------------ |
| 2 | +# Copyright (c) Microsoft |
| 3 | +# Licensed under the MIT License. |
| 4 | +# Written by Ke Sun (sunk@mail.ustc.edu.cn) |
| 5 | +# Modified by Depu Meng (mdp@mail.ustc.edu.cn) |
| 6 | +# ------------------------------------------------------------------------------ |
| 7 | + |
| 8 | +import argparse |
| 9 | +import numpy as np |
| 10 | +import matplotlib.pyplot as plt |
| 11 | +import cv2 |
| 12 | +import json |
| 13 | +import matplotlib.lines as mlines |
| 14 | +import matplotlib.patches as mpatches |
| 15 | +from pycocotools.coco import COCO |
| 16 | +from pycocotools.cocoeval import COCOeval |
| 17 | +import os |
| 18 | + |
| 19 | + |
| 20 | +class ColorStyle: |
| 21 | + def __init__(self, color, link_pairs, point_color): |
| 22 | + self.color = color |
| 23 | + self.link_pairs = link_pairs |
| 24 | + self.point_color = point_color |
| 25 | + |
| 26 | + for i in range(len(self.color)): |
| 27 | + self.link_pairs[i].append(tuple(np.array(self.color[i])/255.)) |
| 28 | + |
| 29 | + self.ring_color = [] |
| 30 | + for i in range(len(self.point_color)): |
| 31 | + self.ring_color.append(tuple(np.array(self.point_color[i])/255.)) |
| 32 | + |
| 33 | +# Xiaochu Style |
| 34 | +# (R,G,B) |
| 35 | +color1 = [(179,0,0),(228,26,28),(255,255,51), |
| 36 | + (49,163,84), (0,109,45), (255,255,51), |
| 37 | + (240,2,127),(240,2,127),(240,2,127), (240,2,127), (240,2,127), |
| 38 | + (217,95,14), (254,153,41),(255,255,51), |
| 39 | + (44,127,184),(0,0,255)] |
| 40 | + |
| 41 | +link_pairs1 = [ |
| 42 | + [15, 13], [13, 11], [11, 5], |
| 43 | + [12, 14], [14, 16], [12, 6], |
| 44 | + [3, 1],[1, 2],[1, 0],[0, 2],[2,4], |
| 45 | + [9, 7], [7,5], [5, 6], |
| 46 | + [6, 8], [8, 10], |
| 47 | + ] |
| 48 | + |
| 49 | +point_color1 = [(240,2,127),(240,2,127),(240,2,127), |
| 50 | + (240,2,127), (240,2,127), |
| 51 | + (255,255,51),(255,255,51), |
| 52 | + (254,153,41),(44,127,184), |
| 53 | + (217,95,14),(0,0,255), |
| 54 | + (255,255,51),(255,255,51),(228,26,28), |
| 55 | + (49,163,84),(252,176,243),(0,176,240), |
| 56 | + (255,255,0),(169, 209, 142), |
| 57 | + (255,255,0),(169, 209, 142), |
| 58 | + (255,255,0),(169, 209, 142)] |
| 59 | + |
| 60 | +xiaochu_style = ColorStyle(color1, link_pairs1, point_color1) |
| 61 | + |
| 62 | + |
| 63 | +# Chunhua Style |
| 64 | +# (R,G,B) |
| 65 | +color2 = [(252,176,243),(252,176,243),(252,176,243), |
| 66 | + (0,176,240), (0,176,240), (0,176,240), |
| 67 | + (240,2,127),(240,2,127),(240,2,127), (240,2,127), (240,2,127), |
| 68 | + (255,255,0), (255,255,0),(169, 209, 142), |
| 69 | + (169, 209, 142),(169, 209, 142)] |
| 70 | + |
| 71 | +link_pairs2 = [ |
| 72 | + [15, 13], [13, 11], [11, 5], |
| 73 | + [12, 14], [14, 16], [12, 6], |
| 74 | + [3, 1],[1, 2],[1, 0],[0, 2],[2,4], |
| 75 | + [9, 7], [7,5], [5, 6], [6, 8], [8, 10], |
| 76 | + ] |
| 77 | + |
| 78 | +point_color2 = [(240,2,127),(240,2,127),(240,2,127), |
| 79 | + (240,2,127), (240,2,127), |
| 80 | + (255,255,0),(169, 209, 142), |
| 81 | + (255,255,0),(169, 209, 142), |
| 82 | + (255,255,0),(169, 209, 142), |
| 83 | + (252,176,243),(0,176,240),(252,176,243), |
| 84 | + (0,176,240),(252,176,243),(0,176,240), |
| 85 | + (255,255,0),(169, 209, 142), |
| 86 | + (255,255,0),(169, 209, 142), |
| 87 | + (255,255,0),(169, 209, 142)] |
| 88 | + |
| 89 | +chunhua_style = ColorStyle(color2, link_pairs2, point_color2) |
| 90 | + |
| 91 | +def parse_args(): |
| 92 | + parser = argparse.ArgumentParser(description='Visualize COCO predictions') |
| 93 | + # general |
| 94 | + parser.add_argument('--image-path', |
| 95 | + help='Path of COCO val images', |
| 96 | + type=str, |
| 97 | + default='data/coco/images/val2017/' |
| 98 | + ) |
| 99 | + |
| 100 | + parser.add_argument('--gt-anno', |
| 101 | + help='Path of COCO val annotation', |
| 102 | + type=str, |
| 103 | + default='data/coco/annotations/person_keypoints_val2017.json' |
| 104 | + ) |
| 105 | + |
| 106 | + parser.add_argument('--save-path', |
| 107 | + help="Path to save the visualizations", |
| 108 | + type=str, |
| 109 | + default='visualize/coco/') |
| 110 | + |
| 111 | + parser.add_argument('--prediction', |
| 112 | + help="Prediction file to visualize", |
| 113 | + type=str, |
| 114 | + required=True) |
| 115 | + |
| 116 | + parser.add_argument('--style', |
| 117 | + help="Style of the visualization: Chunhua style or Xiaochu style", |
| 118 | + type=str, |
| 119 | + default='chunhua') |
| 120 | + |
| 121 | + args = parser.parse_args() |
| 122 | + |
| 123 | + return args |
| 124 | + |
| 125 | + |
| 126 | +def map_joint_dict(joints): |
| 127 | + joints_dict = {} |
| 128 | + for i in range(joints.shape[0]): |
| 129 | + x = int(joints[i][0]) |
| 130 | + y = int(joints[i][1]) |
| 131 | + id = i |
| 132 | + joints_dict[id] = (x, y) |
| 133 | + |
| 134 | + return joints_dict |
| 135 | + |
| 136 | +def plot(data, gt_file, img_path, save_path, |
| 137 | + link_pairs, ring_color, save=True): |
| 138 | + |
| 139 | + # joints |
| 140 | + coco = COCO(gt_file) |
| 141 | + coco_dt = coco.loadRes(data) |
| 142 | + coco_eval = COCOeval(coco, coco_dt, 'keypoints') |
| 143 | + coco_eval._prepare() |
| 144 | + gts_ = coco_eval._gts |
| 145 | + dts_ = coco_eval._dts |
| 146 | + |
| 147 | + p = coco_eval.params |
| 148 | + p.imgIds = list(np.unique(p.imgIds)) |
| 149 | + if p.useCats: |
| 150 | + p.catIds = list(np.unique(p.catIds)) |
| 151 | + p.maxDets = sorted(p.maxDets) |
| 152 | + |
| 153 | + # loop through images, area range, max detection number |
| 154 | + catIds = p.catIds if p.useCats else [-1] |
| 155 | + threshold = 0.3 |
| 156 | + joint_thres = 0.2 |
| 157 | + for catId in catIds: |
| 158 | + for imgId in p.imgIds[:5000]: |
| 159 | + # dimention here should be Nxm |
| 160 | + gts = gts_[imgId, catId] |
| 161 | + dts = dts_[imgId, catId] |
| 162 | + inds = np.argsort([-d['score'] for d in dts], kind='mergesort') |
| 163 | + dts = [dts[i] for i in inds] |
| 164 | + if len(dts) > p.maxDets[-1]: |
| 165 | + dts = dts[0:p.maxDets[-1]] |
| 166 | + if len(gts) == 0 or len(dts) == 0: |
| 167 | + continue |
| 168 | + |
| 169 | + sum_score = 0 |
| 170 | + num_box = 0 |
| 171 | + img_name = str(imgId).zfill(12) |
| 172 | + |
| 173 | + # Read Images |
| 174 | + img_file = img_path + img_name + '.jpg' |
| 175 | + data_numpy = cv2.imread(img_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) |
| 176 | + h = data_numpy.shape[0] |
| 177 | + w = data_numpy.shape[1] |
| 178 | + |
| 179 | + # Plot |
| 180 | + fig = plt.figure(figsize=(w/100, h/100), dpi=100) |
| 181 | + ax = plt.subplot(1,1,1) |
| 182 | + bk = plt.imshow(data_numpy[:,:,::-1]) |
| 183 | + bk.set_zorder(-1) |
| 184 | + print(img_name) |
| 185 | + for j, gt in enumerate(gts): |
| 186 | + # matching dt_box and gt_box |
| 187 | + bb = gt['bbox'] |
| 188 | + x0 = bb[0] - bb[2]; x1 = bb[0] + bb[2] * 2 |
| 189 | + y0 = bb[1] - bb[3]; y1 = bb[1] + bb[3] * 2 |
| 190 | + |
| 191 | + # create bounds for ignore regions(double the gt bbox) |
| 192 | + g = np.array(gt['keypoints']) |
| 193 | + #xg = g[0::3]; yg = g[1::3]; |
| 194 | + vg = g[2::3] |
| 195 | + |
| 196 | + for i, dt in enumerate(dts): |
| 197 | + # Calculate IoU |
| 198 | + dt_bb = dt['bbox'] |
| 199 | + dt_x0 = dt_bb[0] - dt_bb[2]; dt_x1 = dt_bb[0] + dt_bb[2] * 2 |
| 200 | + dt_y0 = dt_bb[1] - dt_bb[3]; dt_y1 = dt_bb[1] + dt_bb[3] * 2 |
| 201 | + |
| 202 | + ol_x = min(x1, dt_x1) - max(x0, dt_x0) |
| 203 | + ol_y = min(y1, dt_y1) - max(y0, dt_y0) |
| 204 | + ol_area = ol_x * ol_y |
| 205 | + s_x = max(x1, dt_x1) - min(x0, dt_x0) |
| 206 | + s_y = max(y1, dt_y1) - min(y0, dt_y0) |
| 207 | + sum_area = s_x * s_y |
| 208 | + iou = ol_area / (sum_area + np.spacing(1)) |
| 209 | + score = dt['score'] |
| 210 | + |
| 211 | + if iou < 0.1 or score < threshold: |
| 212 | + continue |
| 213 | + else: |
| 214 | + print('iou: ', iou) |
| 215 | + dt_w = dt_x1 - dt_x0 |
| 216 | + dt_h = dt_y1 - dt_y0 |
| 217 | + ref = min(dt_w, dt_h) |
| 218 | + num_box += 1 |
| 219 | + sum_score += dt['score'] |
| 220 | + dt_joints = np.array(dt['keypoints']).reshape(17,-1) |
| 221 | + joints_dict = map_joint_dict(dt_joints) |
| 222 | + |
| 223 | + # stick |
| 224 | + for k, link_pair in enumerate(link_pairs): |
| 225 | + if link_pair[0] in joints_dict \ |
| 226 | + and link_pair[1] in joints_dict: |
| 227 | + if dt_joints[link_pair[0],2] < joint_thres \ |
| 228 | + or dt_joints[link_pair[1],2] < joint_thres \ |
| 229 | + or vg[link_pair[0]] == 0 \ |
| 230 | + or vg[link_pair[1]] == 0: |
| 231 | + continue |
| 232 | + if k in range(6,11): |
| 233 | + lw = 1 |
| 234 | + else: |
| 235 | + lw = ref / 100. |
| 236 | + line = mlines.Line2D( |
| 237 | + np.array([joints_dict[link_pair[0]][0], |
| 238 | + joints_dict[link_pair[1]][0]]), |
| 239 | + np.array([joints_dict[link_pair[0]][1], |
| 240 | + joints_dict[link_pair[1]][1]]), |
| 241 | + ls='-', lw=lw, alpha=1, color=link_pair[2],) |
| 242 | + line.set_zorder(0) |
| 243 | + ax.add_line(line) |
| 244 | + # black ring |
| 245 | + for k in range(dt_joints.shape[0]): |
| 246 | + if dt_joints[k,2] < joint_thres \ |
| 247 | + or vg[link_pair[0]] == 0 \ |
| 248 | + or vg[link_pair[1]] == 0: |
| 249 | + continue |
| 250 | + if dt_joints[k,0] > w or dt_joints[k,1] > h: |
| 251 | + continue |
| 252 | + if k in range(5): |
| 253 | + radius = 1 |
| 254 | + else: |
| 255 | + radius = ref / 100 |
| 256 | + |
| 257 | + circle = mpatches.Circle(tuple(dt_joints[k,:2]), |
| 258 | + radius=radius, |
| 259 | + ec='black', |
| 260 | + fc=ring_color[k], |
| 261 | + alpha=1, |
| 262 | + linewidth=1) |
| 263 | + circle.set_zorder(1) |
| 264 | + ax.add_patch(circle) |
| 265 | + |
| 266 | + avg_score = (sum_score / (num_box+np.spacing(1)))*1000 |
| 267 | + |
| 268 | + plt.gca().xaxis.set_major_locator(plt.NullLocator()) |
| 269 | + plt.gca().yaxis.set_major_locator(plt.NullLocator()) |
| 270 | + plt.axis('off') |
| 271 | + plt.subplots_adjust(top=1,bottom=0,left=0,right=1,hspace=0,wspace=0) |
| 272 | + plt.margins(0,0) |
| 273 | + if save: |
| 274 | + plt.savefig(save_path + \ |
| 275 | + 'score_'+str(np.int(avg_score))+ \ |
| 276 | + '_id_'+str(imgId)+ \ |
| 277 | + '_'+img_name + '.png', |
| 278 | + format='png', bbox_inckes='tight', dpi=100) |
| 279 | + plt.savefig(save_path +'id_'+str(imgId)+ '.pdf', format='pdf', |
| 280 | + bbox_inckes='tight', dpi=100) |
| 281 | + # plt.show() |
| 282 | + plt.close() |
| 283 | + |
| 284 | +if __name__ == '__main__': |
| 285 | + |
| 286 | + args = parse_args() |
| 287 | + if args.style == 'xiaochu': |
| 288 | + # Xiaochu Style |
| 289 | + colorstyle = xiaochu_style |
| 290 | + elif args.style == 'chunhua': |
| 291 | + # Chunhua Style |
| 292 | + colorstyle = chunhua_style |
| 293 | + else: |
| 294 | + raise Exception('Invalid color style') |
| 295 | + |
| 296 | + save_path = args.save_path |
| 297 | + img_path = args.image_path |
| 298 | + if not os.path.exists(save_path): |
| 299 | + try: |
| 300 | + os.makedirs(save_path) |
| 301 | + except Exception: |
| 302 | + print('Fail to make {}'.format(save_path)) |
| 303 | + |
| 304 | + |
| 305 | + with open(args.prediction) as f: |
| 306 | + data = json.load(f) |
| 307 | + gt_file = args.gt_anno |
| 308 | + plot(data, gt_file, img_path, save_path, colorstyle.link_pairs, colorstyle.ring_color, save=True) |
| 309 | + |
0 commit comments