@@ -115,7 +115,8 @@ def get_person_detection_boxes(model, img, tracker, id_num, threshold=0.5):
115
115
for i in list (pred [0 ]['boxes' ].detach ().cpu ().numpy ())] # Bounding boxes
116
116
pred_score = list (pred [0 ]['scores' ].detach ().cpu ().numpy ())
117
117
if not pred_score or max (pred_score )< threshold :
118
- return []
118
+ return [], id_num
119
+
119
120
# Get list of index with score greater than threshold
120
121
pred_t = [pred_score .index (x ) for x in pred_score if x > threshold ][- 1 ]
121
122
pred_boxes = pred_boxes [:pred_t + 1 ]
@@ -124,7 +125,7 @@ def get_person_detection_boxes(model, img, tracker, id_num, threshold=0.5):
124
125
person_boxes = []
125
126
for idx , box in enumerate (pred_boxes ):
126
127
if pred_classes [idx ] == 'person' :
127
- # Create array of structure [bb_x1, bb_y1, bb_x2, bb_y2, score] for use with SORT
128
+ # Create array of structure [bb_x1, bb_y1, bb_x2, bb_y2, score] for use with SORT tracker
128
129
box = [coord for pos in box for coord in pos ]
129
130
box .append (pred_score [idx ])
130
131
person_boxes .append (box )
@@ -133,15 +134,20 @@ def get_person_detection_boxes(model, img, tracker, id_num, threshold=0.5):
133
134
person_boxes = np .array (person_boxes )
134
135
boxes_tracked = tracker .update (person_boxes )
135
136
136
- # If this is the first frame, get the ID of the bigger bounding box (person more in focus)
137
+ # If this is the first frame, get the ID of the bigger bounding box (person more in focus, most likely the thrower )
137
138
if id_num is None :
138
139
id_num = get_id_num (boxes_tracked )
139
140
140
141
# Turn into [[(x1, y2), (x2, y2)]]
141
- person_box = [box for box in boxes_tracked if box [4 ] == id_num ][0 ]
142
- person_box = [[(person_box [0 ], person_box [1 ]), (person_box [2 ], person_box [3 ])]]
142
+ try :
143
+ person_box = [box for box in boxes_tracked if box [4 ] == id_num ][0 ]
144
+ person_box = [[(person_box [0 ], person_box [1 ]), (person_box [2 ], person_box [3 ])]]
145
+ return person_box , id_num
143
146
144
- return person_box , id_num
147
+ # If detections weren't made for our thrower in a frame for some reason, return nothing to be smoothed later
148
+ # As long as the thrower is detected within the next 3 frames, it will be assigned the same ID as before
149
+ except IndexError :
150
+ return [], id_num
145
151
146
152
147
153
def get_pose_estimation_prediction (pose_model , image , center , scale ):
@@ -282,6 +288,7 @@ def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpt
282
288
vid_fps = vidcap .get (cv2 .CAP_PROP_FPS )
283
289
out = cv2 .VideoWriter (save_path ,fourcc , vid_fps , (int (vidcap .get (3 )),int (vidcap .get (4 ))))
284
290
291
+ # Initialize SORT Tracker
285
292
tracker = Sort .Sort (max_age = 3 )
286
293
id_num = None
287
294
0 commit comments