From 3ba41251f8810218dbe2594182c2b74aafd2eba0 Mon Sep 17 00:00:00 2001 From: John-HarringtonNZ Date: Wed, 26 May 2021 15:37:09 -0400 Subject: [PATCH 01/23] Update README.md --- demo/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/demo/README.md b/demo/README.md index aff81f44..1fa14aee 100644 --- a/demo/README.md +++ b/demo/README.md @@ -2,6 +2,8 @@ Inferencing the deep-high-resolution-net.pytoch without using Docker. +To run pre-built model, download 'models/pytorch/pose_coco/pose_hrnet_w32_384x288.pth' from google drive. + ## Prep 1. Download the researchers' pretrained pose estimator from [google drive](https://drive.google.com/drive/folders/1hOTihvbyIxsm5ygDpbUuJ7O_tzv4oXjC?usp=sharing) to this directory under `models/` 2. Put the video file you'd like to infer on in this directory under `videos` @@ -72,4 +74,4 @@ Fig: show fps ![multi-people](inference_7.jpg) -Fig: multi-people \ No newline at end of file +Fig: multi-people From ae7b3dbc2474a3d066f2ed284cf327409e6c6111 Mon Sep 17 00:00:00 2001 From: John Harrington Date: Wed, 26 May 2021 17:39:26 -0400 Subject: [PATCH 02/23] modified demo file for end-to-end system test --- demo/demo.py | 165 +++++++++++++++++++++++---------------------------- 1 file changed, 73 insertions(+), 92 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index d482e838..fca303df 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -8,6 +8,7 @@ import shutil from PIL import Image +from numpy.lib.npyio import savetxt import torch import torch.nn.parallel import torch.backends.cudnn as cudnn @@ -48,6 +49,8 @@ 16: 'right_ankle' } +test_dict = COCO_KEYPOINT_INDEXES + COCO_INSTANCE_CATEGORY_NAMES = [ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', @@ -107,7 +110,7 @@ def get_person_detection_boxes(model, img, threshold=0.5): return [] # Get list of index with score greater than threshold pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1] - pred_boxes = pred_boxes[:pred_t+1] + pred_boxes = pred_boxes[:1] pred_classes = pred_classes[:pred_t+1] person_boxes = [] @@ -200,6 +203,7 @@ def parse_args(): parser.add_argument('--image',type=str) parser.add_argument('--write',action='store_true') parser.add_argument('--showFps',action='store_true') + parser.add_argument('--output_dir',type=str, default='/') parser.add_argument('opts', help='Modify config options using the command-line', @@ -217,6 +221,8 @@ def parse_args(): def main(): + + keypoints = None # cudnn related setting cudnn.benchmark = cfg.CUDNN.BENCHMARK torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC @@ -244,100 +250,75 @@ def main(): pose_model.eval() # Loading an video or an image or webcam - if args.webcam: - vidcap = cv2.VideoCapture(0) - elif args.video: + # if args.webcam: + # vidcap = cv2.VideoCapture(0) + if args.video: vidcap = cv2.VideoCapture(args.video) - elif args.image: - image_bgr = cv2.imread(args.image) + # elif args.image: + # image_bgr = cv2.imread(args.image) else: print('please use --video or --webcam or --image to define the input.') - return - - if args.webcam or args.video: - if args.write: - save_path = 'output.avi' - fourcc = cv2.VideoWriter_fourcc(*'XVID') - out = cv2.VideoWriter(save_path,fourcc, 24.0, (int(vidcap.get(3)),int(vidcap.get(4)))) - while True: - ret, image_bgr = vidcap.read() - if ret: - last_time = time.time() - image = image_bgr[:, :, [2, 1, 0]] - - input = [] - img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) - img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().to(CTX) - input.append(img_tensor) - - # object detection box - pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.9) - - # pose estimation - if len(pred_boxes) >= 1: - for box in pred_boxes: - center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]) - image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy() - pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale) - if len(pose_preds)>=1: - for kpt in pose_preds: - draw_pose(kpt,image_bgr) # draw the poses - - if args.showFps: - fps = 1/(time.time()-last_time) - img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2) - - if args.write: - out.write(image_bgr) - - cv2.imshow('demo',image_bgr) - if cv2.waitKey(1) & 0XFF==ord('q'): - break - else: - print('cannot load the video.') - break - - cv2.destroyAllWindows() - vidcap.release() - if args.write: - print('video has been saved as {}'.format(save_path)) - out.release() - - else: - # estimate on the image - last_time = time.time() - image = image_bgr[:, :, [2, 1, 0]] - - input = [] - img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) - img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().to(CTX) - input.append(img_tensor) - - # object detection box - pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.9) - - # pose estimation - if len(pred_boxes) >= 1: - for box in pred_boxes: - center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]) - image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy() - pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale) - if len(pose_preds)>=1: - for kpt in pose_preds: - draw_pose(kpt,image_bgr) # draw the poses - - if args.showFps: - fps = 1/(time.time()-last_time) - img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2) - - if args.write: - save_path = 'output.jpg' - cv2.imwrite(save_path,image_bgr) - print('the result image has been saved as {}'.format(save_path)) - - cv2.imshow('demo',image_bgr) - if cv2.waitKey(0) & 0XFF==ord('q'): - cv2.destroyAllWindows() + return + + # if args.webcam or args.video: + if args.write: + save_path = args.output_dir + "/output.avi" + print(save_path) + fourcc = cv2.VideoWriter_fourcc(*'XVID') + out = cv2.VideoWriter(save_path,fourcc, 24.0, (int(vidcap.get(3)),int(vidcap.get(4)))) + + while True: + ret, image_bgr = vidcap.read() + if ret: + last_time = time.time() + image = image_bgr[:, :, [2, 1, 0]] + + input = [] + img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().to(CTX) + input.append(img_tensor) + + # object detection box + pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.95) + + # pose estimation + if len(pred_boxes) >= 1: + for box in pred_boxes: + center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]) + image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy() + pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale) + if len(pose_preds)>=1: + for i, kpt in enumerate(pose_preds): + name = COCO_KEYPOINT_INDEXES[i] + if keypoints is None: + keypoints = np.array([kpt]) + else: + keypoints = np.append(keypoints, [kpt], axis = 0) + + #print(f"{name}: {kpt}") + draw_pose(kpt,image_bgr) # draw the poses + + if args.showFps: + fps = 1/(time.time()-last_time) + img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2) + + if args.write: + out.write(image_bgr) + + #cv2.imshow('demo',image_bgr) + #if cv2.waitKey(1) & 0XFF==ord('q'): + # break + else: + print('Video ended') + break + + print(keypoints) + np.save(f"{args.output_dir}/keypoints.dat", keypoints) + cv2.destroyAllWindows() + vidcap.release() + if args.write: + print('video has been saved as {}'.format(save_path)) + out.release() if __name__ == '__main__': main() From bd13fe7f33a7b027fa60e0b65b08447a198be362 Mon Sep 17 00:00:00 2001 From: John Harrington Date: Thu, 27 May 2021 11:36:19 -0400 Subject: [PATCH 03/23] removed unused line --- demo/demo.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index fca303df..e83ee257 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -49,8 +49,6 @@ 16: 'right_ankle' } -test_dict = COCO_KEYPOINT_INDEXES - COCO_INSTANCE_CATEGORY_NAMES = [ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', From 420bca884bff114038024d2f5edd151a344b9e97 Mon Sep 17 00:00:00 2001 From: John Harrington Date: Thu, 27 May 2021 22:52:58 -0400 Subject: [PATCH 04/23] modified demo script --- demo/demo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index e83ee257..ca112f62 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -310,8 +310,8 @@ def main(): print('Video ended') break - print(keypoints) - np.save(f"{args.output_dir}/keypoints.dat", keypoints) + np.save(f"{args.output_dir}/keypoints", keypoints) + print(f'keypoint saved to {args.output_dir}/keypoints.npy') cv2.destroyAllWindows() vidcap.release() if args.write: From 8a38dfd12523534a380de6cd3784b5abcc4f35be Mon Sep 17 00:00:00 2001 From: John Harrington Date: Tue, 1 Jun 2021 11:56:09 -0400 Subject: [PATCH 05/23] Added correct FPS --- demo/demo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/demo/demo.py b/demo/demo.py index ca112f62..02ea9c8a 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -263,7 +263,8 @@ def main(): save_path = args.output_dir + "/output.avi" print(save_path) fourcc = cv2.VideoWriter_fourcc(*'XVID') - out = cv2.VideoWriter(save_path,fourcc, 24.0, (int(vidcap.get(3)),int(vidcap.get(4)))) + vid_fps = vidcap.get(cv2.CAP_PROP_FPS) + out = cv2.VideoWriter(save_path,fourcc, vid_fps, (int(vidcap.get(3)),int(vidcap.get(4)))) while True: ret, image_bgr = vidcap.read() From 5b64aed1dbdfc7efa5336aa63d7477bfa3f52332 Mon Sep 17 00:00:00 2001 From: John Harrington Date: Wed, 2 Jun 2021 10:22:49 -0400 Subject: [PATCH 06/23] Removed unnecessary imports --- demo/demo.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index 02ea9c8a..366fcc63 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -3,12 +3,7 @@ from __future__ import print_function import argparse -import csv -import os -import shutil -from PIL import Image -from numpy.lib.npyio import savetxt import torch import torch.nn.parallel import torch.backends.cudnn as cudnn From 4c2769888dd12469b482de54c56a409fd14b48a2 Mon Sep 17 00:00:00 2001 From: John Harrington Date: Wed, 2 Jun 2021 14:26:50 -0400 Subject: [PATCH 07/23] Cleaned up demo code --- demo/demo.py | 34 +++++++++------------------------- 1 file changed, 9 insertions(+), 25 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index 366fcc63..1633ea39 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -16,7 +16,6 @@ import numpy as np import time - import _init_paths import models from config import cfg @@ -187,6 +186,7 @@ def box_to_center_scale(box, model_image_width, model_image_height): return center, scale + def parse_args(): parser = argparse.ArgumentParser(description='Train keypoints network') # general @@ -242,25 +242,13 @@ def main(): pose_model.to(CTX) pose_model.eval() - # Loading an video or an image or webcam - # if args.webcam: - # vidcap = cv2.VideoCapture(0) - if args.video: - vidcap = cv2.VideoCapture(args.video) - # elif args.image: - # image_bgr = cv2.imread(args.image) - else: - print('please use --video or --webcam or --image to define the input.') - return + # Loading an video or an video + vidcap = cv2.VideoCapture(args.video) + save_path = args.output_dir + "/output.avi" + fourcc = cv2.VideoWriter_fourcc(*'XVID') + vid_fps = vidcap.get(cv2.CAP_PROP_FPS) + out = cv2.VideoWriter(save_path,fourcc, vid_fps, (int(vidcap.get(3)),int(vidcap.get(4)))) - # if args.webcam or args.video: - if args.write: - save_path = args.output_dir + "/output.avi" - print(save_path) - fourcc = cv2.VideoWriter_fourcc(*'XVID') - vid_fps = vidcap.get(cv2.CAP_PROP_FPS) - out = cv2.VideoWriter(save_path,fourcc, vid_fps, (int(vidcap.get(3)),int(vidcap.get(4)))) - while True: ret, image_bgr = vidcap.read() if ret: @@ -288,8 +276,6 @@ def main(): keypoints = np.array([kpt]) else: keypoints = np.append(keypoints, [kpt], axis = 0) - - #print(f"{name}: {kpt}") draw_pose(kpt,image_bgr) # draw the poses if args.showFps: @@ -299,9 +285,6 @@ def main(): if args.write: out.write(image_bgr) - #cv2.imshow('demo',image_bgr) - #if cv2.waitKey(1) & 0XFF==ord('q'): - # break else: print('Video ended') break @@ -313,6 +296,7 @@ def main(): if args.write: print('video has been saved as {}'.format(save_path)) out.release() - + + if __name__ == '__main__': main() From b3386c1477e292fd8ad78e4e7aebc0087ea94d69 Mon Sep 17 00:00:00 2001 From: John Harrington Date: Thu, 3 Jun 2021 14:54:10 -0400 Subject: [PATCH 08/23] modified requirements --- requirements.txt | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index 14f225c7..45184461 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -EasyDict==1.7 -opencv-python==3.4.1.15 -shapely==1.6.4 +EasyDict>=1.7 +opencv-python +shapely>=1.6.4 Cython scipy pandas @@ -8,4 +8,5 @@ pyyaml json_tricks scikit-image yacs>=0.1.5 -tensorboardX==1.6 +tensorboardX>=1.6 +torchvision \ No newline at end of file From 147e6f871eb16ebe8190f33903da4deadc8c8bb8 Mon Sep 17 00:00:00 2001 From: John Harrington Date: Fri, 4 Jun 2021 10:00:54 -0400 Subject: [PATCH 09/23] testing possible cpu use --- demo/demo.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/demo/demo.py b/demo/demo.py index 1633ea39..777ee11e 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -234,7 +234,11 @@ def main(): if cfg.TEST.MODEL_FILE: print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE)) - pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False) + if torch.cuda.is_available(): + pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False) + else: + pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE, map_location='cpu'), strict=False) + else: print('expected model defined in config at TEST.MODEL_FILE') From de698ba639b1cc7fdfa5d195bc6ef1ef0d0b3080 Mon Sep 17 00:00:00 2001 From: John Harrington Date: Fri, 4 Jun 2021 14:20:10 -0400 Subject: [PATCH 10/23] Added README and cpu code --- README.md | 12 ++++++++++++ demo/demo.py | 2 -- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 2ff1c717..3f572c98 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,15 @@ +# Ifinite Sky Fork of Deep HRnet 2D pose estimation + +## Installation Instructions +- Once this repo is cloned (using the build_stack.sh script in the `bootstrap` repository), the pre-trained model must be manually added. This may be updated in the future, but is not currently. +- The model is located at [this link](https://drive.google.com/drive/folders/1PufGmj1jHq3HSHr23Vne7UqQ2AETOgY4). The `models.zip` file should be downloaded and then extracted to the main directory for this repo (eg. the path `deep-high-resolution-net.pytorch/models/pytorch` should be a legitimate path). Once this model is placed, the detection can take place. +- Other dependencies for this repository are located in the `requirements.txt` file, and should be installed during the `build_env.sh` script, also located in the `bootstrap` repository. +- Everything below this is from the `README` on the orginal repository that was forked, for reference sake. + +## Inputs & Outputs +- The main file currently interfaced in this repository is `demo.py`, located under the `demo/` directory. +- This file takes in the input file path (currently restricted to `.avi` format, but this can be extended for flexibility). + # Deep High-Resolution Representation Learning for Human Pose Estimation (CVPR 2019) ## News - [2021/04/12] Welcome to check out our recent work on bottom-up pose estimation (CVPR 2021) [HRNet-DEKR](https://github.com/HRNet/DEKR)! diff --git a/demo/demo.py b/demo/demo.py index 777ee11e..5ce13c66 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -192,8 +192,6 @@ def parse_args(): # general parser.add_argument('--cfg', type=str, default='demo/inference-config.yaml') parser.add_argument('--video', type=str) - parser.add_argument('--webcam',action='store_true') - parser.add_argument('--image',type=str) parser.add_argument('--write',action='store_true') parser.add_argument('--showFps',action='store_true') parser.add_argument('--output_dir',type=str, default='/') From 554eb9da47b0f89003dd347c037c8abc90626ce8 Mon Sep 17 00:00:00 2001 From: John Harrington Date: Fri, 4 Jun 2021 14:56:54 -0400 Subject: [PATCH 11/23] more README info --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 3f572c98..316df333 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,8 @@ ## Inputs & Outputs - The main file currently interfaced in this repository is `demo.py`, located under the `demo/` directory. - This file takes in the input file path (currently restricted to `.avi` format, but this can be extended for flexibility). +- The output of the demo is both a video overlaying the detected key points and a numpy data file containing the joint positions at all key frames. +- Post processing of these datapoints require the correct fps of the video, which is currently manually inputted into the prototype config file. This will be detected automatically in later additions. # Deep High-Resolution Representation Learning for Human Pose Estimation (CVPR 2019) ## News From 1367ef9525a588d4ece9f663a3a1845a4380d6e4 Mon Sep 17 00:00:00 2001 From: John Harrington Date: Mon, 14 Jun 2021 14:20:21 -0400 Subject: [PATCH 12/23] Added zero vectors in place of skipping frames --- demo/demo.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/demo/demo.py b/demo/demo.py index 5ce13c66..3b1cb925 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -279,6 +279,17 @@ def main(): else: keypoints = np.append(keypoints, [kpt], axis = 0) draw_pose(kpt,image_bgr) # draw the poses + else: + if keypoints is None: + keypoints = np.array([[[0, 0]]*len(COCO_KEYPOINT_INDEXES)]) + else: + keypoints = np.append(keypoints, [[[0, 0]]*len(COCO_KEYPOINT_INDEXES)], axis=0) + else: + #Fill undetected frames with zero vectors + if keypoints is None: + keypoints = np.array([[[0, 0]]*len(COCO_KEYPOINT_INDEXES)]) + else: + keypoints = np.append(keypoints, [[[0, 0]]*len(COCO_KEYPOINT_INDEXES)], axis=0) if args.showFps: fps = 1/(time.time()-last_time) From 3b43ab14d6a2fdcd253c40b155f91271b58df21c Mon Sep 17 00:00:00 2001 From: John Harrington Date: Wed, 7 Jul 2021 15:07:59 -0700 Subject: [PATCH 13/23] Modified demo script to be class based --- demo/demo.py | 38 +++++++++++++++++--------------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index 3b1cb925..2e8c8937 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -14,7 +14,7 @@ import torchvision import cv2 import numpy as np -import time +import os import _init_paths import models @@ -211,7 +211,7 @@ def parse_args(): return args -def main(): +def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpts=False): keypoints = None # cudnn related setting @@ -219,8 +219,7 @@ def main(): torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED - args = parse_args() - update_config(cfg, args) + #update_config(cfg, args) box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) box_model.to(CTX) @@ -245,16 +244,17 @@ def main(): pose_model.eval() # Loading an video or an video - vidcap = cv2.VideoCapture(args.video) - save_path = args.output_dir + "/output.avi" - fourcc = cv2.VideoWriter_fourcc(*'XVID') - vid_fps = vidcap.get(cv2.CAP_PROP_FPS) - out = cv2.VideoWriter(save_path,fourcc, vid_fps, (int(vidcap.get(3)),int(vidcap.get(4)))) + vidcap = cv2.VideoCapture(video) + vid_name, vid_type = os.path.splitext(video) + if output_dir: + save_path = output_dir + f"/{vid_name}_deephrnet_output.{vid_type}" + fourcc = cv2.VideoWriter_fourcc(*'XVID') + vid_fps = vidcap.get(cv2.CAP_PROP_FPS) + out = cv2.VideoWriter(save_path,fourcc, vid_fps, (int(vidcap.get(3)),int(vidcap.get(4)))) while True: ret, image_bgr = vidcap.read() if ret: - last_time = time.time() image = image_bgr[:, :, [2, 1, 0]] input = [] @@ -291,25 +291,21 @@ def main(): else: keypoints = np.append(keypoints, [[[0, 0]]*len(COCO_KEYPOINT_INDEXES)], axis=0) - if args.showFps: - fps = 1/(time.time()-last_time) - img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2) - - if args.write: + if output_video: out.write(image_bgr) else: print('Video ended') break + + if save_kpts: + np.save(f"{output_dir}/keypoints", keypoints) + print(f'keypoint saved to {output_dir}/keypoints.npy') - np.save(f"{args.output_dir}/keypoints", keypoints) - print(f'keypoint saved to {args.output_dir}/keypoints.npy') cv2.destroyAllWindows() vidcap.release() - if args.write: + if output_video: print('video has been saved as {}'.format(save_path)) out.release() - -if __name__ == '__main__': - main() + return keypoints From 6e31718ab3fca819ed6e1914ef8576f260f61807 Mon Sep 17 00:00:00 2001 From: John Harrington Date: Tue, 13 Jul 2021 16:57:40 -0400 Subject: [PATCH 14/23] Changing zeros to Nans for better missing detection marking --- demo/demo.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index 3b1cb925..95a0d885 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -281,15 +281,15 @@ def main(): draw_pose(kpt,image_bgr) # draw the poses else: if keypoints is None: - keypoints = np.array([[[0, 0]]*len(COCO_KEYPOINT_INDEXES)]) + keypoints = np.array([[[np.nan, np.nan]]*len(COCO_KEYPOINT_INDEXES)]) else: - keypoints = np.append(keypoints, [[[0, 0]]*len(COCO_KEYPOINT_INDEXES)], axis=0) + keypoints = np.append(keypoints, [[[np.nan, np.nan]]*len(COCO_KEYPOINT_INDEXES)], axis=0) else: #Fill undetected frames with zero vectors if keypoints is None: - keypoints = np.array([[[0, 0]]*len(COCO_KEYPOINT_INDEXES)]) + keypoints = np.array([[[np.nan, np.nan]]*len(COCO_KEYPOINT_INDEXES)]) else: - keypoints = np.append(keypoints, [[[0, 0]]*len(COCO_KEYPOINT_INDEXES)], axis=0) + keypoints = np.append(keypoints, [[[np.nan, np.nan]]*len(COCO_KEYPOINT_INDEXES)], axis=0) if args.showFps: fps = 1/(time.time()-last_time) From e06ca242daa8386d3357b40bc9e514a3b6389149 Mon Sep 17 00:00:00 2001 From: kenlaz <81046715+kenlaz@users.noreply.github.com> Date: Wed, 14 Jul 2021 14:52:32 -0400 Subject: [PATCH 15/23] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 316df333..e0bd9906 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Ifinite Sky Fork of Deep HRnet 2D pose estimation +# Infinite Sky Fork of Deep HRnet 2D pose estimation ## Installation Instructions - Once this repo is cloned (using the build_stack.sh script in the `bootstrap` repository), the pre-trained model must be manually added. This may be updated in the future, but is not currently. From 995baa6bbf66739f17142d277ce3269980dc204f Mon Sep 17 00:00:00 2001 From: John Harrington Date: Mon, 26 Jul 2021 17:13:41 -0400 Subject: [PATCH 16/23] Adding custom model option --- demo/demo.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index effa41c8..6291ec3a 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -23,6 +23,9 @@ from core.function import get_final_preds from utils.transforms import get_affine_transform +import os +cur_dir = os.path.dirname(os.path.realpath(__file__)) + COCO_KEYPOINT_INDEXES = { 0: 'nose', 1: 'left_eye', @@ -188,9 +191,10 @@ def box_to_center_scale(box, model_image_width, model_image_height): def parse_args(): + parser = argparse.ArgumentParser(description='Train keypoints network') # general - parser.add_argument('--cfg', type=str, default='demo/inference-config.yaml') + parser.add_argument('--cfg', type=str, default=f'{cur_dir}/inference-config.yaml') parser.add_argument('--video', type=str) parser.add_argument('--write',action='store_true') parser.add_argument('--showFps',action='store_true') @@ -211,7 +215,7 @@ def parse_args(): return args -def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpts=False): +def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpts=False, custom_model=None): keypoints = None # cudnn related setting @@ -219,7 +223,8 @@ def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpt torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED - #update_config(cfg, args) + args = parse_args() + update_config(cfg, args) box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) box_model.to(CTX) @@ -229,15 +234,16 @@ def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpt cfg, is_train=False ) - if cfg.TEST.MODEL_FILE: - print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE)) - if torch.cuda.is_available(): - pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False) - else: - pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE, map_location='cpu'), strict=False) - + model_to_use = cfg.TEST.MODEL_FILE + if custom_model: + model_to_use = custom_model + + print('=> loading model from {}'.format(model_to_use)) + if torch.cuda.is_available(): + pose_model.load_state_dict(torch.load(model_to_use), strict=False) else: - print('expected model defined in config at TEST.MODEL_FILE') + pose_model.load_state_dict(torch.load(model_to_use, map_location='cpu'), strict=False) + pose_model = torch.nn.DataParallel(pose_model, device_ids=cfg.GPUS) pose_model.to(CTX) From d1869e00212993334f9406f2ba96da254845827b Mon Sep 17 00:00:00 2001 From: John Harrington Date: Mon, 26 Jul 2021 17:20:14 -0400 Subject: [PATCH 17/23] Adding progress notifier --- demo/demo.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/demo/demo.py b/demo/demo.py index 6291ec3a..5ec67d01 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -258,11 +258,15 @@ def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpt vid_fps = vidcap.get(cv2.CAP_PROP_FPS) out = cv2.VideoWriter(save_path,fourcc, vid_fps, (int(vidcap.get(3)),int(vidcap.get(4)))) + frame_num = 0 while True: ret, image_bgr = vidcap.read() if ret: image = image_bgr[:, :, [2, 1, 0]] + frame_num += 1 + print(f"Processing frame {frame_num}") + input = [] img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().to(CTX) From 0006dc65a3f37a26d2c1b6fb02d56e75e855f51f Mon Sep 17 00:00:00 2001 From: John-HarringtonNZ Date: Tue, 27 Jul 2021 16:36:30 -0400 Subject: [PATCH 18/23] Update README.md --- README.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e0bd9906..a2eedd31 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,8 @@ # Infinite Sky Fork of Deep HRnet 2D pose estimation ## Installation Instructions -- Once this repo is cloned (using the build_stack.sh script in the `bootstrap` repository), the pre-trained model must be manually added. This may be updated in the future, but is not currently. -- The model is located at [this link](https://drive.google.com/drive/folders/1PufGmj1jHq3HSHr23Vne7UqQ2AETOgY4). The `models.zip` file should be downloaded and then extracted to the main directory for this repo (eg. the path `deep-high-resolution-net.pytorch/models/pytorch` should be a legitimate path). Once this model is placed, the detection can take place. -- Other dependencies for this repository are located in the `requirements.txt` file, and should be installed during the `build_env.sh` script, also located in the `bootstrap` repository. -- Everything below this is from the `README` on the orginal repository that was forked, for reference sake. + 1. If the script `build_stack.sh` script has been run via `bootstrap`, this repository should already be cloned. The following instructions are for downloading the valid model. + 2. The model is located at [this link](https://drive.google.com/drive/folders/1PufGmj1jHq3HSHr23Vne7UqQ2AETOgY4). The `models.zip` file should be downloaded and then extracted to the main directory for this repo (eg. the path `deep-high-resolution-net.pytorch/models/pytorch` should be a legitimate path). Once this model is placed, the detection can take place. ## Inputs & Outputs - The main file currently interfaced in this repository is `demo.py`, located under the `demo/` directory. @@ -12,6 +10,9 @@ - The output of the demo is both a video overlaying the detected key points and a numpy data file containing the joint positions at all key frames. - Post processing of these datapoints require the correct fps of the video, which is currently manually inputted into the prototype config file. This will be detected automatically in later additions. +# Note: Everything below this is from the `README` on the orginal repository that was forked, for reference sake. + + # Deep High-Resolution Representation Learning for Human Pose Estimation (CVPR 2019) ## News - [2021/04/12] Welcome to check out our recent work on bottom-up pose estimation (CVPR 2021) [HRNet-DEKR](https://github.com/HRNet/DEKR)! From d9e5135ce59e6adc851d1192b683511d974b150c Mon Sep 17 00:00:00 2001 From: John Harrington Date: Wed, 4 Aug 2021 13:08:03 -0400 Subject: [PATCH 19/23] Updating config parser to not parse when not needed --- demo/demo.py | 40 ++++++++++++++++++---------------------- lib/config/default.py | 34 +++++++++++++++++----------------- 2 files changed, 35 insertions(+), 39 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index 5ec67d01..918533a4 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -190,30 +190,26 @@ def box_to_center_scale(box, model_image_width, model_image_height): return center, scale -def parse_args(): +# def parse_args(): - parser = argparse.ArgumentParser(description='Train keypoints network') - # general - parser.add_argument('--cfg', type=str, default=f'{cur_dir}/inference-config.yaml') - parser.add_argument('--video', type=str) - parser.add_argument('--write',action='store_true') - parser.add_argument('--showFps',action='store_true') - parser.add_argument('--output_dir',type=str, default='/') +# parser = argparse.ArgumentParser(description='Train keypoints network') +# # general +# parser.add_argument('--cfg', type=str, default=f'{cur_dir}/inference-config.yaml') +# parser.add_argument('--video', type=str) +# parser.add_argument('--write',action='store_true') +# parser.add_argument('--showFps',action='store_true') +# parser.add_argument('--output_dir',type=str, default='/') - parser.add_argument('opts', - help='Modify config options using the command-line', - default=None, - nargs=argparse.REMAINDER) +# parser.add_argument('opts', +# help='Modify config options using the command-line', +# default=None, +# nargs=argparse.REMAINDER) - args = parser.parse_args() - - # args expected by supporting codebase - args.modelDir = '' - args.logDir = '' - args.dataDir = '' - args.prevModelDir = '' - return args +# args = parser.parse_args() +class Bunch: + def __init__(self, **kwds): + self.__dict__.update(kwds) def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpts=False, custom_model=None): @@ -223,8 +219,8 @@ def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpt torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED - args = parse_args() - update_config(cfg, args) + #args = parses_args(video, output_dir) + update_config(cfg, Bunch(cfg=f'{cur_dir}/inference-config.yaml', opts=None)) box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) box_model.to(CTX) diff --git a/lib/config/default.py b/lib/config/default.py index 72d3faf3..34a46dd5 100644 --- a/lib/config/default.py +++ b/lib/config/default.py @@ -126,29 +126,29 @@ def update_config(cfg, args): cfg.defrost() cfg.merge_from_file(args.cfg) - cfg.merge_from_list(args.opts) + # cfg.merge_from_list(args.opts) - if args.modelDir: - cfg.OUTPUT_DIR = args.modelDir + # if args.modelDir: + # cfg.OUTPUT_DIR = args.modelDir - if args.logDir: - cfg.LOG_DIR = args.logDir + # if args.logDir: + # cfg.LOG_DIR = args.logDir - if args.dataDir: - cfg.DATA_DIR = args.dataDir + # if args.dataDir: + # cfg.DATA_DIR = args.dataDir - cfg.DATASET.ROOT = os.path.join( - cfg.DATA_DIR, cfg.DATASET.ROOT - ) + # cfg.DATASET.ROOT = os.path.join( + # cfg.DATA_DIR, cfg.DATASET.ROOT + # ) - cfg.MODEL.PRETRAINED = os.path.join( - cfg.DATA_DIR, cfg.MODEL.PRETRAINED - ) + # cfg.MODEL.PRETRAINED = os.path.join( + # cfg.DATA_DIR, cfg.MODEL.PRETRAINED + # ) - if cfg.TEST.MODEL_FILE: - cfg.TEST.MODEL_FILE = os.path.join( - cfg.DATA_DIR, cfg.TEST.MODEL_FILE - ) + # if cfg.TEST.MODEL_FILE: + # cfg.TEST.MODEL_FILE = os.path.join( + # cfg.DATA_DIR, cfg.TEST.MODEL_FILE + # ) cfg.freeze() From 3134ae91714c4facd4bea6cc8d7e8eaa9072893c Mon Sep 17 00:00:00 2001 From: John Harrington Date: Thu, 5 Aug 2021 15:57:44 -0400 Subject: [PATCH 20/23] Adding confidence outputs to pose_estimation --- demo/demo.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index 918533a4..ab0e449a 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -139,13 +139,13 @@ def get_pose_estimation_prediction(pose_model, image, center, scale): with torch.no_grad(): # compute output heatmap output = pose_model(model_input) - preds, _ = get_final_preds( + preds, max_vals = get_final_preds( cfg, output.clone().cpu().numpy(), np.asarray([center]), np.asarray([scale])) - return preds + return np.concatenate((preds, max_vals),2) def box_to_center_scale(box, model_image_width, model_image_height): @@ -284,18 +284,18 @@ def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpt keypoints = np.array([kpt]) else: keypoints = np.append(keypoints, [kpt], axis = 0) - draw_pose(kpt,image_bgr) # draw the poses + #draw_pose(kpt,image_bgr) # draw the poses else: if keypoints is None: - keypoints = np.array([[[np.nan, np.nan]]*len(COCO_KEYPOINT_INDEXES)]) + keypoints = np.array([[[0, 0, 0]]*len(COCO_KEYPOINT_INDEXES)]) else: - keypoints = np.append(keypoints, [[[np.nan, np.nan]]*len(COCO_KEYPOINT_INDEXES)], axis=0) + keypoints = np.append(keypoints, [[[0, 0, 0]]*len(COCO_KEYPOINT_INDEXES)], axis=0) else: #Fill undetected frames with zero vectors if keypoints is None: - keypoints = np.array([[[np.nan, np.nan]]*len(COCO_KEYPOINT_INDEXES)]) + keypoints = np.array([[[0, 0, 0]]*len(COCO_KEYPOINT_INDEXES)]) else: - keypoints = np.append(keypoints, [[[np.nan, np.nan]]*len(COCO_KEYPOINT_INDEXES)], axis=0) + keypoints = np.append(keypoints, [[[0, 0, 0]]*len(COCO_KEYPOINT_INDEXES)], axis=0) if output_video: out.write(image_bgr) @@ -303,7 +303,6 @@ def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpt else: print('Video ended') break - if save_kpts: np.save(f"{output_dir}/keypoints", keypoints) print(f'keypoint saved to {output_dir}/keypoints.npy') From 2355340fec022273a2dfa9ce1099aec9e31f2720 Mon Sep 17 00:00:00 2001 From: Russell Montalbano Date: Fri, 17 Dec 2021 15:12:34 -0500 Subject: [PATCH 21/23] Choose correct person kpts --- demo/demo.py | 41 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index ab0e449a..52e098ef 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -22,6 +22,7 @@ from config import update_config from core.function import get_final_preds from utils.transforms import get_affine_transform +import pose_estimation.sort as Sort import os cur_dir = os.path.dirname(os.path.realpath(__file__)) @@ -94,7 +95,19 @@ def draw_bbox(box,img): cv2.rectangle(img, box[0], box[1], color=(0, 255, 0),thickness=3) -def get_person_detection_boxes(model, img, threshold=0.5): +def get_id_num(tracked_boxes): + max_area = 0 + id_num = 0 + for box in tracked_boxes: + box_area = (box[2] - box[0]) * (box[3] - box[1]) + if box_area > max_area: + max_area = box_area + id_num = box[4] + + return id_num + + +def get_person_detection_boxes(model, img, tracker, id_num, threshold=0.5): pred = model(img) pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].cpu().numpy())] # Get the Prediction Score @@ -105,15 +118,30 @@ def get_person_detection_boxes(model, img, threshold=0.5): return [] # Get list of index with score greater than threshold pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1] - pred_boxes = pred_boxes[:1] + pred_boxes = pred_boxes[:pred_t+1] pred_classes = pred_classes[:pred_t+1] person_boxes = [] for idx, box in enumerate(pred_boxes): if pred_classes[idx] == 'person': + # Create array of structure [bb_x1, bb_y1, bb_x2, bb_y2, score] for use with SORT + box = [coord for pos in box for coord in pos] + box.append(pred_score[idx]) person_boxes.append(box) + + # Get ID's for each person + person_boxes = np.array(person_boxes) + boxes_tracked = tracker.update(person_boxes) + + # If this is the first frame, get the ID of the bigger bounding box (person more in focus) + if id_num is None: + id_num = get_id_num(boxes_tracked) + + # Turn into [[(x1, y2), (x2, y2)]] + person_box = [box for box in boxes_tracked if box[4] == id_num][0] + person_box = [[(person_box[0], person_box[1]), (person_box[2], person_box[3])]] - return person_boxes + return person_box, id_num def get_pose_estimation_prediction(pose_model, image, center, scale): @@ -254,6 +282,9 @@ def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpt vid_fps = vidcap.get(cv2.CAP_PROP_FPS) out = cv2.VideoWriter(save_path,fourcc, vid_fps, (int(vidcap.get(3)),int(vidcap.get(4)))) + tracker = Sort.Sort(max_age=3) + id_num = None + frame_num = 0 while True: ret, image_bgr = vidcap.read() @@ -269,14 +300,14 @@ def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpt input.append(img_tensor) # object detection box - pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.95) + pred_boxes, id_num = get_person_detection_boxes(box_model, input, tracker, id_num, threshold=0.95) # pose estimation if len(pred_boxes) >= 1: for box in pred_boxes: center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]) image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy() - pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale) + pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale) if len(pose_preds)>=1: for i, kpt in enumerate(pose_preds): name = COCO_KEYPOINT_INDEXES[i] From 676df87cb28140c0edc08eaad38b09e0b82ecfa1 Mon Sep 17 00:00:00 2001 From: Russell Montalbano Date: Fri, 31 Dec 2021 10:45:25 -0500 Subject: [PATCH 22/23] add comments --- demo/demo.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index 52e098ef..0974fb28 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -115,7 +115,8 @@ def get_person_detection_boxes(model, img, tracker, id_num, threshold=0.5): for i in list(pred[0]['boxes'].detach().cpu().numpy())] # Bounding boxes pred_score = list(pred[0]['scores'].detach().cpu().numpy()) if not pred_score or max(pred_score) threshold][-1] pred_boxes = pred_boxes[:pred_t+1] @@ -124,7 +125,7 @@ def get_person_detection_boxes(model, img, tracker, id_num, threshold=0.5): person_boxes = [] for idx, box in enumerate(pred_boxes): if pred_classes[idx] == 'person': - # Create array of structure [bb_x1, bb_y1, bb_x2, bb_y2, score] for use with SORT + # Create array of structure [bb_x1, bb_y1, bb_x2, bb_y2, score] for use with SORT tracker box = [coord for pos in box for coord in pos] box.append(pred_score[idx]) person_boxes.append(box) @@ -133,15 +134,20 @@ def get_person_detection_boxes(model, img, tracker, id_num, threshold=0.5): person_boxes = np.array(person_boxes) boxes_tracked = tracker.update(person_boxes) - # If this is the first frame, get the ID of the bigger bounding box (person more in focus) + # If this is the first frame, get the ID of the bigger bounding box (person more in focus, most likely the thrower) if id_num is None: id_num = get_id_num(boxes_tracked) # Turn into [[(x1, y2), (x2, y2)]] - person_box = [box for box in boxes_tracked if box[4] == id_num][0] - person_box = [[(person_box[0], person_box[1]), (person_box[2], person_box[3])]] + try: + person_box = [box for box in boxes_tracked if box[4] == id_num][0] + person_box = [[(person_box[0], person_box[1]), (person_box[2], person_box[3])]] + return person_box, id_num - return person_box, id_num + # If detections weren't made for our thrower in a frame for some reason, return nothing to be smoothed later + # As long as the thrower is detected within the next 3 frames, it will be assigned the same ID as before + except IndexError: + return [], id_num def get_pose_estimation_prediction(pose_model, image, center, scale): @@ -282,6 +288,7 @@ def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpt vid_fps = vidcap.get(cv2.CAP_PROP_FPS) out = cv2.VideoWriter(save_path,fourcc, vid_fps, (int(vidcap.get(3)),int(vidcap.get(4)))) + # Initialize SORT Tracker tracker = Sort.Sort(max_age=3) id_num = None From 1fe10ca135248a915afc1ee5e6e46d2722d8ce08 Mon Sep 17 00:00:00 2001 From: Russell Montalbano Date: Sun, 9 Jan 2022 14:29:43 -0500 Subject: [PATCH 23/23] add comments --- demo/demo.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/demo/demo.py b/demo/demo.py index 0974fb28..b2210ad8 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -96,6 +96,9 @@ def draw_bbox(box,img): def get_id_num(tracked_boxes): + """ + Get the SORT tracker ID number of the bounding box with the biggest area + """ max_area = 0 id_num = 0 for box in tracked_boxes: @@ -145,7 +148,7 @@ def get_person_detection_boxes(model, img, tracker, id_num, threshold=0.5): return person_box, id_num # If detections weren't made for our thrower in a frame for some reason, return nothing to be smoothed later - # As long as the thrower is detected within the next 3 frames, it will be assigned the same ID as before + # As long as the thrower is detected within the next "max_age" frames, it will be assigned the same ID as before except IndexError: return [], id_num @@ -245,7 +248,7 @@ class Bunch: def __init__(self, **kwds): self.__dict__.update(kwds) -def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpts=False, custom_model=None): +def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpts=False, custom_model=None, max_age=3): keypoints = None # cudnn related setting @@ -289,7 +292,7 @@ def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpt out = cv2.VideoWriter(save_path,fourcc, vid_fps, (int(vidcap.get(3)),int(vidcap.get(4)))) # Initialize SORT Tracker - tracker = Sort.Sort(max_age=3) + tracker = Sort.Sort(max_age=max_age) id_num = None frame_num = 0