diff --git a/README.md b/README.md index 5882990e..40b7720e 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,10 @@ -# Deep High-Resolution Representation Learning for Human Pose Estimation (accepted to CVPR2019) +# Deep High-Resolution Representation Learning for Human Pose Estimation (CVPR 2019) ## News -- If you are interested in internship or research positions related to computer vision in ByteDance AI Lab, feel free to contact me(leoxiaobin-at-gmail.com). +- [2021/04/12] Welcome to check out our recent work on bottom-up pose estimation (CVPR 2021) [HRNet-DEKR](https://github.com/HRNet/DEKR)! +- [2020/07/05] [A very nice blog](https://towardsdatascience.com/overview-of-human-pose-estimation-neural-networks-hrnet-higherhrnet-architectures-and-faq-1954b2f8b249) from Towards Data Science introducing HRNet and HigherHRNet for human pose estimation. +- [2020/03/13] A longer version is accepted by TPAMI: [Deep High-Resolution Representation Learning for Visual Recognition](https://arxiv.org/pdf/1908.07919.pdf). It includes more HRNet applications, and the codes are available: [semantic segmentation](https://github.com/HRNet/HRNet-Semantic-Segmentation), [objection detection](https://github.com/HRNet/HRNet-Object-Detection), [facial landmark detection](https://github.com/HRNet/HRNet-Facial-Landmark-Detection), and [image classification](https://github.com/HRNet/HRNet-Image-Classification). +- [2020/02/01] We have added demo code for HRNet. Thanks [Alex Simes](https://github.com/alex9311). +- Visualization code for showing the pose estimation results. Thanks Depu! - [2019/08/27] HigherHRNet is now on [ArXiv](https://arxiv.org/abs/1908.10357), which is a bottom-up approach for human pose estimation powerd by HRNet. We will also release code and models at [Higher-HRNet-Human-Pose-Estimation](https://github.com/HRNet/Higher-HRNet-Human-Pose-Estimation), stay tuned! - Our new work [High-Resolution Representations for Labeling Pixels and Regions](https://arxiv.org/abs/1904.04514) is available at [HRNet](https://github.com/HRNet). Our HRNet has been applied to a wide range of vision tasks, such as [image classification](https://github.com/HRNet/HRNet-Image-Classification), [objection detection](https://github.com/HRNet/HRNet-Object-Detection), [semantic segmentation](https://github.com/HRNet/HRNet-Semantic-Segmentation) and [facial landmark](https://github.com/HRNet/HRNet-Facial-Landmark-Detection). @@ -219,9 +223,30 @@ python tools/train.py \ --cfg experiments/coco/hrnet/w32_256x192_adam_lr1e-3.yaml \ ``` +### Visualization + +#### Visualizing predictions on COCO val + +``` +python visualization/plot_coco.py \ + --prediction output/coco/w48_384x288_adam_lr1e-3/results/keypoints_val2017_results_0.json \ + --save-path visualization/results + +``` + + + + + ### Other applications -Many other dense prediction tasks, such as segmentation, face alignment and object detection, etc. have been benefited by HRNet. More information can be found at [Deep High-Resolution Representation Learning](https://jingdongwang2017.github.io/Projects/HRNet/). +Many other dense prediction tasks, such as segmentation, face alignment and object detection, etc. have been benefited by HRNet. More information can be found at [High-Resolution Networks](https://github.com/HRNet). + +### Other implementation +[mmpose](https://github.com/open-mmlab/mmpose)
+[ModelScope (中文)](https://modelscope.cn/models/damo/cv_hrnetv2w32_body-2d-keypoints_image/summary)
+[timm](https://huggingface.co/docs/timm/main/en/models/hrnet) + ### Citation If you use our code or models in your research, please cite with: @@ -239,4 +264,13 @@ If you use our code or models in your research, please cite with: booktitle = {European Conference on Computer Vision (ECCV)}, year = {2018} } + +@article{WangSCJDZLMTWLX19, + title={Deep High-Resolution Representation Learning for Visual Recognition}, + author={Jingdong Wang and Ke Sun and Tianheng Cheng and + Borui Jiang and Chaorui Deng and Yang Zhao and Dong Liu and Yadong Mu and + Mingkui Tan and Xinggang Wang and Wenyu Liu and Bin Xiao}, + journal = {TPAMI} + year={2019} +} ``` diff --git a/_config.yml b/_config.yml new file mode 100644 index 00000000..c4192631 --- /dev/null +++ b/_config.yml @@ -0,0 +1 @@ +theme: jekyll-theme-cayman \ No newline at end of file diff --git a/demo/.gitignore b/demo/.gitignore new file mode 100644 index 00000000..04f267c7 --- /dev/null +++ b/demo/.gitignore @@ -0,0 +1,3 @@ +output +models +videos diff --git a/demo/Dockerfile b/demo/Dockerfile new file mode 100644 index 00000000..d8ceaf4e --- /dev/null +++ b/demo/Dockerfile @@ -0,0 +1,112 @@ +FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu16.04 + +ENV OPENCV_VERSION="3.4.6" + +# Basic toolchain +RUN apt-get update && apt-get install -y \ + apt-utils \ + build-essential \ + git \ + wget \ + unzip \ + yasm \ + pkg-config \ + libcurl4-openssl-dev \ + zlib1g-dev \ + htop \ + cmake \ + nano \ + python3-pip \ + python3-dev \ + python3-tk \ + libx264-dev \ + && cd /usr/local/bin \ + && ln -s /usr/bin/python3 python \ + && pip3 install --upgrade pip \ + && apt-get autoremove -y + +# Getting OpenCV dependencies available with apt +RUN apt-get update && apt-get install -y \ + libeigen3-dev \ + libjpeg-dev \ + libpng-dev \ + libtiff-dev \ + libjasper-dev \ + libswscale-dev \ + libavcodec-dev \ + libavformat-dev && \ + apt-get autoremove -y + +# Getting other dependencies +RUN apt-get update && apt-get install -y \ + cppcheck \ + graphviz \ + doxygen \ + p7zip-full \ + libdlib18 \ + libdlib-dev && \ + apt-get autoremove -y + + +# Install OpenCV + OpenCV contrib (takes forever) +RUN mkdir -p /tmp && \ + cd /tmp && \ + wget --no-check-certificate -O opencv.zip https://github.com/opencv/opencv/archive/${OPENCV_VERSION}.zip && \ + wget --no-check-certificate -O opencv_contrib.zip https://github.com/opencv/opencv_contrib/archive/${OPENCV_VERSION}.zip && \ + unzip opencv.zip && \ + unzip opencv_contrib.zip && \ + mkdir opencv-${OPENCV_VERSION}/build && \ + cd opencv-${OPENCV_VERSION}/build && \ + cmake -D CMAKE_BUILD_TYPE=RELEASE \ + -D CMAKE_INSTALL_PREFIX=/usr/local \ + -D WITH_CUDA=ON \ + -D CUDA_FAST_MATH=1 \ + -D WITH_CUBLAS=1 \ + -D WITH_FFMPEG=ON \ + -D WITH_OPENCL=ON \ + -D WITH_V4L=ON \ + -D WITH_OPENGL=ON \ + -D OPENCV_EXTRA_MODULES_PATH=/tmp/opencv_contrib-${OPENCV_VERSION}/modules \ + .. && \ + make -j$(nproc) && \ + make install && \ + echo "/usr/local/lib" > /etc/ld.so.conf.d/opencv.conf && \ + ldconfig && \ + cd /tmp && \ + rm -rf opencv-${OPENCV_VERSION} opencv.zip opencv_contrib-${OPENCV_VERSION} opencv_contrib.zip && \ + cd / + +# Compile and install ffmpeg from source +RUN git clone https://github.com/FFmpeg/FFmpeg /root/ffmpeg && \ + cd /root/ffmpeg && \ + ./configure --enable-gpl --enable-libx264 --enable-nonfree --disable-shared --extra-cflags=-I/usr/local/include && \ + make -j8 && make install -j8 + +# clone deep-high-resolution-net +ARG POSE_ROOT=/pose_root +RUN git clone https://github.com/leoxiaobin/deep-high-resolution-net.pytorch.git $POSE_ROOT +WORKDIR $POSE_ROOT +RUN mkdir output && mkdir log + +RUN pip3 install -r requirements.txt && \ + pip3 install torch==1.1.0 \ + torchvision==0.3.0 \ + opencv-python \ + pillow==6.2.1 + +# build deep-high-resolution-net lib +WORKDIR $POSE_ROOT/lib +RUN make + +# install COCO API +ARG COCOAPI=/cocoapi +RUN git clone https://github.com/cocodataset/cocoapi.git $COCOAPI +WORKDIR $COCOAPI/PythonAPI +# Install into global site-packages +RUN make install + +# download fastrrnn pretrained model for person detection +RUN python -c "import torchvision; model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True); model.eval()" + +COPY inference.py $POSE_ROOT/tools +COPY inference-config.yaml $POSE_ROOT/ diff --git a/demo/README.md b/demo/README.md new file mode 100644 index 00000000..aff81f44 --- /dev/null +++ b/demo/README.md @@ -0,0 +1,75 @@ +# Inference hrnet + +Inferencing the deep-high-resolution-net.pytoch without using Docker. + +## Prep +1. Download the researchers' pretrained pose estimator from [google drive](https://drive.google.com/drive/folders/1hOTihvbyIxsm5ygDpbUuJ7O_tzv4oXjC?usp=sharing) to this directory under `models/` +2. Put the video file you'd like to infer on in this directory under `videos` +3. (OPTIONAL) build the docker container in this directory with `./build-docker.sh` (this can take time because it involves compiling opencv) +4. update the `inference-config.yaml` file to reflect the number of GPUs you have available and which trained model you want to use. + +## Running the Model +### 1. Running on the video +``` +python demo/inference.py --cfg demo/inference-config.yaml \ + --videoFile ../../multi_people.mp4 \ + --writeBoxFrames \ + --outputDir output \ + TEST.MODEL_FILE ../models/pytorch/pose_coco/pose_hrnet_w32_256x192.pth + +``` + +The above command will create a video under *output* directory and a lot of pose image under *output/pose* directory. +Even with usage of GPU (GTX1080 in my case), the person detection will take nearly **0.06 sec**, the person pose match will + take nearly **0.07 sec**. In total. inference time per frame will be **0.13 sec**, nearly 10fps. So if you prefer a real-time (fps >= 20) + pose estimation then you should try other approach. + +**===Result===** + +Some output images are as: + +![1 person](inference_1.jpg) +Fig: 1 person inference + +![3 person](inference_3.jpg) +Fig: 3 person inference + +![3 person](inference_5.jpg) +Fig: 3 person inference + +### 2. Demo with more common functions +Remember to update` TEST.MODEL_FILE` in `demo/inference-config.yaml `according to your model path. + +`demo.py` provides the following functions: + +- use `--webcam` when the input is a real-time camera. +- use `--video [video-path]` when the input is a video. +- use `--image [image-path]` when the input is an image. +- use `--write` to save the image, camera or video result. +- use `--showFps` to show the fps (this fps includes the detection part). +- draw connections between joints. + +#### (1) the input is a real-time carema +```python +python demo/demo.py --webcam --showFps --write +``` + +#### (2) the input is a video +```python +python demo/demo.py --video test.mp4 --showFps --write +``` +#### (3) the input is a image + +```python +python demo/demo.py --image test.jpg --showFps --write +``` + +**===Result===** + +![show_fps](inference_6.jpg) + +Fig: show fps + +![multi-people](inference_7.jpg) + +Fig: multi-people \ No newline at end of file diff --git a/demo/_init_paths.py b/demo/_init_paths.py new file mode 100644 index 00000000..b1aea8fe --- /dev/null +++ b/demo/_init_paths.py @@ -0,0 +1,27 @@ +# ------------------------------------------------------------------------------ +# pose.pytorch +# Copyright (c) 2018-present Microsoft +# Licensed under The Apache-2.0 License [see LICENSE for details] +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path as osp +import sys + + +def add_path(path): + if path not in sys.path: + sys.path.insert(0, path) + + +this_dir = osp.dirname(__file__) + +lib_path = osp.join(this_dir, '..', 'lib') +add_path(lib_path) + +mm_path = osp.join(this_dir, '..', 'lib/poseeval/py-motmetrics') +add_path(mm_path) diff --git a/demo/build-docker.sh b/demo/build-docker.sh new file mode 100755 index 00000000..a4b1aab4 --- /dev/null +++ b/demo/build-docker.sh @@ -0,0 +1 @@ +docker build -t hrnet_demo_inference . diff --git a/demo/demo.py b/demo/demo.py new file mode 100644 index 00000000..d482e838 --- /dev/null +++ b/demo/demo.py @@ -0,0 +1,343 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import csv +import os +import shutil + +from PIL import Image +import torch +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision +import cv2 +import numpy as np +import time + + +import _init_paths +import models +from config import cfg +from config import update_config +from core.function import get_final_preds +from utils.transforms import get_affine_transform + +COCO_KEYPOINT_INDEXES = { + 0: 'nose', + 1: 'left_eye', + 2: 'right_eye', + 3: 'left_ear', + 4: 'right_ear', + 5: 'left_shoulder', + 6: 'right_shoulder', + 7: 'left_elbow', + 8: 'right_elbow', + 9: 'left_wrist', + 10: 'right_wrist', + 11: 'left_hip', + 12: 'right_hip', + 13: 'left_knee', + 14: 'right_knee', + 15: 'left_ankle', + 16: 'right_ankle' +} + +COCO_INSTANCE_CATEGORY_NAMES = [ + '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', + 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', + 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', + 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' +] + +SKELETON = [ + [1,3],[1,0],[2,4],[2,0],[0,5],[0,6],[5,7],[7,9],[6,8],[8,10],[5,11],[6,12],[11,12],[11,13],[13,15],[12,14],[14,16] +] + +CocoColors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + +NUM_KPTS = 17 + +CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +def draw_pose(keypoints,img): + """draw the keypoints and the skeletons. + :params keypoints: the shape should be equal to [17,2] + :params img: + """ + assert keypoints.shape == (NUM_KPTS,2) + for i in range(len(SKELETON)): + kpt_a, kpt_b = SKELETON[i][0], SKELETON[i][1] + x_a, y_a = keypoints[kpt_a][0],keypoints[kpt_a][1] + x_b, y_b = keypoints[kpt_b][0],keypoints[kpt_b][1] + cv2.circle(img, (int(x_a), int(y_a)), 6, CocoColors[i], -1) + cv2.circle(img, (int(x_b), int(y_b)), 6, CocoColors[i], -1) + cv2.line(img, (int(x_a), int(y_a)), (int(x_b), int(y_b)), CocoColors[i], 2) + +def draw_bbox(box,img): + """draw the detected bounding box on the image. + :param img: + """ + cv2.rectangle(img, box[0], box[1], color=(0, 255, 0),thickness=3) + + +def get_person_detection_boxes(model, img, threshold=0.5): + pred = model(img) + pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i] + for i in list(pred[0]['labels'].cpu().numpy())] # Get the Prediction Score + pred_boxes = [[(i[0], i[1]), (i[2], i[3])] + for i in list(pred[0]['boxes'].detach().cpu().numpy())] # Bounding boxes + pred_score = list(pred[0]['scores'].detach().cpu().numpy()) + if not pred_score or max(pred_score) threshold][-1] + pred_boxes = pred_boxes[:pred_t+1] + pred_classes = pred_classes[:pred_t+1] + + person_boxes = [] + for idx, box in enumerate(pred_boxes): + if pred_classes[idx] == 'person': + person_boxes.append(box) + + return person_boxes + + +def get_pose_estimation_prediction(pose_model, image, center, scale): + rotation = 0 + + # pose estimation transformation + trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE) + model_input = cv2.warpAffine( + image, + trans, + (int(cfg.MODEL.IMAGE_SIZE[0]), int(cfg.MODEL.IMAGE_SIZE[1])), + flags=cv2.INTER_LINEAR) + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + + # pose estimation inference + model_input = transform(model_input).unsqueeze(0) + # switch to evaluate mode + pose_model.eval() + with torch.no_grad(): + # compute output heatmap + output = pose_model(model_input) + preds, _ = get_final_preds( + cfg, + output.clone().cpu().numpy(), + np.asarray([center]), + np.asarray([scale])) + + return preds + + +def box_to_center_scale(box, model_image_width, model_image_height): + """convert a box to center,scale information required for pose transformation + Parameters + ---------- + box : list of tuple + list of length 2 with two tuples of floats representing + bottom left and top right corner of a box + model_image_width : int + model_image_height : int + + Returns + ------- + (numpy array, numpy array) + Two numpy arrays, coordinates for the center of the box and the scale of the box + """ + center = np.zeros((2), dtype=np.float32) + + bottom_left_corner = box[0] + top_right_corner = box[1] + box_width = top_right_corner[0]-bottom_left_corner[0] + box_height = top_right_corner[1]-bottom_left_corner[1] + bottom_left_x = bottom_left_corner[0] + bottom_left_y = bottom_left_corner[1] + center[0] = bottom_left_x + box_width * 0.5 + center[1] = bottom_left_y + box_height * 0.5 + + aspect_ratio = model_image_width * 1.0 / model_image_height + pixel_std = 200 + + if box_width > aspect_ratio * box_height: + box_height = box_width * 1.0 / aspect_ratio + elif box_width < aspect_ratio * box_height: + box_width = box_height * aspect_ratio + scale = np.array( + [box_width * 1.0 / pixel_std, box_height * 1.0 / pixel_std], + dtype=np.float32) + if center[0] != -1: + scale = scale * 1.25 + + return center, scale + +def parse_args(): + parser = argparse.ArgumentParser(description='Train keypoints network') + # general + parser.add_argument('--cfg', type=str, default='demo/inference-config.yaml') + parser.add_argument('--video', type=str) + parser.add_argument('--webcam',action='store_true') + parser.add_argument('--image',type=str) + parser.add_argument('--write',action='store_true') + parser.add_argument('--showFps',action='store_true') + + parser.add_argument('opts', + help='Modify config options using the command-line', + default=None, + nargs=argparse.REMAINDER) + + args = parser.parse_args() + + # args expected by supporting codebase + args.modelDir = '' + args.logDir = '' + args.dataDir = '' + args.prevModelDir = '' + return args + + +def main(): + # cudnn related setting + cudnn.benchmark = cfg.CUDNN.BENCHMARK + torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC + torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED + + args = parse_args() + update_config(cfg, args) + + box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + box_model.to(CTX) + box_model.eval() + + pose_model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')( + cfg, is_train=False + ) + + if cfg.TEST.MODEL_FILE: + print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE)) + pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False) + else: + print('expected model defined in config at TEST.MODEL_FILE') + + pose_model = torch.nn.DataParallel(pose_model, device_ids=cfg.GPUS) + pose_model.to(CTX) + pose_model.eval() + + # Loading an video or an image or webcam + if args.webcam: + vidcap = cv2.VideoCapture(0) + elif args.video: + vidcap = cv2.VideoCapture(args.video) + elif args.image: + image_bgr = cv2.imread(args.image) + else: + print('please use --video or --webcam or --image to define the input.') + return + + if args.webcam or args.video: + if args.write: + save_path = 'output.avi' + fourcc = cv2.VideoWriter_fourcc(*'XVID') + out = cv2.VideoWriter(save_path,fourcc, 24.0, (int(vidcap.get(3)),int(vidcap.get(4)))) + while True: + ret, image_bgr = vidcap.read() + if ret: + last_time = time.time() + image = image_bgr[:, :, [2, 1, 0]] + + input = [] + img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().to(CTX) + input.append(img_tensor) + + # object detection box + pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.9) + + # pose estimation + if len(pred_boxes) >= 1: + for box in pred_boxes: + center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]) + image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy() + pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale) + if len(pose_preds)>=1: + for kpt in pose_preds: + draw_pose(kpt,image_bgr) # draw the poses + + if args.showFps: + fps = 1/(time.time()-last_time) + img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2) + + if args.write: + out.write(image_bgr) + + cv2.imshow('demo',image_bgr) + if cv2.waitKey(1) & 0XFF==ord('q'): + break + else: + print('cannot load the video.') + break + + cv2.destroyAllWindows() + vidcap.release() + if args.write: + print('video has been saved as {}'.format(save_path)) + out.release() + + else: + # estimate on the image + last_time = time.time() + image = image_bgr[:, :, [2, 1, 0]] + + input = [] + img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().to(CTX) + input.append(img_tensor) + + # object detection box + pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.9) + + # pose estimation + if len(pred_boxes) >= 1: + for box in pred_boxes: + center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]) + image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy() + pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale) + if len(pose_preds)>=1: + for kpt in pose_preds: + draw_pose(kpt,image_bgr) # draw the poses + + if args.showFps: + fps = 1/(time.time()-last_time) + img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2) + + if args.write: + save_path = 'output.jpg' + cv2.imwrite(save_path,image_bgr) + print('the result image has been saved as {}'.format(save_path)) + + cv2.imshow('demo',image_bgr) + if cv2.waitKey(0) & 0XFF==ord('q'): + cv2.destroyAllWindows() + +if __name__ == '__main__': + main() diff --git a/demo/hrnet-demo.gif b/demo/hrnet-demo.gif new file mode 100644 index 00000000..30e37079 Binary files /dev/null and b/demo/hrnet-demo.gif differ diff --git a/demo/inference-config.yaml b/demo/inference-config.yaml new file mode 100644 index 00000000..14bce176 --- /dev/null +++ b/demo/inference-config.yaml @@ -0,0 +1,127 @@ +AUTO_RESUME: true +CUDNN: + BENCHMARK: true + DETERMINISTIC: false + ENABLED: true +DATA_DIR: '' +GPUS: (0,) +OUTPUT_DIR: 'output' +LOG_DIR: 'log' +WORKERS: 24 +PRINT_FREQ: 100 + +DATASET: + COLOR_RGB: true + DATASET: 'coco' + DATA_FORMAT: jpg + FLIP: true + NUM_JOINTS_HALF_BODY: 8 + PROB_HALF_BODY: 0.3 + ROOT: 'data/coco/' + ROT_FACTOR: 45 + SCALE_FACTOR: 0.35 + TEST_SET: 'val2017' + TRAIN_SET: 'train2017' +MODEL: + INIT_WEIGHTS: true + NAME: pose_hrnet + NUM_JOINTS: 17 + PRETRAINED: 'models/pytorch/pose_coco/pose_hrnet_w32_384x288.pth' + TARGET_TYPE: gaussian + IMAGE_SIZE: + - 288 + - 384 + HEATMAP_SIZE: + - 72 + - 96 + SIGMA: 3 + EXTRA: + PRETRAINED_LAYERS: + - 'conv1' + - 'bn1' + - 'conv2' + - 'bn2' + - 'layer1' + - 'transition1' + - 'stage2' + - 'transition2' + - 'stage3' + - 'transition3' + - 'stage4' + FINAL_CONV_KERNEL: 1 + STAGE2: + NUM_MODULES: 1 + NUM_BRANCHES: 2 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + FUSE_METHOD: SUM + STAGE3: + NUM_MODULES: 4 + NUM_BRANCHES: 3 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + - 128 + FUSE_METHOD: SUM + STAGE4: + NUM_MODULES: 3 + NUM_BRANCHES: 4 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 32 + - 64 + - 128 + - 256 + FUSE_METHOD: SUM +LOSS: + USE_TARGET_WEIGHT: true +TRAIN: + BATCH_SIZE_PER_GPU: 32 + SHUFFLE: true + BEGIN_EPOCH: 0 + END_EPOCH: 210 + OPTIMIZER: adam + LR: 0.001 + LR_FACTOR: 0.1 + LR_STEP: + - 170 + - 200 + WD: 0.0001 + GAMMA1: 0.99 + GAMMA2: 0.0 + MOMENTUM: 0.9 + NESTEROV: false +TEST: + BATCH_SIZE_PER_GPU: 32 + COCO_BBOX_FILE: 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json' + BBOX_THRE: 1.0 + IMAGE_THRE: 0.0 + IN_VIS_THRE: 0.2 + MODEL_FILE: 'models/pytorch/pose_coco/pose_hrnet_w32_384x288.pth' + NMS_THRE: 1.0 + OKS_THRE: 0.9 + USE_GT_BBOX: true + FLIP_TEST: true + POST_PROCESS: true + SHIFT_HEATMAP: true +DEBUG: + DEBUG: true + SAVE_BATCH_IMAGES_GT: true + SAVE_BATCH_IMAGES_PRED: true + SAVE_HEATMAPS_GT: true + SAVE_HEATMAPS_PRED: true diff --git a/demo/inference.py b/demo/inference.py new file mode 100644 index 00000000..efff86a7 --- /dev/null +++ b/demo/inference.py @@ -0,0 +1,341 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import csv +import os +import shutil + +from PIL import Image +import torch +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision +import cv2 +import numpy as np + +import sys +sys.path.append("../lib") +import time + +# import _init_paths +import models +from config import cfg +from config import update_config +from core.inference import get_final_preds +from utils.transforms import get_affine_transform + +CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + +COCO_KEYPOINT_INDEXES = { + 0: 'nose', + 1: 'left_eye', + 2: 'right_eye', + 3: 'left_ear', + 4: 'right_ear', + 5: 'left_shoulder', + 6: 'right_shoulder', + 7: 'left_elbow', + 8: 'right_elbow', + 9: 'left_wrist', + 10: 'right_wrist', + 11: 'left_hip', + 12: 'right_hip', + 13: 'left_knee', + 14: 'right_knee', + 15: 'left_ankle', + 16: 'right_ankle' +} + +COCO_INSTANCE_CATEGORY_NAMES = [ + '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', + 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', + 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', + 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' +] + + +def get_person_detection_boxes(model, img, threshold=0.5): + pil_image = Image.fromarray(img) # Load the image + transform = transforms.Compose([transforms.ToTensor()]) # Defing PyTorch Transform + transformed_img = transform(pil_image) # Apply the transform to the image + pred = model([transformed_img.to(CTX)]) # Pass the image to the model + # Use the first detected person + pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i] + for i in list(pred[0]['labels'].cpu().numpy())] # Get the Prediction Score + pred_boxes = [[(i[0], i[1]), (i[2], i[3])] + for i in list(pred[0]['boxes'].cpu().detach().numpy())] # Bounding boxes + pred_scores = list(pred[0]['scores'].cpu().detach().numpy()) + + person_boxes = [] + # Select box has score larger than threshold and is person + for pred_class, pred_box, pred_score in zip(pred_classes, pred_boxes, pred_scores): + if (pred_score > threshold) and (pred_class == 'person'): + person_boxes.append(pred_box) + + return person_boxes + + +def get_pose_estimation_prediction(pose_model, image, centers, scales, transform): + rotation = 0 + + # pose estimation transformation + model_inputs = [] + for center, scale in zip(centers, scales): + trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE) + # Crop smaller image of people + model_input = cv2.warpAffine( + image, + trans, + (int(cfg.MODEL.IMAGE_SIZE[0]), int(cfg.MODEL.IMAGE_SIZE[1])), + flags=cv2.INTER_LINEAR) + + # hwc -> 1chw + model_input = transform(model_input)#.unsqueeze(0) + model_inputs.append(model_input) + + # n * 1chw -> nchw + model_inputs = torch.stack(model_inputs) + + # compute output heatmap + output = pose_model(model_inputs.to(CTX)) + coords, _ = get_final_preds( + cfg, + output.cpu().detach().numpy(), + np.asarray(centers), + np.asarray(scales)) + + return coords + + +def box_to_center_scale(box, model_image_width, model_image_height): + """convert a box to center,scale information required for pose transformation + Parameters + ---------- + box : list of tuple + list of length 2 with two tuples of floats representing + bottom left and top right corner of a box + model_image_width : int + model_image_height : int + + Returns + ------- + (numpy array, numpy array) + Two numpy arrays, coordinates for the center of the box and the scale of the box + """ + center = np.zeros((2), dtype=np.float32) + + bottom_left_corner = box[0] + top_right_corner = box[1] + box_width = top_right_corner[0]-bottom_left_corner[0] + box_height = top_right_corner[1]-bottom_left_corner[1] + bottom_left_x = bottom_left_corner[0] + bottom_left_y = bottom_left_corner[1] + center[0] = bottom_left_x + box_width * 0.5 + center[1] = bottom_left_y + box_height * 0.5 + + aspect_ratio = model_image_width * 1.0 / model_image_height + pixel_std = 200 + + if box_width > aspect_ratio * box_height: + box_height = box_width * 1.0 / aspect_ratio + elif box_width < aspect_ratio * box_height: + box_width = box_height * aspect_ratio + scale = np.array( + [box_width * 1.0 / pixel_std, box_height * 1.0 / pixel_std], + dtype=np.float32) + if center[0] != -1: + scale = scale * 1.25 + + return center, scale + + +def prepare_output_dirs(prefix='/output/'): + pose_dir = os.path.join(prefix, "pose") + if os.path.exists(pose_dir) and os.path.isdir(pose_dir): + shutil.rmtree(pose_dir) + os.makedirs(pose_dir, exist_ok=True) + return pose_dir + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train keypoints network') + # general + parser.add_argument('--cfg', type=str, required=True) + parser.add_argument('--videoFile', type=str, required=True) + parser.add_argument('--outputDir', type=str, default='/output/') + parser.add_argument('--inferenceFps', type=int, default=10) + parser.add_argument('--writeBoxFrames', action='store_true') + + parser.add_argument('opts', + help='Modify config options using the command-line', + default=None, + nargs=argparse.REMAINDER) + + args = parser.parse_args() + + # args expected by supporting codebase + args.modelDir = '' + args.logDir = '' + args.dataDir = '' + args.prevModelDir = '' + return args + + +def main(): + # transformation + pose_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + + # cudnn related setting + cudnn.benchmark = cfg.CUDNN.BENCHMARK + torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC + torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED + + args = parse_args() + update_config(cfg, args) + pose_dir = prepare_output_dirs(args.outputDir) + csv_output_rows = [] + + box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + box_model.to(CTX) + box_model.eval() + pose_model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')( + cfg, is_train=False + ) + + if cfg.TEST.MODEL_FILE: + print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE)) + pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False) + else: + print('expected model defined in config at TEST.MODEL_FILE') + + pose_model.to(CTX) + pose_model.eval() + + # Loading an video + vidcap = cv2.VideoCapture(args.videoFile) + fps = vidcap.get(cv2.CAP_PROP_FPS) + if fps < args.inferenceFps: + print('desired inference fps is '+str(args.inferenceFps)+' but video fps is '+str(fps)) + exit() + skip_frame_cnt = round(fps / args.inferenceFps) + frame_width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)) + frame_height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + outcap = cv2.VideoWriter('{}/{}_pose.avi'.format(args.outputDir, os.path.splitext(os.path.basename(args.videoFile))[0]), + cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), int(skip_frame_cnt), (frame_width, frame_height)) + + count = 0 + while vidcap.isOpened(): + total_now = time.time() + ret, image_bgr = vidcap.read() + count += 1 + + if not ret: + continue + + if count % skip_frame_cnt != 0: + continue + + image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + + # Clone 2 image for person detection and pose estimation + if cfg.DATASET.COLOR_RGB: + image_per = image_rgb.copy() + image_pose = image_rgb.copy() + else: + image_per = image_bgr.copy() + image_pose = image_bgr.copy() + + # Clone 1 image for debugging purpose + image_debug = image_bgr.copy() + + # object detection box + now = time.time() + pred_boxes = get_person_detection_boxes(box_model, image_per, threshold=0.9) + then = time.time() + print("Find person bbox in: {} sec".format(then - now)) + + # Can not find people. Move to next frame + if not pred_boxes: + count += 1 + continue + + if args.writeBoxFrames: + for box in pred_boxes: + cv2.rectangle(image_debug, box[0], box[1], color=(0, 255, 0), + thickness=3) # Draw Rectangle with the coordinates + + # pose estimation : for multiple people + centers = [] + scales = [] + for box in pred_boxes: + center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]) + centers.append(center) + scales.append(scale) + + now = time.time() + pose_preds = get_pose_estimation_prediction(pose_model, image_pose, centers, scales, transform=pose_transform) + then = time.time() + print("Find person pose in: {} sec".format(then - now)) + + new_csv_row = [] + for coords in pose_preds: + # Draw each point on image + for coord in coords: + x_coord, y_coord = int(coord[0]), int(coord[1]) + cv2.circle(image_debug, (x_coord, y_coord), 4, (255, 0, 0), 2) + new_csv_row.extend([x_coord, y_coord]) + + total_then = time.time() + + text = "{:03.2f} sec".format(total_then - total_now) + cv2.putText(image_debug, text, (100, 50), cv2.FONT_HERSHEY_SIMPLEX, + 1, (0, 0, 255), 2, cv2.LINE_AA) + + cv2.imshow("pos", image_debug) + if cv2.waitKey(1) & 0xFF == ord('q'): + break + + csv_output_rows.append(new_csv_row) + img_file = os.path.join(pose_dir, 'pose_{:08d}.jpg'.format(count)) + cv2.imwrite(img_file, image_debug) + outcap.write(image_debug) + + + # write csv + csv_headers = ['frame'] + for keypoint in COCO_KEYPOINT_INDEXES.values(): + csv_headers.extend([keypoint+'_x', keypoint+'_y']) + + csv_output_filename = os.path.join(args.outputDir, 'pose-data.csv') + with open(csv_output_filename, 'w', newline='') as csvfile: + csvwriter = csv.writer(csvfile) + csvwriter.writerow(csv_headers) + csvwriter.writerows(csv_output_rows) + + vidcap.release() + outcap.release() + + cv2.destroyAllWindows() + + +if __name__ == '__main__': + main() diff --git a/demo/inference_1.jpg b/demo/inference_1.jpg new file mode 100644 index 00000000..2ca29d1b Binary files /dev/null and b/demo/inference_1.jpg differ diff --git a/demo/inference_3.jpg b/demo/inference_3.jpg new file mode 100644 index 00000000..7f20915c Binary files /dev/null and b/demo/inference_3.jpg differ diff --git a/demo/inference_5.jpg b/demo/inference_5.jpg new file mode 100644 index 00000000..d7b7117c Binary files /dev/null and b/demo/inference_5.jpg differ diff --git a/demo/inference_6.jpg b/demo/inference_6.jpg new file mode 100644 index 00000000..cc1183c0 Binary files /dev/null and b/demo/inference_6.jpg differ diff --git a/demo/inference_7.jpg b/demo/inference_7.jpg new file mode 100644 index 00000000..9629b300 Binary files /dev/null and b/demo/inference_7.jpg differ diff --git a/figures/visualization/coco/score_610_id_2685_000000002685.png b/figures/visualization/coco/score_610_id_2685_000000002685.png new file mode 100644 index 00000000..615ff258 Binary files /dev/null and b/figures/visualization/coco/score_610_id_2685_000000002685.png differ diff --git a/figures/visualization/coco/score_710_id_153229_000000153229.png b/figures/visualization/coco/score_710_id_153229_000000153229.png new file mode 100644 index 00000000..61d73e3a Binary files /dev/null and b/figures/visualization/coco/score_710_id_153229_000000153229.png differ diff --git a/figures/visualization/coco/score_755_id_343561_000000343561.png b/figures/visualization/coco/score_755_id_343561_000000343561.png new file mode 100644 index 00000000..f114bd42 Binary files /dev/null and b/figures/visualization/coco/score_755_id_343561_000000343561.png differ diff --git a/figures/visualization/coco/score_755_id_559842_000000559842.png b/figures/visualization/coco/score_755_id_559842_000000559842.png new file mode 100644 index 00000000..7123e65e Binary files /dev/null and b/figures/visualization/coco/score_755_id_559842_000000559842.png differ diff --git a/figures/visualization/coco/score_770_id_6954_000000006954.png b/figures/visualization/coco/score_770_id_6954_000000006954.png new file mode 100644 index 00000000..caba54b0 Binary files /dev/null and b/figures/visualization/coco/score_770_id_6954_000000006954.png differ diff --git a/figures/visualization/coco/score_919_id_53626_000000053626.png b/figures/visualization/coco/score_919_id_53626_000000053626.png new file mode 100644 index 00000000..3efcd62a Binary files /dev/null and b/figures/visualization/coco/score_919_id_53626_000000053626.png differ diff --git a/lib/core/function.py b/lib/core/function.py index dadff0fa..1bc19daa 100755 --- a/lib/core/function.py +++ b/lib/core/function.py @@ -124,10 +124,7 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir, output = outputs if config.TEST.FLIP_TEST: - # this part is ugly, because pytorch has not supported negative index - # input_flipped = model(input[:, :, :, ::-1]) - input_flipped = np.flip(input.cpu().numpy(), 3).copy() - input_flipped = torch.from_numpy(input_flipped).cuda() + input_flipped = input.flip(3) outputs_flipped = model(input_flipped) if isinstance(outputs_flipped, list): diff --git a/lib/models/pose_hrnet.py b/lib/models/pose_hrnet.py index ea65f419..09ff346a 100644 --- a/lib/models/pose_hrnet.py +++ b/lib/models/pose_hrnet.py @@ -275,7 +275,7 @@ class PoseHighResolutionNet(nn.Module): def __init__(self, cfg, **kwargs): self.inplanes = 64 - extra = cfg.MODEL.EXTRA + extra = cfg['MODEL']['EXTRA'] super(PoseHighResolutionNet, self).__init__() # stem net @@ -288,7 +288,7 @@ def __init__(self, cfg, **kwargs): self.relu = nn.ReLU(inplace=True) self.layer1 = self._make_layer(Bottleneck, 64, 4) - self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2'] + self.stage2_cfg = extra['STAGE2'] num_channels = self.stage2_cfg['NUM_CHANNELS'] block = blocks_dict[self.stage2_cfg['BLOCK']] num_channels = [ @@ -298,7 +298,7 @@ def __init__(self, cfg, **kwargs): self.stage2, pre_stage_channels = self._make_stage( self.stage2_cfg, num_channels) - self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3'] + self.stage3_cfg = extra['STAGE3'] num_channels = self.stage3_cfg['NUM_CHANNELS'] block = blocks_dict[self.stage3_cfg['BLOCK']] num_channels = [ @@ -309,7 +309,7 @@ def __init__(self, cfg, **kwargs): self.stage3, pre_stage_channels = self._make_stage( self.stage3_cfg, num_channels) - self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4'] + self.stage4_cfg = extra['STAGE4'] num_channels = self.stage4_cfg['NUM_CHANNELS'] block = blocks_dict[self.stage4_cfg['BLOCK']] num_channels = [ @@ -322,13 +322,13 @@ def __init__(self, cfg, **kwargs): self.final_layer = nn.Conv2d( in_channels=pre_stage_channels[0], - out_channels=cfg.MODEL.NUM_JOINTS, - kernel_size=extra.FINAL_CONV_KERNEL, + out_channels=cfg['MODEL']['NUM_JOINTS'], + kernel_size=extra['FINAL_CONV_KERNEL'], stride=1, - padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0 + padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0 ) - self.pretrained_layers = cfg['MODEL']['EXTRA']['PRETRAINED_LAYERS'] + self.pretrained_layers = extra['PRETRAINED_LAYERS'] def _make_transition_layer( self, num_channels_pre_layer, num_channels_cur_layer): @@ -495,7 +495,7 @@ def init_weights(self, pretrained=''): def get_pose_net(cfg, is_train, **kwargs): model = PoseHighResolutionNet(cfg, **kwargs) - if is_train and cfg.MODEL.INIT_WEIGHTS: - model.init_weights(cfg.MODEL.PRETRAINED) + if is_train and cfg['MODEL']['INIT_WEIGHTS']: + model.init_weights(cfg['MODEL']['PRETRAINED']) return model diff --git a/requirements.txt b/requirements.txt index 18c9ee11..14f225c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,4 @@ pyyaml json_tricks scikit-image yacs>=0.1.5 -tensorboardX>=1.6 +tensorboardX==1.6 diff --git a/visualization/plot_coco.py b/visualization/plot_coco.py new file mode 100644 index 00000000..c0e79039 --- /dev/null +++ b/visualization/plot_coco.py @@ -0,0 +1,309 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Ke Sun (sunk@mail.ustc.edu.cn) +# Modified by Depu Meng (mdp@mail.ustc.edu.cn) +# ------------------------------------------------------------------------------ + +import argparse +import numpy as np +import matplotlib.pyplot as plt +import cv2 +import json +import matplotlib.lines as mlines +import matplotlib.patches as mpatches +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +import os + + +class ColorStyle: + def __init__(self, color, link_pairs, point_color): + self.color = color + self.link_pairs = link_pairs + self.point_color = point_color + + for i in range(len(self.color)): + self.link_pairs[i].append(tuple(np.array(self.color[i])/255.)) + + self.ring_color = [] + for i in range(len(self.point_color)): + self.ring_color.append(tuple(np.array(self.point_color[i])/255.)) + +# Xiaochu Style +# (R,G,B) +color1 = [(179,0,0),(228,26,28),(255,255,51), + (49,163,84), (0,109,45), (255,255,51), + (240,2,127),(240,2,127),(240,2,127), (240,2,127), (240,2,127), + (217,95,14), (254,153,41),(255,255,51), + (44,127,184),(0,0,255)] + +link_pairs1 = [ + [15, 13], [13, 11], [11, 5], + [12, 14], [14, 16], [12, 6], + [3, 1],[1, 2],[1, 0],[0, 2],[2,4], + [9, 7], [7,5], [5, 6], + [6, 8], [8, 10], + ] + +point_color1 = [(240,2,127),(240,2,127),(240,2,127), + (240,2,127), (240,2,127), + (255,255,51),(255,255,51), + (254,153,41),(44,127,184), + (217,95,14),(0,0,255), + (255,255,51),(255,255,51),(228,26,28), + (49,163,84),(252,176,243),(0,176,240), + (255,255,0),(169, 209, 142), + (255,255,0),(169, 209, 142), + (255,255,0),(169, 209, 142)] + +xiaochu_style = ColorStyle(color1, link_pairs1, point_color1) + + +# Chunhua Style +# (R,G,B) +color2 = [(252,176,243),(252,176,243),(252,176,243), + (0,176,240), (0,176,240), (0,176,240), + (240,2,127),(240,2,127),(240,2,127), (240,2,127), (240,2,127), + (255,255,0), (255,255,0),(169, 209, 142), + (169, 209, 142),(169, 209, 142)] + +link_pairs2 = [ + [15, 13], [13, 11], [11, 5], + [12, 14], [14, 16], [12, 6], + [3, 1],[1, 2],[1, 0],[0, 2],[2,4], + [9, 7], [7,5], [5, 6], [6, 8], [8, 10], + ] + +point_color2 = [(240,2,127),(240,2,127),(240,2,127), + (240,2,127), (240,2,127), + (255,255,0),(169, 209, 142), + (255,255,0),(169, 209, 142), + (255,255,0),(169, 209, 142), + (252,176,243),(0,176,240),(252,176,243), + (0,176,240),(252,176,243),(0,176,240), + (255,255,0),(169, 209, 142), + (255,255,0),(169, 209, 142), + (255,255,0),(169, 209, 142)] + +chunhua_style = ColorStyle(color2, link_pairs2, point_color2) + +def parse_args(): + parser = argparse.ArgumentParser(description='Visualize COCO predictions') + # general + parser.add_argument('--image-path', + help='Path of COCO val images', + type=str, + default='data/coco/images/val2017/' + ) + + parser.add_argument('--gt-anno', + help='Path of COCO val annotation', + type=str, + default='data/coco/annotations/person_keypoints_val2017.json' + ) + + parser.add_argument('--save-path', + help="Path to save the visualizations", + type=str, + default='visualization/coco/') + + parser.add_argument('--prediction', + help="Prediction file to visualize", + type=str, + required=True) + + parser.add_argument('--style', + help="Style of the visualization: Chunhua style or Xiaochu style", + type=str, + default='chunhua') + + args = parser.parse_args() + + return args + + +def map_joint_dict(joints): + joints_dict = {} + for i in range(joints.shape[0]): + x = int(joints[i][0]) + y = int(joints[i][1]) + id = i + joints_dict[id] = (x, y) + + return joints_dict + +def plot(data, gt_file, img_path, save_path, + link_pairs, ring_color, save=True): + + # joints + coco = COCO(gt_file) + coco_dt = coco.loadRes(data) + coco_eval = COCOeval(coco, coco_dt, 'keypoints') + coco_eval._prepare() + gts_ = coco_eval._gts + dts_ = coco_eval._dts + + p = coco_eval.params + p.imgIds = list(np.unique(p.imgIds)) + if p.useCats: + p.catIds = list(np.unique(p.catIds)) + p.maxDets = sorted(p.maxDets) + + # loop through images, area range, max detection number + catIds = p.catIds if p.useCats else [-1] + threshold = 0.3 + joint_thres = 0.2 + for catId in catIds: + for imgId in p.imgIds[:5000]: + # dimention here should be Nxm + gts = gts_[imgId, catId] + dts = dts_[imgId, catId] + inds = np.argsort([-d['score'] for d in dts], kind='mergesort') + dts = [dts[i] for i in inds] + if len(dts) > p.maxDets[-1]: + dts = dts[0:p.maxDets[-1]] + if len(gts) == 0 or len(dts) == 0: + continue + + sum_score = 0 + num_box = 0 + img_name = str(imgId).zfill(12) + + # Read Images + img_file = img_path + img_name + '.jpg' + data_numpy = cv2.imread(img_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) + h = data_numpy.shape[0] + w = data_numpy.shape[1] + + # Plot + fig = plt.figure(figsize=(w/100, h/100), dpi=100) + ax = plt.subplot(1,1,1) + bk = plt.imshow(data_numpy[:,:,::-1]) + bk.set_zorder(-1) + print(img_name) + for j, gt in enumerate(gts): + # matching dt_box and gt_box + bb = gt['bbox'] + x0 = bb[0] - bb[2]; x1 = bb[0] + bb[2] * 2 + y0 = bb[1] - bb[3]; y1 = bb[1] + bb[3] * 2 + + # create bounds for ignore regions(double the gt bbox) + g = np.array(gt['keypoints']) + #xg = g[0::3]; yg = g[1::3]; + vg = g[2::3] + + for i, dt in enumerate(dts): + # Calculate IoU + dt_bb = dt['bbox'] + dt_x0 = dt_bb[0] - dt_bb[2]; dt_x1 = dt_bb[0] + dt_bb[2] * 2 + dt_y0 = dt_bb[1] - dt_bb[3]; dt_y1 = dt_bb[1] + dt_bb[3] * 2 + + ol_x = min(x1, dt_x1) - max(x0, dt_x0) + ol_y = min(y1, dt_y1) - max(y0, dt_y0) + ol_area = ol_x * ol_y + s_x = max(x1, dt_x1) - min(x0, dt_x0) + s_y = max(y1, dt_y1) - min(y0, dt_y0) + sum_area = s_x * s_y + iou = ol_area / (sum_area + np.spacing(1)) + score = dt['score'] + + if iou < 0.1 or score < threshold: + continue + else: + print('iou: ', iou) + dt_w = dt_x1 - dt_x0 + dt_h = dt_y1 - dt_y0 + ref = min(dt_w, dt_h) + num_box += 1 + sum_score += dt['score'] + dt_joints = np.array(dt['keypoints']).reshape(17,-1) + joints_dict = map_joint_dict(dt_joints) + + # stick + for k, link_pair in enumerate(link_pairs): + if link_pair[0] in joints_dict \ + and link_pair[1] in joints_dict: + if dt_joints[link_pair[0],2] < joint_thres \ + or dt_joints[link_pair[1],2] < joint_thres \ + or vg[link_pair[0]] == 0 \ + or vg[link_pair[1]] == 0: + continue + if k in range(6,11): + lw = 1 + else: + lw = ref / 100. + line = mlines.Line2D( + np.array([joints_dict[link_pair[0]][0], + joints_dict[link_pair[1]][0]]), + np.array([joints_dict[link_pair[0]][1], + joints_dict[link_pair[1]][1]]), + ls='-', lw=lw, alpha=1, color=link_pair[2],) + line.set_zorder(0) + ax.add_line(line) + # black ring + for k in range(dt_joints.shape[0]): + if dt_joints[k,2] < joint_thres \ + or vg[link_pair[0]] == 0 \ + or vg[link_pair[1]] == 0: + continue + if dt_joints[k,0] > w or dt_joints[k,1] > h: + continue + if k in range(5): + radius = 1 + else: + radius = ref / 100 + + circle = mpatches.Circle(tuple(dt_joints[k,:2]), + radius=radius, + ec='black', + fc=ring_color[k], + alpha=1, + linewidth=1) + circle.set_zorder(1) + ax.add_patch(circle) + + avg_score = (sum_score / (num_box+np.spacing(1)))*1000 + + plt.gca().xaxis.set_major_locator(plt.NullLocator()) + plt.gca().yaxis.set_major_locator(plt.NullLocator()) + plt.axis('off') + plt.subplots_adjust(top=1,bottom=0,left=0,right=1,hspace=0,wspace=0) + plt.margins(0,0) + if save: + plt.savefig(save_path + \ + 'score_'+str(np.int(avg_score))+ \ + '_id_'+str(imgId)+ \ + '_'+img_name + '.png', + format='png', bbox_inckes='tight', dpi=100) + plt.savefig(save_path +'id_'+str(imgId)+ '.pdf', format='pdf', + bbox_inckes='tight', dpi=100) + # plt.show() + plt.close() + +if __name__ == '__main__': + + args = parse_args() + if args.style == 'xiaochu': + # Xiaochu Style + colorstyle = xiaochu_style + elif args.style == 'chunhua': + # Chunhua Style + colorstyle = chunhua_style + else: + raise Exception('Invalid color style') + + save_path = args.save_path + img_path = args.image_path + if not os.path.exists(save_path): + try: + os.makedirs(save_path) + except Exception: + print('Fail to make {}'.format(save_path)) + + + with open(args.prediction) as f: + data = json.load(f) + gt_file = args.gt_anno + plot(data, gt_file, img_path, save_path, colorstyle.link_pairs, colorstyle.ring_color, save=True) +