Skip to content

Commit b24ff93

Browse files
authored
Merge pull request #6 from InfiniteSkyAI/multi_person_detection
Multi person Detection
2 parents b61f95b + 1fe10ca commit b24ff93

File tree

1 file changed

+48
-7
lines changed

1 file changed

+48
-7
lines changed

demo/demo.py

+48-7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from config import update_config
2323
from core.function import get_final_preds
2424
from utils.transforms import get_affine_transform
25+
import pose_estimation.sort as Sort
2526

2627
import os
2728
cur_dir = os.path.dirname(os.path.realpath(__file__))
@@ -94,26 +95,62 @@ def draw_bbox(box,img):
9495
cv2.rectangle(img, box[0], box[1], color=(0, 255, 0),thickness=3)
9596

9697

97-
def get_person_detection_boxes(model, img, threshold=0.5):
98+
def get_id_num(tracked_boxes):
99+
"""
100+
Get the SORT tracker ID number of the bounding box with the biggest area
101+
"""
102+
max_area = 0
103+
id_num = 0
104+
for box in tracked_boxes:
105+
box_area = (box[2] - box[0]) * (box[3] - box[1])
106+
if box_area > max_area:
107+
max_area = box_area
108+
id_num = box[4]
109+
110+
return id_num
111+
112+
113+
def get_person_detection_boxes(model, img, tracker, id_num, threshold=0.5):
98114
pred = model(img)
99115
pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i]
100116
for i in list(pred[0]['labels'].cpu().numpy())] # Get the Prediction Score
101117
pred_boxes = [[(i[0], i[1]), (i[2], i[3])]
102118
for i in list(pred[0]['boxes'].detach().cpu().numpy())] # Bounding boxes
103119
pred_score = list(pred[0]['scores'].detach().cpu().numpy())
104120
if not pred_score or max(pred_score)<threshold:
105-
return []
121+
return [], id_num
122+
106123
# Get list of index with score greater than threshold
107124
pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]
108-
pred_boxes = pred_boxes[:1]
125+
pred_boxes = pred_boxes[:pred_t+1]
109126
pred_classes = pred_classes[:pred_t+1]
110127

111128
person_boxes = []
112129
for idx, box in enumerate(pred_boxes):
113130
if pred_classes[idx] == 'person':
131+
# Create array of structure [bb_x1, bb_y1, bb_x2, bb_y2, score] for use with SORT tracker
132+
box = [coord for pos in box for coord in pos]
133+
box.append(pred_score[idx])
114134
person_boxes.append(box)
135+
136+
# Get ID's for each person
137+
person_boxes = np.array(person_boxes)
138+
boxes_tracked = tracker.update(person_boxes)
139+
140+
# If this is the first frame, get the ID of the bigger bounding box (person more in focus, most likely the thrower)
141+
if id_num is None:
142+
id_num = get_id_num(boxes_tracked)
115143

116-
return person_boxes
144+
# Turn into [[(x1, y2), (x2, y2)]]
145+
try:
146+
person_box = [box for box in boxes_tracked if box[4] == id_num][0]
147+
person_box = [[(person_box[0], person_box[1]), (person_box[2], person_box[3])]]
148+
return person_box, id_num
149+
150+
# If detections weren't made for our thrower in a frame for some reason, return nothing to be smoothed later
151+
# As long as the thrower is detected within the next "max_age" frames, it will be assigned the same ID as before
152+
except IndexError:
153+
return [], id_num
117154

118155

119156
def get_pose_estimation_prediction(pose_model, image, center, scale):
@@ -211,7 +248,7 @@ class Bunch:
211248
def __init__(self, **kwds):
212249
self.__dict__.update(kwds)
213250

214-
def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpts=False, custom_model=None):
251+
def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpts=False, custom_model=None, max_age=3):
215252

216253
keypoints = None
217254
# cudnn related setting
@@ -254,6 +291,10 @@ def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpt
254291
vid_fps = vidcap.get(cv2.CAP_PROP_FPS)
255292
out = cv2.VideoWriter(save_path,fourcc, vid_fps, (int(vidcap.get(3)),int(vidcap.get(4))))
256293

294+
# Initialize SORT Tracker
295+
tracker = Sort.Sort(max_age=max_age)
296+
id_num = None
297+
257298
frame_num = 0
258299
while True:
259300
ret, image_bgr = vidcap.read()
@@ -269,14 +310,14 @@ def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpt
269310
input.append(img_tensor)
270311

271312
# object detection box
272-
pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.95)
313+
pred_boxes, id_num = get_person_detection_boxes(box_model, input, tracker, id_num, threshold=0.95)
273314

274315
# pose estimation
275316
if len(pred_boxes) >= 1:
276317
for box in pred_boxes:
277318
center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
278319
image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy()
279-
pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale)
320+
pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale)
280321
if len(pose_preds)>=1:
281322
for i, kpt in enumerate(pose_preds):
282323
name = COCO_KEYPOINT_INDEXES[i]

0 commit comments

Comments
 (0)