diff --git a/README.md b/README.md
index 5882990e..40b7720e 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,10 @@
-# Deep High-Resolution Representation Learning for Human Pose Estimation (accepted to CVPR2019)
+# Deep High-Resolution Representation Learning for Human Pose Estimation (CVPR 2019)
## News
-- If you are interested in internship or research positions related to computer vision in ByteDance AI Lab, feel free to contact me(leoxiaobin-at-gmail.com).
+- [2021/04/12] Welcome to check out our recent work on bottom-up pose estimation (CVPR 2021) [HRNet-DEKR](https://github.com/HRNet/DEKR)!
+- [2020/07/05] [A very nice blog](https://towardsdatascience.com/overview-of-human-pose-estimation-neural-networks-hrnet-higherhrnet-architectures-and-faq-1954b2f8b249) from Towards Data Science introducing HRNet and HigherHRNet for human pose estimation.
+- [2020/03/13] A longer version is accepted by TPAMI: [Deep High-Resolution Representation Learning for Visual Recognition](https://arxiv.org/pdf/1908.07919.pdf). It includes more HRNet applications, and the codes are available: [semantic segmentation](https://github.com/HRNet/HRNet-Semantic-Segmentation), [objection detection](https://github.com/HRNet/HRNet-Object-Detection), [facial landmark detection](https://github.com/HRNet/HRNet-Facial-Landmark-Detection), and [image classification](https://github.com/HRNet/HRNet-Image-Classification).
+- [2020/02/01] We have added demo code for HRNet. Thanks [Alex Simes](https://github.com/alex9311).
+- Visualization code for showing the pose estimation results. Thanks Depu!
- [2019/08/27] HigherHRNet is now on [ArXiv](https://arxiv.org/abs/1908.10357), which is a bottom-up approach for human pose estimation powerd by HRNet. We will also release code and models at [Higher-HRNet-Human-Pose-Estimation](https://github.com/HRNet/Higher-HRNet-Human-Pose-Estimation), stay tuned!
- Our new work [High-Resolution Representations for Labeling Pixels and Regions](https://arxiv.org/abs/1904.04514) is available at [HRNet](https://github.com/HRNet). Our HRNet has been applied to a wide range of vision tasks, such as [image classification](https://github.com/HRNet/HRNet-Image-Classification), [objection detection](https://github.com/HRNet/HRNet-Object-Detection), [semantic segmentation](https://github.com/HRNet/HRNet-Semantic-Segmentation) and [facial landmark](https://github.com/HRNet/HRNet-Facial-Landmark-Detection).
@@ -219,9 +223,30 @@ python tools/train.py \
--cfg experiments/coco/hrnet/w32_256x192_adam_lr1e-3.yaml \
```
+### Visualization
+
+#### Visualizing predictions on COCO val
+
+```
+python visualization/plot_coco.py \
+ --prediction output/coco/w48_384x288_adam_lr1e-3/results/keypoints_val2017_results_0.json \
+ --save-path visualization/results
+
+```
+
+
+

+
+

### Other applications
-Many other dense prediction tasks, such as segmentation, face alignment and object detection, etc. have been benefited by HRNet. More information can be found at [Deep High-Resolution Representation Learning](https://jingdongwang2017.github.io/Projects/HRNet/).
+Many other dense prediction tasks, such as segmentation, face alignment and object detection, etc. have been benefited by HRNet. More information can be found at [High-Resolution Networks](https://github.com/HRNet).
+
+### Other implementation
+[mmpose](https://github.com/open-mmlab/mmpose)
+[ModelScope (中文)](https://modelscope.cn/models/damo/cv_hrnetv2w32_body-2d-keypoints_image/summary)
+[timm](https://huggingface.co/docs/timm/main/en/models/hrnet)
+
### Citation
If you use our code or models in your research, please cite with:
@@ -239,4 +264,13 @@ If you use our code or models in your research, please cite with:
booktitle = {European Conference on Computer Vision (ECCV)},
year = {2018}
}
+
+@article{WangSCJDZLMTWLX19,
+ title={Deep High-Resolution Representation Learning for Visual Recognition},
+ author={Jingdong Wang and Ke Sun and Tianheng Cheng and
+ Borui Jiang and Chaorui Deng and Yang Zhao and Dong Liu and Yadong Mu and
+ Mingkui Tan and Xinggang Wang and Wenyu Liu and Bin Xiao},
+ journal = {TPAMI}
+ year={2019}
+}
```
diff --git a/_config.yml b/_config.yml
new file mode 100644
index 00000000..c4192631
--- /dev/null
+++ b/_config.yml
@@ -0,0 +1 @@
+theme: jekyll-theme-cayman
\ No newline at end of file
diff --git a/demo/.gitignore b/demo/.gitignore
new file mode 100644
index 00000000..04f267c7
--- /dev/null
+++ b/demo/.gitignore
@@ -0,0 +1,3 @@
+output
+models
+videos
diff --git a/demo/Dockerfile b/demo/Dockerfile
new file mode 100644
index 00000000..d8ceaf4e
--- /dev/null
+++ b/demo/Dockerfile
@@ -0,0 +1,112 @@
+FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu16.04
+
+ENV OPENCV_VERSION="3.4.6"
+
+# Basic toolchain
+RUN apt-get update && apt-get install -y \
+ apt-utils \
+ build-essential \
+ git \
+ wget \
+ unzip \
+ yasm \
+ pkg-config \
+ libcurl4-openssl-dev \
+ zlib1g-dev \
+ htop \
+ cmake \
+ nano \
+ python3-pip \
+ python3-dev \
+ python3-tk \
+ libx264-dev \
+ && cd /usr/local/bin \
+ && ln -s /usr/bin/python3 python \
+ && pip3 install --upgrade pip \
+ && apt-get autoremove -y
+
+# Getting OpenCV dependencies available with apt
+RUN apt-get update && apt-get install -y \
+ libeigen3-dev \
+ libjpeg-dev \
+ libpng-dev \
+ libtiff-dev \
+ libjasper-dev \
+ libswscale-dev \
+ libavcodec-dev \
+ libavformat-dev && \
+ apt-get autoremove -y
+
+# Getting other dependencies
+RUN apt-get update && apt-get install -y \
+ cppcheck \
+ graphviz \
+ doxygen \
+ p7zip-full \
+ libdlib18 \
+ libdlib-dev && \
+ apt-get autoremove -y
+
+
+# Install OpenCV + OpenCV contrib (takes forever)
+RUN mkdir -p /tmp && \
+ cd /tmp && \
+ wget --no-check-certificate -O opencv.zip https://github.com/opencv/opencv/archive/${OPENCV_VERSION}.zip && \
+ wget --no-check-certificate -O opencv_contrib.zip https://github.com/opencv/opencv_contrib/archive/${OPENCV_VERSION}.zip && \
+ unzip opencv.zip && \
+ unzip opencv_contrib.zip && \
+ mkdir opencv-${OPENCV_VERSION}/build && \
+ cd opencv-${OPENCV_VERSION}/build && \
+ cmake -D CMAKE_BUILD_TYPE=RELEASE \
+ -D CMAKE_INSTALL_PREFIX=/usr/local \
+ -D WITH_CUDA=ON \
+ -D CUDA_FAST_MATH=1 \
+ -D WITH_CUBLAS=1 \
+ -D WITH_FFMPEG=ON \
+ -D WITH_OPENCL=ON \
+ -D WITH_V4L=ON \
+ -D WITH_OPENGL=ON \
+ -D OPENCV_EXTRA_MODULES_PATH=/tmp/opencv_contrib-${OPENCV_VERSION}/modules \
+ .. && \
+ make -j$(nproc) && \
+ make install && \
+ echo "/usr/local/lib" > /etc/ld.so.conf.d/opencv.conf && \
+ ldconfig && \
+ cd /tmp && \
+ rm -rf opencv-${OPENCV_VERSION} opencv.zip opencv_contrib-${OPENCV_VERSION} opencv_contrib.zip && \
+ cd /
+
+# Compile and install ffmpeg from source
+RUN git clone https://github.com/FFmpeg/FFmpeg /root/ffmpeg && \
+ cd /root/ffmpeg && \
+ ./configure --enable-gpl --enable-libx264 --enable-nonfree --disable-shared --extra-cflags=-I/usr/local/include && \
+ make -j8 && make install -j8
+
+# clone deep-high-resolution-net
+ARG POSE_ROOT=/pose_root
+RUN git clone https://github.com/leoxiaobin/deep-high-resolution-net.pytorch.git $POSE_ROOT
+WORKDIR $POSE_ROOT
+RUN mkdir output && mkdir log
+
+RUN pip3 install -r requirements.txt && \
+ pip3 install torch==1.1.0 \
+ torchvision==0.3.0 \
+ opencv-python \
+ pillow==6.2.1
+
+# build deep-high-resolution-net lib
+WORKDIR $POSE_ROOT/lib
+RUN make
+
+# install COCO API
+ARG COCOAPI=/cocoapi
+RUN git clone https://github.com/cocodataset/cocoapi.git $COCOAPI
+WORKDIR $COCOAPI/PythonAPI
+# Install into global site-packages
+RUN make install
+
+# download fastrrnn pretrained model for person detection
+RUN python -c "import torchvision; model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True); model.eval()"
+
+COPY inference.py $POSE_ROOT/tools
+COPY inference-config.yaml $POSE_ROOT/
diff --git a/demo/README.md b/demo/README.md
new file mode 100644
index 00000000..aff81f44
--- /dev/null
+++ b/demo/README.md
@@ -0,0 +1,75 @@
+# Inference hrnet
+
+Inferencing the deep-high-resolution-net.pytoch without using Docker.
+
+## Prep
+1. Download the researchers' pretrained pose estimator from [google drive](https://drive.google.com/drive/folders/1hOTihvbyIxsm5ygDpbUuJ7O_tzv4oXjC?usp=sharing) to this directory under `models/`
+2. Put the video file you'd like to infer on in this directory under `videos`
+3. (OPTIONAL) build the docker container in this directory with `./build-docker.sh` (this can take time because it involves compiling opencv)
+4. update the `inference-config.yaml` file to reflect the number of GPUs you have available and which trained model you want to use.
+
+## Running the Model
+### 1. Running on the video
+```
+python demo/inference.py --cfg demo/inference-config.yaml \
+ --videoFile ../../multi_people.mp4 \
+ --writeBoxFrames \
+ --outputDir output \
+ TEST.MODEL_FILE ../models/pytorch/pose_coco/pose_hrnet_w32_256x192.pth
+
+```
+
+The above command will create a video under *output* directory and a lot of pose image under *output/pose* directory.
+Even with usage of GPU (GTX1080 in my case), the person detection will take nearly **0.06 sec**, the person pose match will
+ take nearly **0.07 sec**. In total. inference time per frame will be **0.13 sec**, nearly 10fps. So if you prefer a real-time (fps >= 20)
+ pose estimation then you should try other approach.
+
+**===Result===**
+
+Some output images are as:
+
+
+Fig: 1 person inference
+
+
+Fig: 3 person inference
+
+
+Fig: 3 person inference
+
+### 2. Demo with more common functions
+Remember to update` TEST.MODEL_FILE` in `demo/inference-config.yaml `according to your model path.
+
+`demo.py` provides the following functions:
+
+- use `--webcam` when the input is a real-time camera.
+- use `--video [video-path]` when the input is a video.
+- use `--image [image-path]` when the input is an image.
+- use `--write` to save the image, camera or video result.
+- use `--showFps` to show the fps (this fps includes the detection part).
+- draw connections between joints.
+
+#### (1) the input is a real-time carema
+```python
+python demo/demo.py --webcam --showFps --write
+```
+
+#### (2) the input is a video
+```python
+python demo/demo.py --video test.mp4 --showFps --write
+```
+#### (3) the input is a image
+
+```python
+python demo/demo.py --image test.jpg --showFps --write
+```
+
+**===Result===**
+
+
+
+Fig: show fps
+
+
+
+Fig: multi-people
\ No newline at end of file
diff --git a/demo/_init_paths.py b/demo/_init_paths.py
new file mode 100644
index 00000000..b1aea8fe
--- /dev/null
+++ b/demo/_init_paths.py
@@ -0,0 +1,27 @@
+# ------------------------------------------------------------------------------
+# pose.pytorch
+# Copyright (c) 2018-present Microsoft
+# Licensed under The Apache-2.0 License [see LICENSE for details]
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path as osp
+import sys
+
+
+def add_path(path):
+ if path not in sys.path:
+ sys.path.insert(0, path)
+
+
+this_dir = osp.dirname(__file__)
+
+lib_path = osp.join(this_dir, '..', 'lib')
+add_path(lib_path)
+
+mm_path = osp.join(this_dir, '..', 'lib/poseeval/py-motmetrics')
+add_path(mm_path)
diff --git a/demo/build-docker.sh b/demo/build-docker.sh
new file mode 100755
index 00000000..a4b1aab4
--- /dev/null
+++ b/demo/build-docker.sh
@@ -0,0 +1 @@
+docker build -t hrnet_demo_inference .
diff --git a/demo/demo.py b/demo/demo.py
new file mode 100644
index 00000000..d482e838
--- /dev/null
+++ b/demo/demo.py
@@ -0,0 +1,343 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import csv
+import os
+import shutil
+
+from PIL import Image
+import torch
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.optim
+import torch.utils.data
+import torch.utils.data.distributed
+import torchvision.transforms as transforms
+import torchvision
+import cv2
+import numpy as np
+import time
+
+
+import _init_paths
+import models
+from config import cfg
+from config import update_config
+from core.function import get_final_preds
+from utils.transforms import get_affine_transform
+
+COCO_KEYPOINT_INDEXES = {
+ 0: 'nose',
+ 1: 'left_eye',
+ 2: 'right_eye',
+ 3: 'left_ear',
+ 4: 'right_ear',
+ 5: 'left_shoulder',
+ 6: 'right_shoulder',
+ 7: 'left_elbow',
+ 8: 'right_elbow',
+ 9: 'left_wrist',
+ 10: 'right_wrist',
+ 11: 'left_hip',
+ 12: 'right_hip',
+ 13: 'left_knee',
+ 14: 'right_knee',
+ 15: 'left_ankle',
+ 16: 'right_ankle'
+}
+
+COCO_INSTANCE_CATEGORY_NAMES = [
+ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
+ 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
+ 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
+ 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
+ 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
+ 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
+ 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
+ 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
+ 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
+]
+
+SKELETON = [
+ [1,3],[1,0],[2,4],[2,0],[0,5],[0,6],[5,7],[7,9],[6,8],[8,10],[5,11],[6,12],[11,12],[11,13],[13,15],[12,14],[14,16]
+]
+
+CocoColors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
+
+NUM_KPTS = 17
+
+CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+def draw_pose(keypoints,img):
+ """draw the keypoints and the skeletons.
+ :params keypoints: the shape should be equal to [17,2]
+ :params img:
+ """
+ assert keypoints.shape == (NUM_KPTS,2)
+ for i in range(len(SKELETON)):
+ kpt_a, kpt_b = SKELETON[i][0], SKELETON[i][1]
+ x_a, y_a = keypoints[kpt_a][0],keypoints[kpt_a][1]
+ x_b, y_b = keypoints[kpt_b][0],keypoints[kpt_b][1]
+ cv2.circle(img, (int(x_a), int(y_a)), 6, CocoColors[i], -1)
+ cv2.circle(img, (int(x_b), int(y_b)), 6, CocoColors[i], -1)
+ cv2.line(img, (int(x_a), int(y_a)), (int(x_b), int(y_b)), CocoColors[i], 2)
+
+def draw_bbox(box,img):
+ """draw the detected bounding box on the image.
+ :param img:
+ """
+ cv2.rectangle(img, box[0], box[1], color=(0, 255, 0),thickness=3)
+
+
+def get_person_detection_boxes(model, img, threshold=0.5):
+ pred = model(img)
+ pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i]
+ for i in list(pred[0]['labels'].cpu().numpy())] # Get the Prediction Score
+ pred_boxes = [[(i[0], i[1]), (i[2], i[3])]
+ for i in list(pred[0]['boxes'].detach().cpu().numpy())] # Bounding boxes
+ pred_score = list(pred[0]['scores'].detach().cpu().numpy())
+ if not pred_score or max(pred_score) threshold][-1]
+ pred_boxes = pred_boxes[:pred_t+1]
+ pred_classes = pred_classes[:pred_t+1]
+
+ person_boxes = []
+ for idx, box in enumerate(pred_boxes):
+ if pred_classes[idx] == 'person':
+ person_boxes.append(box)
+
+ return person_boxes
+
+
+def get_pose_estimation_prediction(pose_model, image, center, scale):
+ rotation = 0
+
+ # pose estimation transformation
+ trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE)
+ model_input = cv2.warpAffine(
+ image,
+ trans,
+ (int(cfg.MODEL.IMAGE_SIZE[0]), int(cfg.MODEL.IMAGE_SIZE[1])),
+ flags=cv2.INTER_LINEAR)
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ ])
+
+ # pose estimation inference
+ model_input = transform(model_input).unsqueeze(0)
+ # switch to evaluate mode
+ pose_model.eval()
+ with torch.no_grad():
+ # compute output heatmap
+ output = pose_model(model_input)
+ preds, _ = get_final_preds(
+ cfg,
+ output.clone().cpu().numpy(),
+ np.asarray([center]),
+ np.asarray([scale]))
+
+ return preds
+
+
+def box_to_center_scale(box, model_image_width, model_image_height):
+ """convert a box to center,scale information required for pose transformation
+ Parameters
+ ----------
+ box : list of tuple
+ list of length 2 with two tuples of floats representing
+ bottom left and top right corner of a box
+ model_image_width : int
+ model_image_height : int
+
+ Returns
+ -------
+ (numpy array, numpy array)
+ Two numpy arrays, coordinates for the center of the box and the scale of the box
+ """
+ center = np.zeros((2), dtype=np.float32)
+
+ bottom_left_corner = box[0]
+ top_right_corner = box[1]
+ box_width = top_right_corner[0]-bottom_left_corner[0]
+ box_height = top_right_corner[1]-bottom_left_corner[1]
+ bottom_left_x = bottom_left_corner[0]
+ bottom_left_y = bottom_left_corner[1]
+ center[0] = bottom_left_x + box_width * 0.5
+ center[1] = bottom_left_y + box_height * 0.5
+
+ aspect_ratio = model_image_width * 1.0 / model_image_height
+ pixel_std = 200
+
+ if box_width > aspect_ratio * box_height:
+ box_height = box_width * 1.0 / aspect_ratio
+ elif box_width < aspect_ratio * box_height:
+ box_width = box_height * aspect_ratio
+ scale = np.array(
+ [box_width * 1.0 / pixel_std, box_height * 1.0 / pixel_std],
+ dtype=np.float32)
+ if center[0] != -1:
+ scale = scale * 1.25
+
+ return center, scale
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Train keypoints network')
+ # general
+ parser.add_argument('--cfg', type=str, default='demo/inference-config.yaml')
+ parser.add_argument('--video', type=str)
+ parser.add_argument('--webcam',action='store_true')
+ parser.add_argument('--image',type=str)
+ parser.add_argument('--write',action='store_true')
+ parser.add_argument('--showFps',action='store_true')
+
+ parser.add_argument('opts',
+ help='Modify config options using the command-line',
+ default=None,
+ nargs=argparse.REMAINDER)
+
+ args = parser.parse_args()
+
+ # args expected by supporting codebase
+ args.modelDir = ''
+ args.logDir = ''
+ args.dataDir = ''
+ args.prevModelDir = ''
+ return args
+
+
+def main():
+ # cudnn related setting
+ cudnn.benchmark = cfg.CUDNN.BENCHMARK
+ torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
+ torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
+
+ args = parse_args()
+ update_config(cfg, args)
+
+ box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
+ box_model.to(CTX)
+ box_model.eval()
+
+ pose_model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
+ cfg, is_train=False
+ )
+
+ if cfg.TEST.MODEL_FILE:
+ print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
+ pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
+ else:
+ print('expected model defined in config at TEST.MODEL_FILE')
+
+ pose_model = torch.nn.DataParallel(pose_model, device_ids=cfg.GPUS)
+ pose_model.to(CTX)
+ pose_model.eval()
+
+ # Loading an video or an image or webcam
+ if args.webcam:
+ vidcap = cv2.VideoCapture(0)
+ elif args.video:
+ vidcap = cv2.VideoCapture(args.video)
+ elif args.image:
+ image_bgr = cv2.imread(args.image)
+ else:
+ print('please use --video or --webcam or --image to define the input.')
+ return
+
+ if args.webcam or args.video:
+ if args.write:
+ save_path = 'output.avi'
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
+ out = cv2.VideoWriter(save_path,fourcc, 24.0, (int(vidcap.get(3)),int(vidcap.get(4))))
+ while True:
+ ret, image_bgr = vidcap.read()
+ if ret:
+ last_time = time.time()
+ image = image_bgr[:, :, [2, 1, 0]]
+
+ input = []
+ img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
+ img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().to(CTX)
+ input.append(img_tensor)
+
+ # object detection box
+ pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.9)
+
+ # pose estimation
+ if len(pred_boxes) >= 1:
+ for box in pred_boxes:
+ center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
+ image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy()
+ pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale)
+ if len(pose_preds)>=1:
+ for kpt in pose_preds:
+ draw_pose(kpt,image_bgr) # draw the poses
+
+ if args.showFps:
+ fps = 1/(time.time()-last_time)
+ img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2)
+
+ if args.write:
+ out.write(image_bgr)
+
+ cv2.imshow('demo',image_bgr)
+ if cv2.waitKey(1) & 0XFF==ord('q'):
+ break
+ else:
+ print('cannot load the video.')
+ break
+
+ cv2.destroyAllWindows()
+ vidcap.release()
+ if args.write:
+ print('video has been saved as {}'.format(save_path))
+ out.release()
+
+ else:
+ # estimate on the image
+ last_time = time.time()
+ image = image_bgr[:, :, [2, 1, 0]]
+
+ input = []
+ img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
+ img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().to(CTX)
+ input.append(img_tensor)
+
+ # object detection box
+ pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.9)
+
+ # pose estimation
+ if len(pred_boxes) >= 1:
+ for box in pred_boxes:
+ center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
+ image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy()
+ pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale)
+ if len(pose_preds)>=1:
+ for kpt in pose_preds:
+ draw_pose(kpt,image_bgr) # draw the poses
+
+ if args.showFps:
+ fps = 1/(time.time()-last_time)
+ img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2)
+
+ if args.write:
+ save_path = 'output.jpg'
+ cv2.imwrite(save_path,image_bgr)
+ print('the result image has been saved as {}'.format(save_path))
+
+ cv2.imshow('demo',image_bgr)
+ if cv2.waitKey(0) & 0XFF==ord('q'):
+ cv2.destroyAllWindows()
+
+if __name__ == '__main__':
+ main()
diff --git a/demo/hrnet-demo.gif b/demo/hrnet-demo.gif
new file mode 100644
index 00000000..30e37079
Binary files /dev/null and b/demo/hrnet-demo.gif differ
diff --git a/demo/inference-config.yaml b/demo/inference-config.yaml
new file mode 100644
index 00000000..14bce176
--- /dev/null
+++ b/demo/inference-config.yaml
@@ -0,0 +1,127 @@
+AUTO_RESUME: true
+CUDNN:
+ BENCHMARK: true
+ DETERMINISTIC: false
+ ENABLED: true
+DATA_DIR: ''
+GPUS: (0,)
+OUTPUT_DIR: 'output'
+LOG_DIR: 'log'
+WORKERS: 24
+PRINT_FREQ: 100
+
+DATASET:
+ COLOR_RGB: true
+ DATASET: 'coco'
+ DATA_FORMAT: jpg
+ FLIP: true
+ NUM_JOINTS_HALF_BODY: 8
+ PROB_HALF_BODY: 0.3
+ ROOT: 'data/coco/'
+ ROT_FACTOR: 45
+ SCALE_FACTOR: 0.35
+ TEST_SET: 'val2017'
+ TRAIN_SET: 'train2017'
+MODEL:
+ INIT_WEIGHTS: true
+ NAME: pose_hrnet
+ NUM_JOINTS: 17
+ PRETRAINED: 'models/pytorch/pose_coco/pose_hrnet_w32_384x288.pth'
+ TARGET_TYPE: gaussian
+ IMAGE_SIZE:
+ - 288
+ - 384
+ HEATMAP_SIZE:
+ - 72
+ - 96
+ SIGMA: 3
+ EXTRA:
+ PRETRAINED_LAYERS:
+ - 'conv1'
+ - 'bn1'
+ - 'conv2'
+ - 'bn2'
+ - 'layer1'
+ - 'transition1'
+ - 'stage2'
+ - 'transition2'
+ - 'stage3'
+ - 'transition3'
+ - 'stage4'
+ FINAL_CONV_KERNEL: 1
+ STAGE2:
+ NUM_MODULES: 1
+ NUM_BRANCHES: 2
+ BLOCK: BASIC
+ NUM_BLOCKS:
+ - 4
+ - 4
+ NUM_CHANNELS:
+ - 32
+ - 64
+ FUSE_METHOD: SUM
+ STAGE3:
+ NUM_MODULES: 4
+ NUM_BRANCHES: 3
+ BLOCK: BASIC
+ NUM_BLOCKS:
+ - 4
+ - 4
+ - 4
+ NUM_CHANNELS:
+ - 32
+ - 64
+ - 128
+ FUSE_METHOD: SUM
+ STAGE4:
+ NUM_MODULES: 3
+ NUM_BRANCHES: 4
+ BLOCK: BASIC
+ NUM_BLOCKS:
+ - 4
+ - 4
+ - 4
+ - 4
+ NUM_CHANNELS:
+ - 32
+ - 64
+ - 128
+ - 256
+ FUSE_METHOD: SUM
+LOSS:
+ USE_TARGET_WEIGHT: true
+TRAIN:
+ BATCH_SIZE_PER_GPU: 32
+ SHUFFLE: true
+ BEGIN_EPOCH: 0
+ END_EPOCH: 210
+ OPTIMIZER: adam
+ LR: 0.001
+ LR_FACTOR: 0.1
+ LR_STEP:
+ - 170
+ - 200
+ WD: 0.0001
+ GAMMA1: 0.99
+ GAMMA2: 0.0
+ MOMENTUM: 0.9
+ NESTEROV: false
+TEST:
+ BATCH_SIZE_PER_GPU: 32
+ COCO_BBOX_FILE: 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json'
+ BBOX_THRE: 1.0
+ IMAGE_THRE: 0.0
+ IN_VIS_THRE: 0.2
+ MODEL_FILE: 'models/pytorch/pose_coco/pose_hrnet_w32_384x288.pth'
+ NMS_THRE: 1.0
+ OKS_THRE: 0.9
+ USE_GT_BBOX: true
+ FLIP_TEST: true
+ POST_PROCESS: true
+ SHIFT_HEATMAP: true
+DEBUG:
+ DEBUG: true
+ SAVE_BATCH_IMAGES_GT: true
+ SAVE_BATCH_IMAGES_PRED: true
+ SAVE_HEATMAPS_GT: true
+ SAVE_HEATMAPS_PRED: true
diff --git a/demo/inference.py b/demo/inference.py
new file mode 100644
index 00000000..efff86a7
--- /dev/null
+++ b/demo/inference.py
@@ -0,0 +1,341 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import csv
+import os
+import shutil
+
+from PIL import Image
+import torch
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.optim
+import torch.utils.data
+import torch.utils.data.distributed
+import torchvision.transforms as transforms
+import torchvision
+import cv2
+import numpy as np
+
+import sys
+sys.path.append("../lib")
+import time
+
+# import _init_paths
+import models
+from config import cfg
+from config import update_config
+from core.inference import get_final_preds
+from utils.transforms import get_affine_transform
+
+CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+
+COCO_KEYPOINT_INDEXES = {
+ 0: 'nose',
+ 1: 'left_eye',
+ 2: 'right_eye',
+ 3: 'left_ear',
+ 4: 'right_ear',
+ 5: 'left_shoulder',
+ 6: 'right_shoulder',
+ 7: 'left_elbow',
+ 8: 'right_elbow',
+ 9: 'left_wrist',
+ 10: 'right_wrist',
+ 11: 'left_hip',
+ 12: 'right_hip',
+ 13: 'left_knee',
+ 14: 'right_knee',
+ 15: 'left_ankle',
+ 16: 'right_ankle'
+}
+
+COCO_INSTANCE_CATEGORY_NAMES = [
+ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
+ 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
+ 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
+ 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
+ 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
+ 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
+ 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
+ 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
+ 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
+]
+
+
+def get_person_detection_boxes(model, img, threshold=0.5):
+ pil_image = Image.fromarray(img) # Load the image
+ transform = transforms.Compose([transforms.ToTensor()]) # Defing PyTorch Transform
+ transformed_img = transform(pil_image) # Apply the transform to the image
+ pred = model([transformed_img.to(CTX)]) # Pass the image to the model
+ # Use the first detected person
+ pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i]
+ for i in list(pred[0]['labels'].cpu().numpy())] # Get the Prediction Score
+ pred_boxes = [[(i[0], i[1]), (i[2], i[3])]
+ for i in list(pred[0]['boxes'].cpu().detach().numpy())] # Bounding boxes
+ pred_scores = list(pred[0]['scores'].cpu().detach().numpy())
+
+ person_boxes = []
+ # Select box has score larger than threshold and is person
+ for pred_class, pred_box, pred_score in zip(pred_classes, pred_boxes, pred_scores):
+ if (pred_score > threshold) and (pred_class == 'person'):
+ person_boxes.append(pred_box)
+
+ return person_boxes
+
+
+def get_pose_estimation_prediction(pose_model, image, centers, scales, transform):
+ rotation = 0
+
+ # pose estimation transformation
+ model_inputs = []
+ for center, scale in zip(centers, scales):
+ trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE)
+ # Crop smaller image of people
+ model_input = cv2.warpAffine(
+ image,
+ trans,
+ (int(cfg.MODEL.IMAGE_SIZE[0]), int(cfg.MODEL.IMAGE_SIZE[1])),
+ flags=cv2.INTER_LINEAR)
+
+ # hwc -> 1chw
+ model_input = transform(model_input)#.unsqueeze(0)
+ model_inputs.append(model_input)
+
+ # n * 1chw -> nchw
+ model_inputs = torch.stack(model_inputs)
+
+ # compute output heatmap
+ output = pose_model(model_inputs.to(CTX))
+ coords, _ = get_final_preds(
+ cfg,
+ output.cpu().detach().numpy(),
+ np.asarray(centers),
+ np.asarray(scales))
+
+ return coords
+
+
+def box_to_center_scale(box, model_image_width, model_image_height):
+ """convert a box to center,scale information required for pose transformation
+ Parameters
+ ----------
+ box : list of tuple
+ list of length 2 with two tuples of floats representing
+ bottom left and top right corner of a box
+ model_image_width : int
+ model_image_height : int
+
+ Returns
+ -------
+ (numpy array, numpy array)
+ Two numpy arrays, coordinates for the center of the box and the scale of the box
+ """
+ center = np.zeros((2), dtype=np.float32)
+
+ bottom_left_corner = box[0]
+ top_right_corner = box[1]
+ box_width = top_right_corner[0]-bottom_left_corner[0]
+ box_height = top_right_corner[1]-bottom_left_corner[1]
+ bottom_left_x = bottom_left_corner[0]
+ bottom_left_y = bottom_left_corner[1]
+ center[0] = bottom_left_x + box_width * 0.5
+ center[1] = bottom_left_y + box_height * 0.5
+
+ aspect_ratio = model_image_width * 1.0 / model_image_height
+ pixel_std = 200
+
+ if box_width > aspect_ratio * box_height:
+ box_height = box_width * 1.0 / aspect_ratio
+ elif box_width < aspect_ratio * box_height:
+ box_width = box_height * aspect_ratio
+ scale = np.array(
+ [box_width * 1.0 / pixel_std, box_height * 1.0 / pixel_std],
+ dtype=np.float32)
+ if center[0] != -1:
+ scale = scale * 1.25
+
+ return center, scale
+
+
+def prepare_output_dirs(prefix='/output/'):
+ pose_dir = os.path.join(prefix, "pose")
+ if os.path.exists(pose_dir) and os.path.isdir(pose_dir):
+ shutil.rmtree(pose_dir)
+ os.makedirs(pose_dir, exist_ok=True)
+ return pose_dir
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Train keypoints network')
+ # general
+ parser.add_argument('--cfg', type=str, required=True)
+ parser.add_argument('--videoFile', type=str, required=True)
+ parser.add_argument('--outputDir', type=str, default='/output/')
+ parser.add_argument('--inferenceFps', type=int, default=10)
+ parser.add_argument('--writeBoxFrames', action='store_true')
+
+ parser.add_argument('opts',
+ help='Modify config options using the command-line',
+ default=None,
+ nargs=argparse.REMAINDER)
+
+ args = parser.parse_args()
+
+ # args expected by supporting codebase
+ args.modelDir = ''
+ args.logDir = ''
+ args.dataDir = ''
+ args.prevModelDir = ''
+ return args
+
+
+def main():
+ # transformation
+ pose_transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ ])
+
+ # cudnn related setting
+ cudnn.benchmark = cfg.CUDNN.BENCHMARK
+ torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
+ torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
+
+ args = parse_args()
+ update_config(cfg, args)
+ pose_dir = prepare_output_dirs(args.outputDir)
+ csv_output_rows = []
+
+ box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
+ box_model.to(CTX)
+ box_model.eval()
+ pose_model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
+ cfg, is_train=False
+ )
+
+ if cfg.TEST.MODEL_FILE:
+ print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
+ pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
+ else:
+ print('expected model defined in config at TEST.MODEL_FILE')
+
+ pose_model.to(CTX)
+ pose_model.eval()
+
+ # Loading an video
+ vidcap = cv2.VideoCapture(args.videoFile)
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
+ if fps < args.inferenceFps:
+ print('desired inference fps is '+str(args.inferenceFps)+' but video fps is '+str(fps))
+ exit()
+ skip_frame_cnt = round(fps / args.inferenceFps)
+ frame_width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ frame_height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ outcap = cv2.VideoWriter('{}/{}_pose.avi'.format(args.outputDir, os.path.splitext(os.path.basename(args.videoFile))[0]),
+ cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), int(skip_frame_cnt), (frame_width, frame_height))
+
+ count = 0
+ while vidcap.isOpened():
+ total_now = time.time()
+ ret, image_bgr = vidcap.read()
+ count += 1
+
+ if not ret:
+ continue
+
+ if count % skip_frame_cnt != 0:
+ continue
+
+ image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
+
+ # Clone 2 image for person detection and pose estimation
+ if cfg.DATASET.COLOR_RGB:
+ image_per = image_rgb.copy()
+ image_pose = image_rgb.copy()
+ else:
+ image_per = image_bgr.copy()
+ image_pose = image_bgr.copy()
+
+ # Clone 1 image for debugging purpose
+ image_debug = image_bgr.copy()
+
+ # object detection box
+ now = time.time()
+ pred_boxes = get_person_detection_boxes(box_model, image_per, threshold=0.9)
+ then = time.time()
+ print("Find person bbox in: {} sec".format(then - now))
+
+ # Can not find people. Move to next frame
+ if not pred_boxes:
+ count += 1
+ continue
+
+ if args.writeBoxFrames:
+ for box in pred_boxes:
+ cv2.rectangle(image_debug, box[0], box[1], color=(0, 255, 0),
+ thickness=3) # Draw Rectangle with the coordinates
+
+ # pose estimation : for multiple people
+ centers = []
+ scales = []
+ for box in pred_boxes:
+ center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
+ centers.append(center)
+ scales.append(scale)
+
+ now = time.time()
+ pose_preds = get_pose_estimation_prediction(pose_model, image_pose, centers, scales, transform=pose_transform)
+ then = time.time()
+ print("Find person pose in: {} sec".format(then - now))
+
+ new_csv_row = []
+ for coords in pose_preds:
+ # Draw each point on image
+ for coord in coords:
+ x_coord, y_coord = int(coord[0]), int(coord[1])
+ cv2.circle(image_debug, (x_coord, y_coord), 4, (255, 0, 0), 2)
+ new_csv_row.extend([x_coord, y_coord])
+
+ total_then = time.time()
+
+ text = "{:03.2f} sec".format(total_then - total_now)
+ cv2.putText(image_debug, text, (100, 50), cv2.FONT_HERSHEY_SIMPLEX,
+ 1, (0, 0, 255), 2, cv2.LINE_AA)
+
+ cv2.imshow("pos", image_debug)
+ if cv2.waitKey(1) & 0xFF == ord('q'):
+ break
+
+ csv_output_rows.append(new_csv_row)
+ img_file = os.path.join(pose_dir, 'pose_{:08d}.jpg'.format(count))
+ cv2.imwrite(img_file, image_debug)
+ outcap.write(image_debug)
+
+
+ # write csv
+ csv_headers = ['frame']
+ for keypoint in COCO_KEYPOINT_INDEXES.values():
+ csv_headers.extend([keypoint+'_x', keypoint+'_y'])
+
+ csv_output_filename = os.path.join(args.outputDir, 'pose-data.csv')
+ with open(csv_output_filename, 'w', newline='') as csvfile:
+ csvwriter = csv.writer(csvfile)
+ csvwriter.writerow(csv_headers)
+ csvwriter.writerows(csv_output_rows)
+
+ vidcap.release()
+ outcap.release()
+
+ cv2.destroyAllWindows()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/demo/inference_1.jpg b/demo/inference_1.jpg
new file mode 100644
index 00000000..2ca29d1b
Binary files /dev/null and b/demo/inference_1.jpg differ
diff --git a/demo/inference_3.jpg b/demo/inference_3.jpg
new file mode 100644
index 00000000..7f20915c
Binary files /dev/null and b/demo/inference_3.jpg differ
diff --git a/demo/inference_5.jpg b/demo/inference_5.jpg
new file mode 100644
index 00000000..d7b7117c
Binary files /dev/null and b/demo/inference_5.jpg differ
diff --git a/demo/inference_6.jpg b/demo/inference_6.jpg
new file mode 100644
index 00000000..cc1183c0
Binary files /dev/null and b/demo/inference_6.jpg differ
diff --git a/demo/inference_7.jpg b/demo/inference_7.jpg
new file mode 100644
index 00000000..9629b300
Binary files /dev/null and b/demo/inference_7.jpg differ
diff --git a/figures/visualization/coco/score_610_id_2685_000000002685.png b/figures/visualization/coco/score_610_id_2685_000000002685.png
new file mode 100644
index 00000000..615ff258
Binary files /dev/null and b/figures/visualization/coco/score_610_id_2685_000000002685.png differ
diff --git a/figures/visualization/coco/score_710_id_153229_000000153229.png b/figures/visualization/coco/score_710_id_153229_000000153229.png
new file mode 100644
index 00000000..61d73e3a
Binary files /dev/null and b/figures/visualization/coco/score_710_id_153229_000000153229.png differ
diff --git a/figures/visualization/coco/score_755_id_343561_000000343561.png b/figures/visualization/coco/score_755_id_343561_000000343561.png
new file mode 100644
index 00000000..f114bd42
Binary files /dev/null and b/figures/visualization/coco/score_755_id_343561_000000343561.png differ
diff --git a/figures/visualization/coco/score_755_id_559842_000000559842.png b/figures/visualization/coco/score_755_id_559842_000000559842.png
new file mode 100644
index 00000000..7123e65e
Binary files /dev/null and b/figures/visualization/coco/score_755_id_559842_000000559842.png differ
diff --git a/figures/visualization/coco/score_770_id_6954_000000006954.png b/figures/visualization/coco/score_770_id_6954_000000006954.png
new file mode 100644
index 00000000..caba54b0
Binary files /dev/null and b/figures/visualization/coco/score_770_id_6954_000000006954.png differ
diff --git a/figures/visualization/coco/score_919_id_53626_000000053626.png b/figures/visualization/coco/score_919_id_53626_000000053626.png
new file mode 100644
index 00000000..3efcd62a
Binary files /dev/null and b/figures/visualization/coco/score_919_id_53626_000000053626.png differ
diff --git a/lib/core/function.py b/lib/core/function.py
index dadff0fa..1bc19daa 100755
--- a/lib/core/function.py
+++ b/lib/core/function.py
@@ -124,10 +124,7 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
output = outputs
if config.TEST.FLIP_TEST:
- # this part is ugly, because pytorch has not supported negative index
- # input_flipped = model(input[:, :, :, ::-1])
- input_flipped = np.flip(input.cpu().numpy(), 3).copy()
- input_flipped = torch.from_numpy(input_flipped).cuda()
+ input_flipped = input.flip(3)
outputs_flipped = model(input_flipped)
if isinstance(outputs_flipped, list):
diff --git a/lib/models/pose_hrnet.py b/lib/models/pose_hrnet.py
index ea65f419..09ff346a 100644
--- a/lib/models/pose_hrnet.py
+++ b/lib/models/pose_hrnet.py
@@ -275,7 +275,7 @@ class PoseHighResolutionNet(nn.Module):
def __init__(self, cfg, **kwargs):
self.inplanes = 64
- extra = cfg.MODEL.EXTRA
+ extra = cfg['MODEL']['EXTRA']
super(PoseHighResolutionNet, self).__init__()
# stem net
@@ -288,7 +288,7 @@ def __init__(self, cfg, **kwargs):
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(Bottleneck, 64, 4)
- self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2']
+ self.stage2_cfg = extra['STAGE2']
num_channels = self.stage2_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage2_cfg['BLOCK']]
num_channels = [
@@ -298,7 +298,7 @@ def __init__(self, cfg, **kwargs):
self.stage2, pre_stage_channels = self._make_stage(
self.stage2_cfg, num_channels)
- self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3']
+ self.stage3_cfg = extra['STAGE3']
num_channels = self.stage3_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage3_cfg['BLOCK']]
num_channels = [
@@ -309,7 +309,7 @@ def __init__(self, cfg, **kwargs):
self.stage3, pre_stage_channels = self._make_stage(
self.stage3_cfg, num_channels)
- self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4']
+ self.stage4_cfg = extra['STAGE4']
num_channels = self.stage4_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage4_cfg['BLOCK']]
num_channels = [
@@ -322,13 +322,13 @@ def __init__(self, cfg, **kwargs):
self.final_layer = nn.Conv2d(
in_channels=pre_stage_channels[0],
- out_channels=cfg.MODEL.NUM_JOINTS,
- kernel_size=extra.FINAL_CONV_KERNEL,
+ out_channels=cfg['MODEL']['NUM_JOINTS'],
+ kernel_size=extra['FINAL_CONV_KERNEL'],
stride=1,
- padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0
+ padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0
)
- self.pretrained_layers = cfg['MODEL']['EXTRA']['PRETRAINED_LAYERS']
+ self.pretrained_layers = extra['PRETRAINED_LAYERS']
def _make_transition_layer(
self, num_channels_pre_layer, num_channels_cur_layer):
@@ -495,7 +495,7 @@ def init_weights(self, pretrained=''):
def get_pose_net(cfg, is_train, **kwargs):
model = PoseHighResolutionNet(cfg, **kwargs)
- if is_train and cfg.MODEL.INIT_WEIGHTS:
- model.init_weights(cfg.MODEL.PRETRAINED)
+ if is_train and cfg['MODEL']['INIT_WEIGHTS']:
+ model.init_weights(cfg['MODEL']['PRETRAINED'])
return model
diff --git a/requirements.txt b/requirements.txt
index 18c9ee11..14f225c7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,4 +8,4 @@ pyyaml
json_tricks
scikit-image
yacs>=0.1.5
-tensorboardX>=1.6
+tensorboardX==1.6
diff --git a/visualization/plot_coco.py b/visualization/plot_coco.py
new file mode 100644
index 00000000..c0e79039
--- /dev/null
+++ b/visualization/plot_coco.py
@@ -0,0 +1,309 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Ke Sun (sunk@mail.ustc.edu.cn)
+# Modified by Depu Meng (mdp@mail.ustc.edu.cn)
+# ------------------------------------------------------------------------------
+
+import argparse
+import numpy as np
+import matplotlib.pyplot as plt
+import cv2
+import json
+import matplotlib.lines as mlines
+import matplotlib.patches as mpatches
+from pycocotools.coco import COCO
+from pycocotools.cocoeval import COCOeval
+import os
+
+
+class ColorStyle:
+ def __init__(self, color, link_pairs, point_color):
+ self.color = color
+ self.link_pairs = link_pairs
+ self.point_color = point_color
+
+ for i in range(len(self.color)):
+ self.link_pairs[i].append(tuple(np.array(self.color[i])/255.))
+
+ self.ring_color = []
+ for i in range(len(self.point_color)):
+ self.ring_color.append(tuple(np.array(self.point_color[i])/255.))
+
+# Xiaochu Style
+# (R,G,B)
+color1 = [(179,0,0),(228,26,28),(255,255,51),
+ (49,163,84), (0,109,45), (255,255,51),
+ (240,2,127),(240,2,127),(240,2,127), (240,2,127), (240,2,127),
+ (217,95,14), (254,153,41),(255,255,51),
+ (44,127,184),(0,0,255)]
+
+link_pairs1 = [
+ [15, 13], [13, 11], [11, 5],
+ [12, 14], [14, 16], [12, 6],
+ [3, 1],[1, 2],[1, 0],[0, 2],[2,4],
+ [9, 7], [7,5], [5, 6],
+ [6, 8], [8, 10],
+ ]
+
+point_color1 = [(240,2,127),(240,2,127),(240,2,127),
+ (240,2,127), (240,2,127),
+ (255,255,51),(255,255,51),
+ (254,153,41),(44,127,184),
+ (217,95,14),(0,0,255),
+ (255,255,51),(255,255,51),(228,26,28),
+ (49,163,84),(252,176,243),(0,176,240),
+ (255,255,0),(169, 209, 142),
+ (255,255,0),(169, 209, 142),
+ (255,255,0),(169, 209, 142)]
+
+xiaochu_style = ColorStyle(color1, link_pairs1, point_color1)
+
+
+# Chunhua Style
+# (R,G,B)
+color2 = [(252,176,243),(252,176,243),(252,176,243),
+ (0,176,240), (0,176,240), (0,176,240),
+ (240,2,127),(240,2,127),(240,2,127), (240,2,127), (240,2,127),
+ (255,255,0), (255,255,0),(169, 209, 142),
+ (169, 209, 142),(169, 209, 142)]
+
+link_pairs2 = [
+ [15, 13], [13, 11], [11, 5],
+ [12, 14], [14, 16], [12, 6],
+ [3, 1],[1, 2],[1, 0],[0, 2],[2,4],
+ [9, 7], [7,5], [5, 6], [6, 8], [8, 10],
+ ]
+
+point_color2 = [(240,2,127),(240,2,127),(240,2,127),
+ (240,2,127), (240,2,127),
+ (255,255,0),(169, 209, 142),
+ (255,255,0),(169, 209, 142),
+ (255,255,0),(169, 209, 142),
+ (252,176,243),(0,176,240),(252,176,243),
+ (0,176,240),(252,176,243),(0,176,240),
+ (255,255,0),(169, 209, 142),
+ (255,255,0),(169, 209, 142),
+ (255,255,0),(169, 209, 142)]
+
+chunhua_style = ColorStyle(color2, link_pairs2, point_color2)
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Visualize COCO predictions')
+ # general
+ parser.add_argument('--image-path',
+ help='Path of COCO val images',
+ type=str,
+ default='data/coco/images/val2017/'
+ )
+
+ parser.add_argument('--gt-anno',
+ help='Path of COCO val annotation',
+ type=str,
+ default='data/coco/annotations/person_keypoints_val2017.json'
+ )
+
+ parser.add_argument('--save-path',
+ help="Path to save the visualizations",
+ type=str,
+ default='visualization/coco/')
+
+ parser.add_argument('--prediction',
+ help="Prediction file to visualize",
+ type=str,
+ required=True)
+
+ parser.add_argument('--style',
+ help="Style of the visualization: Chunhua style or Xiaochu style",
+ type=str,
+ default='chunhua')
+
+ args = parser.parse_args()
+
+ return args
+
+
+def map_joint_dict(joints):
+ joints_dict = {}
+ for i in range(joints.shape[0]):
+ x = int(joints[i][0])
+ y = int(joints[i][1])
+ id = i
+ joints_dict[id] = (x, y)
+
+ return joints_dict
+
+def plot(data, gt_file, img_path, save_path,
+ link_pairs, ring_color, save=True):
+
+ # joints
+ coco = COCO(gt_file)
+ coco_dt = coco.loadRes(data)
+ coco_eval = COCOeval(coco, coco_dt, 'keypoints')
+ coco_eval._prepare()
+ gts_ = coco_eval._gts
+ dts_ = coco_eval._dts
+
+ p = coco_eval.params
+ p.imgIds = list(np.unique(p.imgIds))
+ if p.useCats:
+ p.catIds = list(np.unique(p.catIds))
+ p.maxDets = sorted(p.maxDets)
+
+ # loop through images, area range, max detection number
+ catIds = p.catIds if p.useCats else [-1]
+ threshold = 0.3
+ joint_thres = 0.2
+ for catId in catIds:
+ for imgId in p.imgIds[:5000]:
+ # dimention here should be Nxm
+ gts = gts_[imgId, catId]
+ dts = dts_[imgId, catId]
+ inds = np.argsort([-d['score'] for d in dts], kind='mergesort')
+ dts = [dts[i] for i in inds]
+ if len(dts) > p.maxDets[-1]:
+ dts = dts[0:p.maxDets[-1]]
+ if len(gts) == 0 or len(dts) == 0:
+ continue
+
+ sum_score = 0
+ num_box = 0
+ img_name = str(imgId).zfill(12)
+
+ # Read Images
+ img_file = img_path + img_name + '.jpg'
+ data_numpy = cv2.imread(img_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
+ h = data_numpy.shape[0]
+ w = data_numpy.shape[1]
+
+ # Plot
+ fig = plt.figure(figsize=(w/100, h/100), dpi=100)
+ ax = plt.subplot(1,1,1)
+ bk = plt.imshow(data_numpy[:,:,::-1])
+ bk.set_zorder(-1)
+ print(img_name)
+ for j, gt in enumerate(gts):
+ # matching dt_box and gt_box
+ bb = gt['bbox']
+ x0 = bb[0] - bb[2]; x1 = bb[0] + bb[2] * 2
+ y0 = bb[1] - bb[3]; y1 = bb[1] + bb[3] * 2
+
+ # create bounds for ignore regions(double the gt bbox)
+ g = np.array(gt['keypoints'])
+ #xg = g[0::3]; yg = g[1::3];
+ vg = g[2::3]
+
+ for i, dt in enumerate(dts):
+ # Calculate IoU
+ dt_bb = dt['bbox']
+ dt_x0 = dt_bb[0] - dt_bb[2]; dt_x1 = dt_bb[0] + dt_bb[2] * 2
+ dt_y0 = dt_bb[1] - dt_bb[3]; dt_y1 = dt_bb[1] + dt_bb[3] * 2
+
+ ol_x = min(x1, dt_x1) - max(x0, dt_x0)
+ ol_y = min(y1, dt_y1) - max(y0, dt_y0)
+ ol_area = ol_x * ol_y
+ s_x = max(x1, dt_x1) - min(x0, dt_x0)
+ s_y = max(y1, dt_y1) - min(y0, dt_y0)
+ sum_area = s_x * s_y
+ iou = ol_area / (sum_area + np.spacing(1))
+ score = dt['score']
+
+ if iou < 0.1 or score < threshold:
+ continue
+ else:
+ print('iou: ', iou)
+ dt_w = dt_x1 - dt_x0
+ dt_h = dt_y1 - dt_y0
+ ref = min(dt_w, dt_h)
+ num_box += 1
+ sum_score += dt['score']
+ dt_joints = np.array(dt['keypoints']).reshape(17,-1)
+ joints_dict = map_joint_dict(dt_joints)
+
+ # stick
+ for k, link_pair in enumerate(link_pairs):
+ if link_pair[0] in joints_dict \
+ and link_pair[1] in joints_dict:
+ if dt_joints[link_pair[0],2] < joint_thres \
+ or dt_joints[link_pair[1],2] < joint_thres \
+ or vg[link_pair[0]] == 0 \
+ or vg[link_pair[1]] == 0:
+ continue
+ if k in range(6,11):
+ lw = 1
+ else:
+ lw = ref / 100.
+ line = mlines.Line2D(
+ np.array([joints_dict[link_pair[0]][0],
+ joints_dict[link_pair[1]][0]]),
+ np.array([joints_dict[link_pair[0]][1],
+ joints_dict[link_pair[1]][1]]),
+ ls='-', lw=lw, alpha=1, color=link_pair[2],)
+ line.set_zorder(0)
+ ax.add_line(line)
+ # black ring
+ for k in range(dt_joints.shape[0]):
+ if dt_joints[k,2] < joint_thres \
+ or vg[link_pair[0]] == 0 \
+ or vg[link_pair[1]] == 0:
+ continue
+ if dt_joints[k,0] > w or dt_joints[k,1] > h:
+ continue
+ if k in range(5):
+ radius = 1
+ else:
+ radius = ref / 100
+
+ circle = mpatches.Circle(tuple(dt_joints[k,:2]),
+ radius=radius,
+ ec='black',
+ fc=ring_color[k],
+ alpha=1,
+ linewidth=1)
+ circle.set_zorder(1)
+ ax.add_patch(circle)
+
+ avg_score = (sum_score / (num_box+np.spacing(1)))*1000
+
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
+ plt.axis('off')
+ plt.subplots_adjust(top=1,bottom=0,left=0,right=1,hspace=0,wspace=0)
+ plt.margins(0,0)
+ if save:
+ plt.savefig(save_path + \
+ 'score_'+str(np.int(avg_score))+ \
+ '_id_'+str(imgId)+ \
+ '_'+img_name + '.png',
+ format='png', bbox_inckes='tight', dpi=100)
+ plt.savefig(save_path +'id_'+str(imgId)+ '.pdf', format='pdf',
+ bbox_inckes='tight', dpi=100)
+ # plt.show()
+ plt.close()
+
+if __name__ == '__main__':
+
+ args = parse_args()
+ if args.style == 'xiaochu':
+ # Xiaochu Style
+ colorstyle = xiaochu_style
+ elif args.style == 'chunhua':
+ # Chunhua Style
+ colorstyle = chunhua_style
+ else:
+ raise Exception('Invalid color style')
+
+ save_path = args.save_path
+ img_path = args.image_path
+ if not os.path.exists(save_path):
+ try:
+ os.makedirs(save_path)
+ except Exception:
+ print('Fail to make {}'.format(save_path))
+
+
+ with open(args.prediction) as f:
+ data = json.load(f)
+ gt_file = args.gt_anno
+ plot(data, gt_file, img_path, save_path, colorstyle.link_pairs, colorstyle.ring_color, save=True)
+