Skip to content

Commit 676df87

Browse files
committed
add comments
1 parent 2355340 commit 676df87

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

demo/demo.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def get_person_detection_boxes(model, img, tracker, id_num, threshold=0.5):
115115
for i in list(pred[0]['boxes'].detach().cpu().numpy())] # Bounding boxes
116116
pred_score = list(pred[0]['scores'].detach().cpu().numpy())
117117
if not pred_score or max(pred_score)<threshold:
118-
return []
118+
return [], id_num
119+
119120
# Get list of index with score greater than threshold
120121
pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]
121122
pred_boxes = pred_boxes[:pred_t+1]
@@ -124,7 +125,7 @@ def get_person_detection_boxes(model, img, tracker, id_num, threshold=0.5):
124125
person_boxes = []
125126
for idx, box in enumerate(pred_boxes):
126127
if pred_classes[idx] == 'person':
127-
# Create array of structure [bb_x1, bb_y1, bb_x2, bb_y2, score] for use with SORT
128+
# Create array of structure [bb_x1, bb_y1, bb_x2, bb_y2, score] for use with SORT tracker
128129
box = [coord for pos in box for coord in pos]
129130
box.append(pred_score[idx])
130131
person_boxes.append(box)
@@ -133,15 +134,20 @@ def get_person_detection_boxes(model, img, tracker, id_num, threshold=0.5):
133134
person_boxes = np.array(person_boxes)
134135
boxes_tracked = tracker.update(person_boxes)
135136

136-
# If this is the first frame, get the ID of the bigger bounding box (person more in focus)
137+
# If this is the first frame, get the ID of the bigger bounding box (person more in focus, most likely the thrower)
137138
if id_num is None:
138139
id_num = get_id_num(boxes_tracked)
139140

140141
# Turn into [[(x1, y2), (x2, y2)]]
141-
person_box = [box for box in boxes_tracked if box[4] == id_num][0]
142-
person_box = [[(person_box[0], person_box[1]), (person_box[2], person_box[3])]]
142+
try:
143+
person_box = [box for box in boxes_tracked if box[4] == id_num][0]
144+
person_box = [[(person_box[0], person_box[1]), (person_box[2], person_box[3])]]
145+
return person_box, id_num
143146

144-
return person_box, id_num
147+
# If detections weren't made for our thrower in a frame for some reason, return nothing to be smoothed later
148+
# As long as the thrower is detected within the next 3 frames, it will be assigned the same ID as before
149+
except IndexError:
150+
return [], id_num
145151

146152

147153
def get_pose_estimation_prediction(pose_model, image, center, scale):
@@ -282,6 +288,7 @@ def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpt
282288
vid_fps = vidcap.get(cv2.CAP_PROP_FPS)
283289
out = cv2.VideoWriter(save_path,fourcc, vid_fps, (int(vidcap.get(3)),int(vidcap.get(4))))
284290

291+
# Initialize SORT Tracker
285292
tracker = Sort.Sort(max_age=3)
286293
id_num = None
287294

0 commit comments

Comments
 (0)