22
22
from config import update_config
23
23
from core .function import get_final_preds
24
24
from utils .transforms import get_affine_transform
25
+ import pose_estimation .sort as Sort
25
26
26
27
import os
27
28
cur_dir = os .path .dirname (os .path .realpath (__file__ ))
@@ -94,26 +95,62 @@ def draw_bbox(box,img):
94
95
cv2 .rectangle (img , box [0 ], box [1 ], color = (0 , 255 , 0 ),thickness = 3 )
95
96
96
97
97
- def get_person_detection_boxes (model , img , threshold = 0.5 ):
98
+ def get_id_num (tracked_boxes ):
99
+ """
100
+ Get the SORT tracker ID number of the bounding box with the biggest area
101
+ """
102
+ max_area = 0
103
+ id_num = 0
104
+ for box in tracked_boxes :
105
+ box_area = (box [2 ] - box [0 ]) * (box [3 ] - box [1 ])
106
+ if box_area > max_area :
107
+ max_area = box_area
108
+ id_num = box [4 ]
109
+
110
+ return id_num
111
+
112
+
113
+ def get_person_detection_boxes (model , img , tracker , id_num , threshold = 0.5 ):
98
114
pred = model (img )
99
115
pred_classes = [COCO_INSTANCE_CATEGORY_NAMES [i ]
100
116
for i in list (pred [0 ]['labels' ].cpu ().numpy ())] # Get the Prediction Score
101
117
pred_boxes = [[(i [0 ], i [1 ]), (i [2 ], i [3 ])]
102
118
for i in list (pred [0 ]['boxes' ].detach ().cpu ().numpy ())] # Bounding boxes
103
119
pred_score = list (pred [0 ]['scores' ].detach ().cpu ().numpy ())
104
120
if not pred_score or max (pred_score )< threshold :
105
- return []
121
+ return [], id_num
122
+
106
123
# Get list of index with score greater than threshold
107
124
pred_t = [pred_score .index (x ) for x in pred_score if x > threshold ][- 1 ]
108
- pred_boxes = pred_boxes [:1 ]
125
+ pred_boxes = pred_boxes [:pred_t + 1 ]
109
126
pred_classes = pred_classes [:pred_t + 1 ]
110
127
111
128
person_boxes = []
112
129
for idx , box in enumerate (pred_boxes ):
113
130
if pred_classes [idx ] == 'person' :
131
+ # Create array of structure [bb_x1, bb_y1, bb_x2, bb_y2, score] for use with SORT tracker
132
+ box = [coord for pos in box for coord in pos ]
133
+ box .append (pred_score [idx ])
114
134
person_boxes .append (box )
135
+
136
+ # Get ID's for each person
137
+ person_boxes = np .array (person_boxes )
138
+ boxes_tracked = tracker .update (person_boxes )
139
+
140
+ # If this is the first frame, get the ID of the bigger bounding box (person more in focus, most likely the thrower)
141
+ if id_num is None :
142
+ id_num = get_id_num (boxes_tracked )
115
143
116
- return person_boxes
144
+ # Turn into [[(x1, y2), (x2, y2)]]
145
+ try :
146
+ person_box = [box for box in boxes_tracked if box [4 ] == id_num ][0 ]
147
+ person_box = [[(person_box [0 ], person_box [1 ]), (person_box [2 ], person_box [3 ])]]
148
+ return person_box , id_num
149
+
150
+ # If detections weren't made for our thrower in a frame for some reason, return nothing to be smoothed later
151
+ # As long as the thrower is detected within the next "max_age" frames, it will be assigned the same ID as before
152
+ except IndexError :
153
+ return [], id_num
117
154
118
155
119
156
def get_pose_estimation_prediction (pose_model , image , center , scale ):
@@ -211,7 +248,7 @@ class Bunch:
211
248
def __init__ (self , ** kwds ):
212
249
self .__dict__ .update (kwds )
213
250
214
- def get_deepHRnet_keypoints (video , output_dir = None , output_video = False , save_kpts = False , custom_model = None ):
251
+ def get_deepHRnet_keypoints (video , output_dir = None , output_video = False , save_kpts = False , custom_model = None , max_age = 3 ):
215
252
216
253
keypoints = None
217
254
# cudnn related setting
@@ -254,6 +291,10 @@ def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpt
254
291
vid_fps = vidcap .get (cv2 .CAP_PROP_FPS )
255
292
out = cv2 .VideoWriter (save_path ,fourcc , vid_fps , (int (vidcap .get (3 )),int (vidcap .get (4 ))))
256
293
294
+ # Initialize SORT Tracker
295
+ tracker = Sort .Sort (max_age = max_age )
296
+ id_num = None
297
+
257
298
frame_num = 0
258
299
while True :
259
300
ret , image_bgr = vidcap .read ()
@@ -269,14 +310,14 @@ def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpt
269
310
input .append (img_tensor )
270
311
271
312
# object detection box
272
- pred_boxes = get_person_detection_boxes (box_model , input , threshold = 0.95 )
313
+ pred_boxes , id_num = get_person_detection_boxes (box_model , input , tracker , id_num , threshold = 0.95 )
273
314
274
315
# pose estimation
275
316
if len (pred_boxes ) >= 1 :
276
317
for box in pred_boxes :
277
318
center , scale = box_to_center_scale (box , cfg .MODEL .IMAGE_SIZE [0 ], cfg .MODEL .IMAGE_SIZE [1 ])
278
319
image_pose = image .copy () if cfg .DATASET .COLOR_RGB else image_bgr .copy ()
279
- pose_preds = get_pose_estimation_prediction (pose_model , image_pose , center , scale )
320
+ pose_preds = get_pose_estimation_prediction (pose_model , image_pose , center , scale )
280
321
if len (pose_preds )>= 1 :
281
322
for i , kpt in enumerate (pose_preds ):
282
323
name = COCO_KEYPOINT_INDEXES [i ]
0 commit comments