diff --git a/README.md b/README.md index 3baff399..40b7720e 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ # Deep High-Resolution Representation Learning for Human Pose Estimation (CVPR 2019) ## News +- [2021/04/12] Welcome to check out our recent work on bottom-up pose estimation (CVPR 2021) [HRNet-DEKR](https://github.com/HRNet/DEKR)! +- [2020/07/05] [A very nice blog](https://towardsdatascience.com/overview-of-human-pose-estimation-neural-networks-hrnet-higherhrnet-architectures-and-faq-1954b2f8b249) from Towards Data Science introducing HRNet and HigherHRNet for human pose estimation. +- [2020/03/13] A longer version is accepted by TPAMI: [Deep High-Resolution Representation Learning for Visual Recognition](https://arxiv.org/pdf/1908.07919.pdf). It includes more HRNet applications, and the codes are available: [semantic segmentation](https://github.com/HRNet/HRNet-Semantic-Segmentation), [objection detection](https://github.com/HRNet/HRNet-Object-Detection), [facial landmark detection](https://github.com/HRNet/HRNet-Facial-Landmark-Detection), and [image classification](https://github.com/HRNet/HRNet-Image-Classification). - [2020/02/01] We have added demo code for HRNet. Thanks [Alex Simes](https://github.com/alex9311). - Visualization code for showing the pose estimation results. Thanks Depu! - [2019/08/27] HigherHRNet is now on [ArXiv](https://arxiv.org/abs/1908.10357), which is a bottom-up approach for human pose estimation powerd by HRNet. We will also release code and models at [Higher-HRNet-Human-Pose-Estimation](https://github.com/HRNet/Higher-HRNet-Human-Pose-Estimation), stay tuned! @@ -239,6 +242,12 @@ python visualization/plot_coco.py \ ### Other applications Many other dense prediction tasks, such as segmentation, face alignment and object detection, etc. have been benefited by HRNet. More information can be found at [High-Resolution Networks](https://github.com/HRNet). +### Other implementation +[mmpose](https://github.com/open-mmlab/mmpose)
+[ModelScope (中文)](https://modelscope.cn/models/damo/cv_hrnetv2w32_body-2d-keypoints_image/summary)
+[timm](https://huggingface.co/docs/timm/main/en/models/hrnet) + + ### Citation If you use our code or models in your research, please cite with: ``` diff --git a/demo/README.md b/demo/README.md index 35d590b8..aff81f44 100644 --- a/demo/README.md +++ b/demo/README.md @@ -1,41 +1,75 @@ -This demo code is meant to be run on a video and includes a person detector. -[Nvidia-docker](https://github.com/NVIDIA/nvidia-docker) and GPUs are required. -It only expects there to be one person in each frame of video, though the code could easily be extended to support multiple people. +# Inference hrnet -### Prep +Inferencing the deep-high-resolution-net.pytoch without using Docker. + +## Prep 1. Download the researchers' pretrained pose estimator from [google drive](https://drive.google.com/drive/folders/1hOTihvbyIxsm5ygDpbUuJ7O_tzv4oXjC?usp=sharing) to this directory under `models/` 2. Put the video file you'd like to infer on in this directory under `videos` -3. build the docker container in this directory with `./build-docker.sh` (this can take time because it involves compiling opencv) -4. update the `inference-config.yaml` file to reflect the number of GPUs you have available +3. (OPTIONAL) build the docker container in this directory with `./build-docker.sh` (this can take time because it involves compiling opencv) +4. update the `inference-config.yaml` file to reflect the number of GPUs you have available and which trained model you want to use. + +## Running the Model +### 1. Running on the video +``` +python demo/inference.py --cfg demo/inference-config.yaml \ + --videoFile ../../multi_people.mp4 \ + --writeBoxFrames \ + --outputDir output \ + TEST.MODEL_FILE ../models/pytorch/pose_coco/pose_hrnet_w32_256x192.pth -### Running the Model -Start your docker container with: ``` -nvidia-docker run --rm -it \ - -v $(pwd)/output:/output \ - -v $(pwd)/videos:/videos \ - -v $(pwd)/models:/models \ - -w /pose_root \ - hrnet_demo_inference \ - /bin/bash + +The above command will create a video under *output* directory and a lot of pose image under *output/pose* directory. +Even with usage of GPU (GTX1080 in my case), the person detection will take nearly **0.06 sec**, the person pose match will + take nearly **0.07 sec**. In total. inference time per frame will be **0.13 sec**, nearly 10fps. So if you prefer a real-time (fps >= 20) + pose estimation then you should try other approach. + +**===Result===** + +Some output images are as: + +![1 person](inference_1.jpg) +Fig: 1 person inference + +![3 person](inference_3.jpg) +Fig: 3 person inference + +![3 person](inference_5.jpg) +Fig: 3 person inference + +### 2. Demo with more common functions +Remember to update` TEST.MODEL_FILE` in `demo/inference-config.yaml `according to your model path. + +`demo.py` provides the following functions: + +- use `--webcam` when the input is a real-time camera. +- use `--video [video-path]` when the input is a video. +- use `--image [image-path]` when the input is an image. +- use `--write` to save the image, camera or video result. +- use `--showFps` to show the fps (this fps includes the detection part). +- draw connections between joints. + +#### (1) the input is a real-time carema +```python +python demo/demo.py --webcam --showFps --write ``` -Once the container is running, you can run inference with: +#### (2) the input is a video +```python +python demo/demo.py --video test.mp4 --showFps --write ``` -python tools/inference.py \ - --cfg inference-config.yaml \ - --videoFile /videos/my-video.mp4 \ - --inferenceFps 10 \ - --writeBoxFrames \ - TEST.MODEL_FILE \ - /models/pytorch/pose_coco/pose_hrnet_w32_384x288.pth +#### (3) the input is a image + +```python +python demo/demo.py --image test.jpg --showFps --write ``` -The command above will output frames with boxes, -frames with poses, -a video with poses, -and a csv with the keypoint coordinates for each frame. +**===Result===** + +![show_fps](inference_6.jpg) + +Fig: show fps -![](hrnet-demo.gif) +![multi-people](inference_7.jpg) -Original source for demo video above is licensed for `Free for commercial use No attribution required` by [Pixabay](https://pixabay.com/service/license/) +Fig: multi-people \ No newline at end of file diff --git a/demo/_init_paths.py b/demo/_init_paths.py new file mode 100644 index 00000000..b1aea8fe --- /dev/null +++ b/demo/_init_paths.py @@ -0,0 +1,27 @@ +# ------------------------------------------------------------------------------ +# pose.pytorch +# Copyright (c) 2018-present Microsoft +# Licensed under The Apache-2.0 License [see LICENSE for details] +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path as osp +import sys + + +def add_path(path): + if path not in sys.path: + sys.path.insert(0, path) + + +this_dir = osp.dirname(__file__) + +lib_path = osp.join(this_dir, '..', 'lib') +add_path(lib_path) + +mm_path = osp.join(this_dir, '..', 'lib/poseeval/py-motmetrics') +add_path(mm_path) diff --git a/demo/demo.py b/demo/demo.py new file mode 100644 index 00000000..d482e838 --- /dev/null +++ b/demo/demo.py @@ -0,0 +1,343 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import csv +import os +import shutil + +from PIL import Image +import torch +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision +import cv2 +import numpy as np +import time + + +import _init_paths +import models +from config import cfg +from config import update_config +from core.function import get_final_preds +from utils.transforms import get_affine_transform + +COCO_KEYPOINT_INDEXES = { + 0: 'nose', + 1: 'left_eye', + 2: 'right_eye', + 3: 'left_ear', + 4: 'right_ear', + 5: 'left_shoulder', + 6: 'right_shoulder', + 7: 'left_elbow', + 8: 'right_elbow', + 9: 'left_wrist', + 10: 'right_wrist', + 11: 'left_hip', + 12: 'right_hip', + 13: 'left_knee', + 14: 'right_knee', + 15: 'left_ankle', + 16: 'right_ankle' +} + +COCO_INSTANCE_CATEGORY_NAMES = [ + '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', + 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', + 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', + 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' +] + +SKELETON = [ + [1,3],[1,0],[2,4],[2,0],[0,5],[0,6],[5,7],[7,9],[6,8],[8,10],[5,11],[6,12],[11,12],[11,13],[13,15],[12,14],[14,16] +] + +CocoColors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + +NUM_KPTS = 17 + +CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +def draw_pose(keypoints,img): + """draw the keypoints and the skeletons. + :params keypoints: the shape should be equal to [17,2] + :params img: + """ + assert keypoints.shape == (NUM_KPTS,2) + for i in range(len(SKELETON)): + kpt_a, kpt_b = SKELETON[i][0], SKELETON[i][1] + x_a, y_a = keypoints[kpt_a][0],keypoints[kpt_a][1] + x_b, y_b = keypoints[kpt_b][0],keypoints[kpt_b][1] + cv2.circle(img, (int(x_a), int(y_a)), 6, CocoColors[i], -1) + cv2.circle(img, (int(x_b), int(y_b)), 6, CocoColors[i], -1) + cv2.line(img, (int(x_a), int(y_a)), (int(x_b), int(y_b)), CocoColors[i], 2) + +def draw_bbox(box,img): + """draw the detected bounding box on the image. + :param img: + """ + cv2.rectangle(img, box[0], box[1], color=(0, 255, 0),thickness=3) + + +def get_person_detection_boxes(model, img, threshold=0.5): + pred = model(img) + pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i] + for i in list(pred[0]['labels'].cpu().numpy())] # Get the Prediction Score + pred_boxes = [[(i[0], i[1]), (i[2], i[3])] + for i in list(pred[0]['boxes'].detach().cpu().numpy())] # Bounding boxes + pred_score = list(pred[0]['scores'].detach().cpu().numpy()) + if not pred_score or max(pred_score) threshold][-1] + pred_boxes = pred_boxes[:pred_t+1] + pred_classes = pred_classes[:pred_t+1] + + person_boxes = [] + for idx, box in enumerate(pred_boxes): + if pred_classes[idx] == 'person': + person_boxes.append(box) + + return person_boxes + + +def get_pose_estimation_prediction(pose_model, image, center, scale): + rotation = 0 + + # pose estimation transformation + trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE) + model_input = cv2.warpAffine( + image, + trans, + (int(cfg.MODEL.IMAGE_SIZE[0]), int(cfg.MODEL.IMAGE_SIZE[1])), + flags=cv2.INTER_LINEAR) + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + + # pose estimation inference + model_input = transform(model_input).unsqueeze(0) + # switch to evaluate mode + pose_model.eval() + with torch.no_grad(): + # compute output heatmap + output = pose_model(model_input) + preds, _ = get_final_preds( + cfg, + output.clone().cpu().numpy(), + np.asarray([center]), + np.asarray([scale])) + + return preds + + +def box_to_center_scale(box, model_image_width, model_image_height): + """convert a box to center,scale information required for pose transformation + Parameters + ---------- + box : list of tuple + list of length 2 with two tuples of floats representing + bottom left and top right corner of a box + model_image_width : int + model_image_height : int + + Returns + ------- + (numpy array, numpy array) + Two numpy arrays, coordinates for the center of the box and the scale of the box + """ + center = np.zeros((2), dtype=np.float32) + + bottom_left_corner = box[0] + top_right_corner = box[1] + box_width = top_right_corner[0]-bottom_left_corner[0] + box_height = top_right_corner[1]-bottom_left_corner[1] + bottom_left_x = bottom_left_corner[0] + bottom_left_y = bottom_left_corner[1] + center[0] = bottom_left_x + box_width * 0.5 + center[1] = bottom_left_y + box_height * 0.5 + + aspect_ratio = model_image_width * 1.0 / model_image_height + pixel_std = 200 + + if box_width > aspect_ratio * box_height: + box_height = box_width * 1.0 / aspect_ratio + elif box_width < aspect_ratio * box_height: + box_width = box_height * aspect_ratio + scale = np.array( + [box_width * 1.0 / pixel_std, box_height * 1.0 / pixel_std], + dtype=np.float32) + if center[0] != -1: + scale = scale * 1.25 + + return center, scale + +def parse_args(): + parser = argparse.ArgumentParser(description='Train keypoints network') + # general + parser.add_argument('--cfg', type=str, default='demo/inference-config.yaml') + parser.add_argument('--video', type=str) + parser.add_argument('--webcam',action='store_true') + parser.add_argument('--image',type=str) + parser.add_argument('--write',action='store_true') + parser.add_argument('--showFps',action='store_true') + + parser.add_argument('opts', + help='Modify config options using the command-line', + default=None, + nargs=argparse.REMAINDER) + + args = parser.parse_args() + + # args expected by supporting codebase + args.modelDir = '' + args.logDir = '' + args.dataDir = '' + args.prevModelDir = '' + return args + + +def main(): + # cudnn related setting + cudnn.benchmark = cfg.CUDNN.BENCHMARK + torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC + torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED + + args = parse_args() + update_config(cfg, args) + + box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + box_model.to(CTX) + box_model.eval() + + pose_model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')( + cfg, is_train=False + ) + + if cfg.TEST.MODEL_FILE: + print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE)) + pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False) + else: + print('expected model defined in config at TEST.MODEL_FILE') + + pose_model = torch.nn.DataParallel(pose_model, device_ids=cfg.GPUS) + pose_model.to(CTX) + pose_model.eval() + + # Loading an video or an image or webcam + if args.webcam: + vidcap = cv2.VideoCapture(0) + elif args.video: + vidcap = cv2.VideoCapture(args.video) + elif args.image: + image_bgr = cv2.imread(args.image) + else: + print('please use --video or --webcam or --image to define the input.') + return + + if args.webcam or args.video: + if args.write: + save_path = 'output.avi' + fourcc = cv2.VideoWriter_fourcc(*'XVID') + out = cv2.VideoWriter(save_path,fourcc, 24.0, (int(vidcap.get(3)),int(vidcap.get(4)))) + while True: + ret, image_bgr = vidcap.read() + if ret: + last_time = time.time() + image = image_bgr[:, :, [2, 1, 0]] + + input = [] + img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().to(CTX) + input.append(img_tensor) + + # object detection box + pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.9) + + # pose estimation + if len(pred_boxes) >= 1: + for box in pred_boxes: + center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]) + image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy() + pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale) + if len(pose_preds)>=1: + for kpt in pose_preds: + draw_pose(kpt,image_bgr) # draw the poses + + if args.showFps: + fps = 1/(time.time()-last_time) + img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2) + + if args.write: + out.write(image_bgr) + + cv2.imshow('demo',image_bgr) + if cv2.waitKey(1) & 0XFF==ord('q'): + break + else: + print('cannot load the video.') + break + + cv2.destroyAllWindows() + vidcap.release() + if args.write: + print('video has been saved as {}'.format(save_path)) + out.release() + + else: + # estimate on the image + last_time = time.time() + image = image_bgr[:, :, [2, 1, 0]] + + input = [] + img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().to(CTX) + input.append(img_tensor) + + # object detection box + pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.9) + + # pose estimation + if len(pred_boxes) >= 1: + for box in pred_boxes: + center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]) + image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy() + pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale) + if len(pose_preds)>=1: + for kpt in pose_preds: + draw_pose(kpt,image_bgr) # draw the poses + + if args.showFps: + fps = 1/(time.time()-last_time) + img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2) + + if args.write: + save_path = 'output.jpg' + cv2.imwrite(save_path,image_bgr) + print('the result image has been saved as {}'.format(save_path)) + + cv2.imshow('demo',image_bgr) + if cv2.waitKey(0) & 0XFF==ord('q'): + cv2.destroyAllWindows() + +if __name__ == '__main__': + main() diff --git a/demo/inference-config.yaml b/demo/inference-config.yaml index 9e57cf20..14bce176 100644 --- a/demo/inference-config.yaml +++ b/demo/inference-config.yaml @@ -26,7 +26,7 @@ MODEL: INIT_WEIGHTS: true NAME: pose_hrnet NUM_JOINTS: 17 - PRETRAINED: 'models/pytorch/imagenet/hrnet_w32-36af842e.pth' + PRETRAINED: 'models/pytorch/pose_coco/pose_hrnet_w32_384x288.pth' TARGET_TYPE: gaussian IMAGE_SIZE: - 288 @@ -112,7 +112,7 @@ TEST: BBOX_THRE: 1.0 IMAGE_THRE: 0.0 IN_VIS_THRE: 0.2 - MODEL_FILE: '' + MODEL_FILE: 'models/pytorch/pose_coco/pose_hrnet_w32_384x288.pth' NMS_THRE: 1.0 OKS_THRE: 0.9 USE_GT_BBOX: true diff --git a/demo/inference.py b/demo/inference.py index bee22bf8..efff86a7 100644 --- a/demo/inference.py +++ b/demo/inference.py @@ -19,14 +19,20 @@ import cv2 import numpy as np +import sys +sys.path.append("../lib") +import time -import _init_paths +# import _init_paths import models from config import cfg from config import update_config -from core.function import get_final_preds +from core.inference import get_final_preds from utils.transforms import get_affine_transform +CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + COCO_KEYPOINT_INDEXES = { 0: 'nose', 1: 'left_eye', @@ -67,57 +73,53 @@ def get_person_detection_boxes(model, img, threshold=0.5): pil_image = Image.fromarray(img) # Load the image transform = transforms.Compose([transforms.ToTensor()]) # Defing PyTorch Transform transformed_img = transform(pil_image) # Apply the transform to the image - pred = model([transformed_img]) # Pass the image to the model + pred = model([transformed_img.to(CTX)]) # Pass the image to the model + # Use the first detected person pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i] - for i in list(pred[0]['labels'].numpy())] # Get the Prediction Score + for i in list(pred[0]['labels'].cpu().numpy())] # Get the Prediction Score pred_boxes = [[(i[0], i[1]), (i[2], i[3])] - for i in list(pred[0]['boxes'].detach().numpy())] # Bounding boxes - pred_score = list(pred[0]['scores'].detach().numpy()) - if not pred_score: - return [] - # Get list of index with score greater than threshold - pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1] - pred_boxes = pred_boxes[:pred_t+1] - pred_classes = pred_classes[:pred_t+1] + for i in list(pred[0]['boxes'].cpu().detach().numpy())] # Bounding boxes + pred_scores = list(pred[0]['scores'].cpu().detach().numpy()) person_boxes = [] - for idx, box in enumerate(pred_boxes): - if pred_classes[idx] == 'person': - person_boxes.append(box) + # Select box has score larger than threshold and is person + for pred_class, pred_box, pred_score in zip(pred_classes, pred_boxes, pred_scores): + if (pred_score > threshold) and (pred_class == 'person'): + person_boxes.append(pred_box) return person_boxes -def get_pose_estimation_prediction(pose_model, image, center, scale): +def get_pose_estimation_prediction(pose_model, image, centers, scales, transform): rotation = 0 # pose estimation transformation - trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE) - model_input = cv2.warpAffine( - image, - trans, - (int(cfg.MODEL.IMAGE_SIZE[0]), int(cfg.MODEL.IMAGE_SIZE[1])), - flags=cv2.INTER_LINEAR) - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]), - ]) - - # pose estimation inference - model_input = transform(model_input).unsqueeze(0) - # switch to evaluate mode - pose_model.eval() - with torch.no_grad(): - # compute output heatmap - output = pose_model(model_input) - preds, _ = get_final_preds( - cfg, - output.clone().cpu().numpy(), - np.asarray([center]), - np.asarray([scale])) - - return preds + model_inputs = [] + for center, scale in zip(centers, scales): + trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE) + # Crop smaller image of people + model_input = cv2.warpAffine( + image, + trans, + (int(cfg.MODEL.IMAGE_SIZE[0]), int(cfg.MODEL.IMAGE_SIZE[1])), + flags=cv2.INTER_LINEAR) + + # hwc -> 1chw + model_input = transform(model_input)#.unsqueeze(0) + model_inputs.append(model_input) + + # n * 1chw -> nchw + model_inputs = torch.stack(model_inputs) + + # compute output heatmap + output = pose_model(model_inputs.to(CTX)) + coords, _ = get_final_preds( + cfg, + output.cpu().detach().numpy(), + np.asarray(centers), + np.asarray(scales)) + + return coords def box_to_center_scale(box, model_image_width, model_image_height): @@ -163,15 +165,11 @@ def box_to_center_scale(box, model_image_width, model_image_height): def prepare_output_dirs(prefix='/output/'): - pose_dir = prefix+'poses/' - box_dir = prefix+'boxes/' + pose_dir = os.path.join(prefix, "pose") if os.path.exists(pose_dir) and os.path.isdir(pose_dir): shutil.rmtree(pose_dir) - if os.path.exists(box_dir) and os.path.isdir(box_dir): - shutil.rmtree(box_dir) os.makedirs(pose_dir, exist_ok=True) - os.makedirs(box_dir, exist_ok=True) - return pose_dir, box_dir + return pose_dir def parse_args(): @@ -199,6 +197,13 @@ def parse_args(): def main(): + # transformation + pose_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + # cudnn related setting cudnn.benchmark = cfg.CUDNN.BENCHMARK torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC @@ -206,13 +211,12 @@ def main(): args = parse_args() update_config(cfg, args) - pose_dir, box_dir = prepare_output_dirs(args.outputDir) - csv_output_filename = args.outputDir+'pose-data.csv' + pose_dir = prepare_output_dirs(args.outputDir) csv_output_rows = [] box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + box_model.to(CTX) box_model.eval() - pose_model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')( cfg, is_train=False ) @@ -223,7 +227,8 @@ def main(): else: print('expected model defined in config at TEST.MODEL_FILE') - pose_model = torch.nn.DataParallel(pose_model, device_ids=cfg.GPUS).cuda() + pose_model.to(CTX) + pose_model.eval() # Loading an video vidcap = cv2.VideoCapture(args.videoFile) @@ -231,68 +236,105 @@ def main(): if fps < args.inferenceFps: print('desired inference fps is '+str(args.inferenceFps)+' but video fps is '+str(fps)) exit() - every_nth_frame = round(fps/args.inferenceFps) + skip_frame_cnt = round(fps / args.inferenceFps) + frame_width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)) + frame_height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + outcap = cv2.VideoWriter('{}/{}_pose.avi'.format(args.outputDir, os.path.splitext(os.path.basename(args.videoFile))[0]), + cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), int(skip_frame_cnt), (frame_width, frame_height)) - success, image_bgr = vidcap.read() count = 0 + while vidcap.isOpened(): + total_now = time.time() + ret, image_bgr = vidcap.read() + count += 1 - while success: - if count % every_nth_frame != 0: - success, image_bgr = vidcap.read() - count += 1 + if not ret: continue - image = image_bgr[:, :, [2, 1, 0]] - count_str = str(count).zfill(32) + if count % skip_frame_cnt != 0: + continue + + image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + + # Clone 2 image for person detection and pose estimation + if cfg.DATASET.COLOR_RGB: + image_per = image_rgb.copy() + image_pose = image_rgb.copy() + else: + image_per = image_bgr.copy() + image_pose = image_bgr.copy() + + # Clone 1 image for debugging purpose + image_debug = image_bgr.copy() # object detection box - pred_boxes = get_person_detection_boxes(box_model, image, threshold=0.8) - if args.writeBoxFrames: - image_bgr_box = image_bgr.copy() - for box in pred_boxes: - cv2.rectangle(image_bgr_box, box[0], box[1], color=(0, 255, 0), - thickness=3) # Draw Rectangle with the coordinates - cv2.imwrite(box_dir+'box%s.jpg' % count_str, image_bgr_box) + now = time.time() + pred_boxes = get_person_detection_boxes(box_model, image_per, threshold=0.9) + then = time.time() + print("Find person bbox in: {} sec".format(then - now)) + + # Can not find people. Move to next frame if not pred_boxes: - success, image_bgr = vidcap.read() count += 1 continue - # pose estimation - box = pred_boxes[0] # assume there is only 1 person - center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]) - image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy() - pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale) + if args.writeBoxFrames: + for box in pred_boxes: + cv2.rectangle(image_debug, box[0], box[1], color=(0, 255, 0), + thickness=3) # Draw Rectangle with the coordinates + + # pose estimation : for multiple people + centers = [] + scales = [] + for box in pred_boxes: + center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]) + centers.append(center) + scales.append(scale) + + now = time.time() + pose_preds = get_pose_estimation_prediction(pose_model, image_pose, centers, scales, transform=pose_transform) + then = time.time() + print("Find person pose in: {} sec".format(then - now)) new_csv_row = [] - for _, mat in enumerate(pose_preds[0]): - x_coord, y_coord = int(mat[0]), int(mat[1]) - cv2.circle(image_bgr, (x_coord, y_coord), 4, (255, 0, 0), 2) - new_csv_row.extend([x_coord, y_coord]) + for coords in pose_preds: + # Draw each point on image + for coord in coords: + x_coord, y_coord = int(coord[0]), int(coord[1]) + cv2.circle(image_debug, (x_coord, y_coord), 4, (255, 0, 0), 2) + new_csv_row.extend([x_coord, y_coord]) + + total_then = time.time() + + text = "{:03.2f} sec".format(total_then - total_now) + cv2.putText(image_debug, text, (100, 50), cv2.FONT_HERSHEY_SIMPLEX, + 1, (0, 0, 255), 2, cv2.LINE_AA) + + cv2.imshow("pos", image_debug) + if cv2.waitKey(1) & 0xFF == ord('q'): + break csv_output_rows.append(new_csv_row) - cv2.imwrite(pose_dir+'pose%s.jpg' % count_str, image_bgr) + img_file = os.path.join(pose_dir, 'pose_{:08d}.jpg'.format(count)) + cv2.imwrite(img_file, image_debug) + outcap.write(image_debug) - # get next frame - success, image_bgr = vidcap.read() - count += 1 # write csv csv_headers = ['frame'] for keypoint in COCO_KEYPOINT_INDEXES.values(): csv_headers.extend([keypoint+'_x', keypoint+'_y']) + csv_output_filename = os.path.join(args.outputDir, 'pose-data.csv') with open(csv_output_filename, 'w', newline='') as csvfile: csvwriter = csv.writer(csvfile) csvwriter.writerow(csv_headers) csvwriter.writerows(csv_output_rows) - os.system("ffmpeg -y -r " - + str(args.inferenceFps) - + " -pattern_type glob -i '" - + pose_dir - + "/*.jpg' -c:v libx264 -vf fps=" - + str(args.inferenceFps)+" -pix_fmt yuv420p /output/movie.mp4") + vidcap.release() + outcap.release() + + cv2.destroyAllWindows() if __name__ == '__main__': diff --git a/demo/inference_1.jpg b/demo/inference_1.jpg new file mode 100644 index 00000000..2ca29d1b Binary files /dev/null and b/demo/inference_1.jpg differ diff --git a/demo/inference_3.jpg b/demo/inference_3.jpg new file mode 100644 index 00000000..7f20915c Binary files /dev/null and b/demo/inference_3.jpg differ diff --git a/demo/inference_5.jpg b/demo/inference_5.jpg new file mode 100644 index 00000000..d7b7117c Binary files /dev/null and b/demo/inference_5.jpg differ diff --git a/demo/inference_6.jpg b/demo/inference_6.jpg new file mode 100644 index 00000000..cc1183c0 Binary files /dev/null and b/demo/inference_6.jpg differ diff --git a/demo/inference_7.jpg b/demo/inference_7.jpg new file mode 100644 index 00000000..9629b300 Binary files /dev/null and b/demo/inference_7.jpg differ