diff --git a/README.md b/README.md
index fb2b5631..40b7720e 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,12 @@
-# Deep High-Resolution Representation Learning for Human Pose Estimation(accepted to CVPR2019)
+# Deep High-Resolution Representation Learning for Human Pose Estimation (CVPR 2019)
+## News
+- [2021/04/12] Welcome to check out our recent work on bottom-up pose estimation (CVPR 2021) [HRNet-DEKR](https://github.com/HRNet/DEKR)!
+- [2020/07/05] [A very nice blog](https://towardsdatascience.com/overview-of-human-pose-estimation-neural-networks-hrnet-higherhrnet-architectures-and-faq-1954b2f8b249) from Towards Data Science introducing HRNet and HigherHRNet for human pose estimation.
+- [2020/03/13] A longer version is accepted by TPAMI: [Deep High-Resolution Representation Learning for Visual Recognition](https://arxiv.org/pdf/1908.07919.pdf). It includes more HRNet applications, and the codes are available: [semantic segmentation](https://github.com/HRNet/HRNet-Semantic-Segmentation),  [objection detection](https://github.com/HRNet/HRNet-Object-Detection),  [facial landmark detection](https://github.com/HRNet/HRNet-Facial-Landmark-Detection), and [image classification](https://github.com/HRNet/HRNet-Image-Classification).
+- [2020/02/01] We have added demo code for HRNet. Thanks [Alex Simes](https://github.com/alex9311). 
+- Visualization code for showing the pose estimation results. Thanks Depu!
+- [2019/08/27] HigherHRNet is now on [ArXiv](https://arxiv.org/abs/1908.10357), which is a bottom-up approach for human pose estimation powerd by HRNet. We will also release code and models at [Higher-HRNet-Human-Pose-Estimation](https://github.com/HRNet/Higher-HRNet-Human-Pose-Estimation), stay tuned!
+- Our new work [High-Resolution Representations for Labeling Pixels and Regions](https://arxiv.org/abs/1904.04514) is available at [HRNet](https://github.com/HRNet). Our HRNet has been applied to a wide range of vision tasks, such as [image classification](https://github.com/HRNet/HRNet-Image-Classification), [objection detection](https://github.com/HRNet/HRNet-Object-Detection), [semantic segmentation](https://github.com/HRNet/HRNet-Semantic-Segmentation) and [facial landmark](https://github.com/HRNet/HRNet-Facial-Landmark-Detection).
 
 ## Introduction
 This is an official pytorch implementation of [*Deep High-Resolution Representation Learning for Human Pose Estimation*](https://arxiv.org/abs/1902.09212). 
@@ -215,16 +223,37 @@ python tools/train.py \
     --cfg experiments/coco/hrnet/w32_256x192_adam_lr1e-3.yaml \
 ```
 
+### Visualization
+
+#### Visualizing predictions on COCO val
+
+```
+python visualization/plot_coco.py \
+    --prediction output/coco/w48_384x288_adam_lr1e-3/results/keypoints_val2017_results_0.json \
+    --save-path visualization/results
+
+```
+
+
+<img src="figures\visualization\coco\score_610_id_2685_000000002685.png" height="215"><img src="figures\visualization\coco\score_710_id_153229_000000153229.png" height="215"><img src="figures\visualization\coco\score_755_id_343561_000000343561.png" height="215">
+
+<img src="figures\visualization\coco\score_755_id_559842_000000559842.png" height="209"><img src="figures\visualization\coco\score_770_id_6954_000000006954.png" height="209"><img src="figures\visualization\coco\score_919_id_53626_000000053626.png" height="209">
 
 ### Other applications
-Many other dense prediction tasks, such as segmentation, face alignment and object detection, etc. have been benefited by HRNet. More information can be found at [Deep High-Resolution Representation Learning](https://jingdongwang2017.github.io/Projects/HRNet/).
+Many other dense prediction tasks, such as segmentation, face alignment and object detection, etc. have been benefited by HRNet. More information can be found at [High-Resolution Networks](https://github.com/HRNet).
+
+### Other implementation
+[mmpose](https://github.com/open-mmlab/mmpose) </br>
+[ModelScope (中文)](https://modelscope.cn/models/damo/cv_hrnetv2w32_body-2d-keypoints_image/summary)</br>
+[timm](https://huggingface.co/docs/timm/main/en/models/hrnet)
+
 
 ### Citation
 If you use our code or models in your research, please cite with:
 ```
-@inproceedings{SunXLWang2019,
+@inproceedings{sun2019deep,
   title={Deep High-Resolution Representation Learning for Human Pose Estimation},
-  author={Ke Sun, Bin Xiao, Dong Liu, and Jingdong Wang},
+  author={Sun, Ke and Xiao, Bin and Liu, Dong and Wang, Jingdong},
   booktitle={CVPR},
   year={2019}
 }
@@ -235,4 +264,13 @@ If you use our code or models in your research, please cite with:
     booktitle = {European Conference on Computer Vision (ECCV)},
     year = {2018}
 }
+
+@article{WangSCJDZLMTWLX19,
+  title={Deep High-Resolution Representation Learning for Visual Recognition},
+  author={Jingdong Wang and Ke Sun and Tianheng Cheng and 
+          Borui Jiang and Chaorui Deng and Yang Zhao and Dong Liu and Yadong Mu and 
+          Mingkui Tan and Xinggang Wang and Wenyu Liu and Bin Xiao},
+  journal   = {TPAMI}
+  year={2019}
+}
 ```
diff --git a/_config.yml b/_config.yml
new file mode 100644
index 00000000..c4192631
--- /dev/null
+++ b/_config.yml
@@ -0,0 +1 @@
+theme: jekyll-theme-cayman
\ No newline at end of file
diff --git a/demo/.gitignore b/demo/.gitignore
new file mode 100644
index 00000000..04f267c7
--- /dev/null
+++ b/demo/.gitignore
@@ -0,0 +1,3 @@
+output
+models
+videos
diff --git a/demo/Dockerfile b/demo/Dockerfile
new file mode 100644
index 00000000..d8ceaf4e
--- /dev/null
+++ b/demo/Dockerfile
@@ -0,0 +1,112 @@
+FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu16.04
+
+ENV OPENCV_VERSION="3.4.6"
+
+# Basic toolchain
+RUN apt-get update && apt-get install -y \
+        apt-utils \
+        build-essential \
+        git \
+        wget \
+        unzip \
+        yasm \
+        pkg-config \
+        libcurl4-openssl-dev \
+        zlib1g-dev \
+        htop \
+        cmake \
+        nano \
+        python3-pip \
+        python3-dev \
+        python3-tk \
+        libx264-dev \
+    && cd /usr/local/bin \
+    && ln -s /usr/bin/python3 python \
+    && pip3 install --upgrade pip \
+    && apt-get autoremove -y
+
+# Getting OpenCV dependencies available with apt
+RUN apt-get update && apt-get install -y \
+        libeigen3-dev \
+        libjpeg-dev \
+        libpng-dev \
+        libtiff-dev \
+        libjasper-dev \
+        libswscale-dev \
+        libavcodec-dev \
+        libavformat-dev && \
+    apt-get autoremove -y
+
+# Getting other dependencies
+RUN apt-get update && apt-get install -y \
+        cppcheck \
+        graphviz \
+        doxygen \
+        p7zip-full \
+        libdlib18 \
+        libdlib-dev && \
+    apt-get autoremove -y
+
+
+# Install OpenCV + OpenCV contrib (takes forever)
+RUN mkdir -p /tmp && \
+    cd /tmp && \
+    wget --no-check-certificate -O opencv.zip https://github.com/opencv/opencv/archive/${OPENCV_VERSION}.zip && \
+    wget --no-check-certificate -O opencv_contrib.zip https://github.com/opencv/opencv_contrib/archive/${OPENCV_VERSION}.zip && \
+    unzip opencv.zip && \
+    unzip opencv_contrib.zip && \
+    mkdir opencv-${OPENCV_VERSION}/build && \
+    cd opencv-${OPENCV_VERSION}/build && \
+    cmake -D CMAKE_BUILD_TYPE=RELEASE \
+        -D CMAKE_INSTALL_PREFIX=/usr/local \
+        -D WITH_CUDA=ON \
+        -D CUDA_FAST_MATH=1 \
+        -D WITH_CUBLAS=1 \
+        -D WITH_FFMPEG=ON \
+        -D WITH_OPENCL=ON \
+        -D WITH_V4L=ON \
+        -D WITH_OPENGL=ON \
+        -D OPENCV_EXTRA_MODULES_PATH=/tmp/opencv_contrib-${OPENCV_VERSION}/modules \
+        .. && \
+    make -j$(nproc) && \
+    make install && \
+    echo "/usr/local/lib" > /etc/ld.so.conf.d/opencv.conf && \
+    ldconfig && \
+    cd /tmp && \
+    rm -rf opencv-${OPENCV_VERSION} opencv.zip opencv_contrib-${OPENCV_VERSION} opencv_contrib.zip && \
+    cd /
+
+# Compile and install ffmpeg from source
+RUN git clone https://github.com/FFmpeg/FFmpeg /root/ffmpeg && \
+    cd /root/ffmpeg && \
+    ./configure --enable-gpl --enable-libx264 --enable-nonfree --disable-shared --extra-cflags=-I/usr/local/include && \
+    make -j8 && make install -j8
+
+# clone deep-high-resolution-net
+ARG POSE_ROOT=/pose_root
+RUN git clone https://github.com/leoxiaobin/deep-high-resolution-net.pytorch.git $POSE_ROOT
+WORKDIR $POSE_ROOT
+RUN mkdir output && mkdir log
+
+RUN pip3 install -r requirements.txt && \
+    pip3 install torch==1.1.0 \
+    torchvision==0.3.0 \
+    opencv-python \
+    pillow==6.2.1
+
+# build deep-high-resolution-net lib
+WORKDIR $POSE_ROOT/lib
+RUN make
+
+# install COCO API
+ARG COCOAPI=/cocoapi
+RUN git clone https://github.com/cocodataset/cocoapi.git $COCOAPI
+WORKDIR $COCOAPI/PythonAPI
+# Install into global site-packages
+RUN make install
+
+# download fastrrnn pretrained model for person detection
+RUN python -c "import torchvision; model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True); model.eval()"
+
+COPY inference.py $POSE_ROOT/tools
+COPY inference-config.yaml $POSE_ROOT/
diff --git a/demo/README.md b/demo/README.md
new file mode 100644
index 00000000..aff81f44
--- /dev/null
+++ b/demo/README.md
@@ -0,0 +1,75 @@
+# Inference hrnet
+
+Inferencing the deep-high-resolution-net.pytoch without using Docker. 
+
+## Prep
+1. Download the researchers' pretrained pose estimator from [google drive](https://drive.google.com/drive/folders/1hOTihvbyIxsm5ygDpbUuJ7O_tzv4oXjC?usp=sharing) to this directory under `models/`
+2. Put the video file you'd like to infer on in this directory under `videos`
+3. (OPTIONAL) build the docker container in this directory with `./build-docker.sh` (this can take time because it involves compiling opencv)
+4. update the `inference-config.yaml` file to reflect the number of GPUs you have available and which trained model you want to use.
+
+## Running the Model
+### 1. Running on the video
+```
+python demo/inference.py --cfg demo/inference-config.yaml \
+    --videoFile ../../multi_people.mp4 \
+    --writeBoxFrames \
+    --outputDir output \
+    TEST.MODEL_FILE ../models/pytorch/pose_coco/pose_hrnet_w32_256x192.pth 
+
+```
+
+The above command will create a video under *output* directory and a lot of pose image under *output/pose* directory. 
+Even with usage of GPU (GTX1080 in my case), the person detection will take nearly **0.06 sec**, the person pose match will
+ take nearly **0.07 sec**. In total. inference time per frame will be **0.13 sec**, nearly 10fps. So if you prefer a real-time (fps >= 20) 
+ pose estimation then you should try other approach.
+
+**===Result===**
+
+Some output images are as:
+
+![1 person](inference_1.jpg)
+Fig: 1 person inference
+
+![3 person](inference_3.jpg)
+Fig: 3 person inference
+
+![3 person](inference_5.jpg)
+Fig: 3 person inference
+
+### 2. Demo with more common functions
+Remember to update` TEST.MODEL_FILE` in `demo/inference-config.yaml `according to your model path.
+
+`demo.py` provides the following functions:
+
+- use `--webcam` when the input is a real-time camera.
+- use `--video [video-path]`  when the input is a video.
+- use `--image [image-path]` when the input is an image.
+- use `--write` to save the image, camera or video result.
+- use `--showFps` to show the fps (this fps includes the detection part).
+- draw connections between joints.
+
+#### (1) the input is a real-time carema
+```python
+python demo/demo.py --webcam --showFps --write
+```
+
+#### (2) the input is a video
+```python
+python demo/demo.py --video test.mp4 --showFps --write
+```
+#### (3) the input is a image
+
+```python
+python demo/demo.py --image test.jpg --showFps --write
+```
+
+**===Result===**
+
+![show_fps](inference_6.jpg)
+
+Fig: show fps
+
+![multi-people](inference_7.jpg)
+
+Fig: multi-people
\ No newline at end of file
diff --git a/demo/_init_paths.py b/demo/_init_paths.py
new file mode 100644
index 00000000..b1aea8fe
--- /dev/null
+++ b/demo/_init_paths.py
@@ -0,0 +1,27 @@
+# ------------------------------------------------------------------------------
+# pose.pytorch
+# Copyright (c) 2018-present Microsoft
+# Licensed under The Apache-2.0 License [see LICENSE for details]
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path as osp
+import sys
+
+
+def add_path(path):
+    if path not in sys.path:
+        sys.path.insert(0, path)
+
+
+this_dir = osp.dirname(__file__)
+
+lib_path = osp.join(this_dir, '..', 'lib')
+add_path(lib_path)
+
+mm_path = osp.join(this_dir, '..', 'lib/poseeval/py-motmetrics')
+add_path(mm_path)
diff --git a/demo/build-docker.sh b/demo/build-docker.sh
new file mode 100755
index 00000000..a4b1aab4
--- /dev/null
+++ b/demo/build-docker.sh
@@ -0,0 +1 @@
+docker build -t hrnet_demo_inference .
diff --git a/demo/demo.py b/demo/demo.py
new file mode 100644
index 00000000..d482e838
--- /dev/null
+++ b/demo/demo.py
@@ -0,0 +1,343 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import csv
+import os
+import shutil
+
+from PIL import Image
+import torch
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.optim
+import torch.utils.data
+import torch.utils.data.distributed
+import torchvision.transforms as transforms
+import torchvision
+import cv2
+import numpy as np
+import time
+
+
+import _init_paths
+import models
+from config import cfg
+from config import update_config
+from core.function import get_final_preds
+from utils.transforms import get_affine_transform
+
+COCO_KEYPOINT_INDEXES = {
+    0: 'nose',
+    1: 'left_eye',
+    2: 'right_eye',
+    3: 'left_ear',
+    4: 'right_ear',
+    5: 'left_shoulder',
+    6: 'right_shoulder',
+    7: 'left_elbow',
+    8: 'right_elbow',
+    9: 'left_wrist',
+    10: 'right_wrist',
+    11: 'left_hip',
+    12: 'right_hip',
+    13: 'left_knee',
+    14: 'right_knee',
+    15: 'left_ankle',
+    16: 'right_ankle'
+}
+
+COCO_INSTANCE_CATEGORY_NAMES = [
+    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
+    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
+    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
+    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
+    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
+    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
+    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
+    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
+    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
+    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
+    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
+    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
+]
+
+SKELETON = [
+    [1,3],[1,0],[2,4],[2,0],[0,5],[0,6],[5,7],[7,9],[6,8],[8,10],[5,11],[6,12],[11,12],[11,13],[13,15],[12,14],[14,16]
+]
+
+CocoColors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
+              [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
+              [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
+
+NUM_KPTS = 17
+
+CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+def draw_pose(keypoints,img):
+    """draw the keypoints and the skeletons.
+    :params keypoints: the shape should be equal to [17,2]
+    :params img:
+    """
+    assert keypoints.shape == (NUM_KPTS,2)
+    for i in range(len(SKELETON)):
+        kpt_a, kpt_b = SKELETON[i][0], SKELETON[i][1]
+        x_a, y_a = keypoints[kpt_a][0],keypoints[kpt_a][1]
+        x_b, y_b = keypoints[kpt_b][0],keypoints[kpt_b][1] 
+        cv2.circle(img, (int(x_a), int(y_a)), 6, CocoColors[i], -1)
+        cv2.circle(img, (int(x_b), int(y_b)), 6, CocoColors[i], -1)
+        cv2.line(img, (int(x_a), int(y_a)), (int(x_b), int(y_b)), CocoColors[i], 2)
+
+def draw_bbox(box,img):
+    """draw the detected bounding box on the image.
+    :param img:
+    """
+    cv2.rectangle(img, box[0], box[1], color=(0, 255, 0),thickness=3)
+
+
+def get_person_detection_boxes(model, img, threshold=0.5):
+    pred = model(img)
+    pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i]
+                    for i in list(pred[0]['labels'].cpu().numpy())]  # Get the Prediction Score
+    pred_boxes = [[(i[0], i[1]), (i[2], i[3])]
+                  for i in list(pred[0]['boxes'].detach().cpu().numpy())]  # Bounding boxes
+    pred_score = list(pred[0]['scores'].detach().cpu().numpy())
+    if not pred_score or max(pred_score)<threshold:
+        return []
+    # Get list of index with score greater than threshold
+    pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]
+    pred_boxes = pred_boxes[:pred_t+1]
+    pred_classes = pred_classes[:pred_t+1]
+
+    person_boxes = []
+    for idx, box in enumerate(pred_boxes):
+        if pred_classes[idx] == 'person':
+            person_boxes.append(box)
+
+    return person_boxes
+
+
+def get_pose_estimation_prediction(pose_model, image, center, scale):
+    rotation = 0
+
+    # pose estimation transformation
+    trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE)
+    model_input = cv2.warpAffine(
+        image,
+        trans,
+        (int(cfg.MODEL.IMAGE_SIZE[0]), int(cfg.MODEL.IMAGE_SIZE[1])),
+        flags=cv2.INTER_LINEAR)
+    transform = transforms.Compose([
+        transforms.ToTensor(),
+        transforms.Normalize(mean=[0.485, 0.456, 0.406],
+                             std=[0.229, 0.224, 0.225]),
+    ])
+
+    # pose estimation inference
+    model_input = transform(model_input).unsqueeze(0)
+    # switch to evaluate mode
+    pose_model.eval()
+    with torch.no_grad():
+        # compute output heatmap
+        output = pose_model(model_input)
+        preds, _ = get_final_preds(
+            cfg,
+            output.clone().cpu().numpy(),
+            np.asarray([center]),
+            np.asarray([scale]))
+
+        return preds
+
+
+def box_to_center_scale(box, model_image_width, model_image_height):
+    """convert a box to center,scale information required for pose transformation
+    Parameters
+    ----------
+    box : list of tuple
+        list of length 2 with two tuples of floats representing
+        bottom left and top right corner of a box
+    model_image_width : int
+    model_image_height : int
+
+    Returns
+    -------
+    (numpy array, numpy array)
+        Two numpy arrays, coordinates for the center of the box and the scale of the box
+    """
+    center = np.zeros((2), dtype=np.float32)
+
+    bottom_left_corner = box[0]
+    top_right_corner = box[1]
+    box_width = top_right_corner[0]-bottom_left_corner[0]
+    box_height = top_right_corner[1]-bottom_left_corner[1]
+    bottom_left_x = bottom_left_corner[0]
+    bottom_left_y = bottom_left_corner[1]
+    center[0] = bottom_left_x + box_width * 0.5
+    center[1] = bottom_left_y + box_height * 0.5
+
+    aspect_ratio = model_image_width * 1.0 / model_image_height
+    pixel_std = 200
+
+    if box_width > aspect_ratio * box_height:
+        box_height = box_width * 1.0 / aspect_ratio
+    elif box_width < aspect_ratio * box_height:
+        box_width = box_height * aspect_ratio
+    scale = np.array(
+        [box_width * 1.0 / pixel_std, box_height * 1.0 / pixel_std],
+        dtype=np.float32)
+    if center[0] != -1:
+        scale = scale * 1.25
+
+    return center, scale
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Train keypoints network')
+    # general
+    parser.add_argument('--cfg', type=str, default='demo/inference-config.yaml')
+    parser.add_argument('--video', type=str)
+    parser.add_argument('--webcam',action='store_true')
+    parser.add_argument('--image',type=str)
+    parser.add_argument('--write',action='store_true')
+    parser.add_argument('--showFps',action='store_true')
+
+    parser.add_argument('opts',
+                        help='Modify config options using the command-line',
+                        default=None,
+                        nargs=argparse.REMAINDER)
+
+    args = parser.parse_args()
+
+    # args expected by supporting codebase  
+    args.modelDir = ''
+    args.logDir = ''
+    args.dataDir = ''
+    args.prevModelDir = ''
+    return args
+
+
+def main():
+    # cudnn related setting
+    cudnn.benchmark = cfg.CUDNN.BENCHMARK
+    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
+    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
+
+    args = parse_args()
+    update_config(cfg, args)
+
+    box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
+    box_model.to(CTX)
+    box_model.eval()
+
+    pose_model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
+        cfg, is_train=False
+    )
+
+    if cfg.TEST.MODEL_FILE:
+        print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
+        pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
+    else:
+        print('expected model defined in config at TEST.MODEL_FILE')
+
+    pose_model = torch.nn.DataParallel(pose_model, device_ids=cfg.GPUS)
+    pose_model.to(CTX)
+    pose_model.eval()
+
+    # Loading an video or an image or webcam 
+    if args.webcam:
+        vidcap = cv2.VideoCapture(0)
+    elif args.video:
+        vidcap = cv2.VideoCapture(args.video)
+    elif args.image:
+        image_bgr = cv2.imread(args.image)
+    else:
+        print('please use --video or --webcam or --image to define the input.')
+        return 
+
+    if args.webcam or args.video:
+        if args.write:
+            save_path = 'output.avi'
+            fourcc = cv2.VideoWriter_fourcc(*'XVID')
+            out = cv2.VideoWriter(save_path,fourcc, 24.0, (int(vidcap.get(3)),int(vidcap.get(4))))
+        while True:
+            ret, image_bgr = vidcap.read()
+            if ret:
+                last_time = time.time()
+                image = image_bgr[:, :, [2, 1, 0]]
+
+                input = []
+                img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
+                img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().to(CTX)
+                input.append(img_tensor)
+
+                # object detection box
+                pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.9)
+
+                # pose estimation
+                if len(pred_boxes) >= 1:
+                    for box in pred_boxes:
+                        center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
+                        image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy()
+                        pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale)
+                        if len(pose_preds)>=1:
+                            for kpt in pose_preds:
+                                draw_pose(kpt,image_bgr) # draw the poses
+
+                if args.showFps:
+                    fps = 1/(time.time()-last_time)
+                    img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2)
+
+                if args.write:
+                    out.write(image_bgr)
+
+                cv2.imshow('demo',image_bgr)
+                if cv2.waitKey(1) & 0XFF==ord('q'):
+                    break
+            else:
+                print('cannot load the video.')
+                break
+
+        cv2.destroyAllWindows()
+        vidcap.release()
+        if args.write:
+            print('video has been saved as {}'.format(save_path))
+            out.release()
+
+    else:
+        # estimate on the image
+        last_time = time.time()
+        image = image_bgr[:, :, [2, 1, 0]]
+
+        input = []
+        img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
+        img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().to(CTX)
+        input.append(img_tensor)
+
+        # object detection box
+        pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.9)
+
+        # pose estimation
+        if len(pred_boxes) >= 1:
+            for box in pred_boxes:
+                center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
+                image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy()
+                pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale)
+                if len(pose_preds)>=1:
+                    for kpt in pose_preds:
+                        draw_pose(kpt,image_bgr) # draw the poses
+        
+        if args.showFps:
+            fps = 1/(time.time()-last_time)
+            img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2)
+        
+        if args.write:
+            save_path = 'output.jpg'
+            cv2.imwrite(save_path,image_bgr)
+            print('the result image has been saved as {}'.format(save_path))
+
+        cv2.imshow('demo',image_bgr)
+        if cv2.waitKey(0) & 0XFF==ord('q'):
+            cv2.destroyAllWindows()
+        
+if __name__ == '__main__':
+    main()
diff --git a/demo/hrnet-demo.gif b/demo/hrnet-demo.gif
new file mode 100644
index 00000000..30e37079
Binary files /dev/null and b/demo/hrnet-demo.gif differ
diff --git a/demo/inference-config.yaml b/demo/inference-config.yaml
new file mode 100644
index 00000000..14bce176
--- /dev/null
+++ b/demo/inference-config.yaml
@@ -0,0 +1,127 @@
+AUTO_RESUME: true
+CUDNN:
+  BENCHMARK: true
+  DETERMINISTIC: false
+  ENABLED: true
+DATA_DIR: ''
+GPUS: (0,)
+OUTPUT_DIR: 'output'
+LOG_DIR: 'log'
+WORKERS: 24
+PRINT_FREQ: 100
+
+DATASET:
+  COLOR_RGB: true
+  DATASET: 'coco'
+  DATA_FORMAT: jpg
+  FLIP: true
+  NUM_JOINTS_HALF_BODY: 8
+  PROB_HALF_BODY: 0.3
+  ROOT: 'data/coco/'
+  ROT_FACTOR: 45
+  SCALE_FACTOR: 0.35
+  TEST_SET: 'val2017'
+  TRAIN_SET: 'train2017'
+MODEL:
+  INIT_WEIGHTS: true
+  NAME: pose_hrnet
+  NUM_JOINTS: 17
+  PRETRAINED: 'models/pytorch/pose_coco/pose_hrnet_w32_384x288.pth'
+  TARGET_TYPE: gaussian
+  IMAGE_SIZE:
+  - 288
+  - 384
+  HEATMAP_SIZE:
+  - 72
+  - 96
+  SIGMA: 3
+  EXTRA:
+    PRETRAINED_LAYERS:
+    - 'conv1'
+    - 'bn1'
+    - 'conv2'
+    - 'bn2'
+    - 'layer1'
+    - 'transition1'
+    - 'stage2'
+    - 'transition2'
+    - 'stage3'
+    - 'transition3'
+    - 'stage4'
+    FINAL_CONV_KERNEL: 1
+    STAGE2:
+      NUM_MODULES: 1
+      NUM_BRANCHES: 2
+      BLOCK: BASIC
+      NUM_BLOCKS:
+      - 4
+      - 4
+      NUM_CHANNELS:
+      - 32
+      - 64
+      FUSE_METHOD: SUM
+    STAGE3:
+      NUM_MODULES: 4
+      NUM_BRANCHES: 3
+      BLOCK: BASIC
+      NUM_BLOCKS:
+      - 4
+      - 4
+      - 4
+      NUM_CHANNELS:
+      - 32
+      - 64
+      - 128
+      FUSE_METHOD: SUM
+    STAGE4:
+      NUM_MODULES: 3
+      NUM_BRANCHES: 4
+      BLOCK: BASIC
+      NUM_BLOCKS:
+      - 4
+      - 4
+      - 4
+      - 4
+      NUM_CHANNELS:
+      - 32
+      - 64
+      - 128
+      - 256
+      FUSE_METHOD: SUM
+LOSS:
+  USE_TARGET_WEIGHT: true
+TRAIN:
+  BATCH_SIZE_PER_GPU: 32
+  SHUFFLE: true
+  BEGIN_EPOCH: 0
+  END_EPOCH: 210
+  OPTIMIZER: adam
+  LR: 0.001
+  LR_FACTOR: 0.1
+  LR_STEP:
+  - 170
+  - 200
+  WD: 0.0001
+  GAMMA1: 0.99
+  GAMMA2: 0.0
+  MOMENTUM: 0.9
+  NESTEROV: false
+TEST:
+  BATCH_SIZE_PER_GPU: 32
+  COCO_BBOX_FILE: 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json'
+  BBOX_THRE: 1.0
+  IMAGE_THRE: 0.0
+  IN_VIS_THRE: 0.2
+  MODEL_FILE: 'models/pytorch/pose_coco/pose_hrnet_w32_384x288.pth'
+  NMS_THRE: 1.0
+  OKS_THRE: 0.9
+  USE_GT_BBOX: true
+  FLIP_TEST: true
+  POST_PROCESS: true
+  SHIFT_HEATMAP: true
+DEBUG:
+  DEBUG: true
+  SAVE_BATCH_IMAGES_GT: true
+  SAVE_BATCH_IMAGES_PRED: true
+  SAVE_HEATMAPS_GT: true
+  SAVE_HEATMAPS_PRED: true
diff --git a/demo/inference.py b/demo/inference.py
new file mode 100644
index 00000000..efff86a7
--- /dev/null
+++ b/demo/inference.py
@@ -0,0 +1,341 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import csv
+import os
+import shutil
+
+from PIL import Image
+import torch
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.optim
+import torch.utils.data
+import torch.utils.data.distributed
+import torchvision.transforms as transforms
+import torchvision
+import cv2
+import numpy as np
+
+import sys
+sys.path.append("../lib")
+import time
+
+# import _init_paths
+import models
+from config import cfg
+from config import update_config
+from core.inference import get_final_preds
+from utils.transforms import get_affine_transform
+
+CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+
+COCO_KEYPOINT_INDEXES = {
+    0: 'nose',
+    1: 'left_eye',
+    2: 'right_eye',
+    3: 'left_ear',
+    4: 'right_ear',
+    5: 'left_shoulder',
+    6: 'right_shoulder',
+    7: 'left_elbow',
+    8: 'right_elbow',
+    9: 'left_wrist',
+    10: 'right_wrist',
+    11: 'left_hip',
+    12: 'right_hip',
+    13: 'left_knee',
+    14: 'right_knee',
+    15: 'left_ankle',
+    16: 'right_ankle'
+}
+
+COCO_INSTANCE_CATEGORY_NAMES = [
+    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
+    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
+    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
+    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
+    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
+    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
+    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
+    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
+    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
+    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
+    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
+    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
+]
+
+
+def get_person_detection_boxes(model, img, threshold=0.5):
+    pil_image = Image.fromarray(img)  # Load the image
+    transform = transforms.Compose([transforms.ToTensor()])  # Defing PyTorch Transform
+    transformed_img = transform(pil_image)  # Apply the transform to the image
+    pred = model([transformed_img.to(CTX)])  # Pass the image to the model
+    # Use the first detected person
+    pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i]
+                    for i in list(pred[0]['labels'].cpu().numpy())]  # Get the Prediction Score
+    pred_boxes = [[(i[0], i[1]), (i[2], i[3])]
+                  for i in list(pred[0]['boxes'].cpu().detach().numpy())]  # Bounding boxes
+    pred_scores = list(pred[0]['scores'].cpu().detach().numpy())
+
+    person_boxes = []
+    # Select box has score larger than threshold and is person
+    for pred_class, pred_box, pred_score in zip(pred_classes, pred_boxes, pred_scores):
+        if (pred_score > threshold) and (pred_class == 'person'):
+            person_boxes.append(pred_box)
+
+    return person_boxes
+
+
+def get_pose_estimation_prediction(pose_model, image, centers, scales, transform):
+    rotation = 0
+
+    # pose estimation transformation
+    model_inputs = []
+    for center, scale in zip(centers, scales):
+        trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE)
+        # Crop smaller image of people
+        model_input = cv2.warpAffine(
+            image,
+            trans,
+            (int(cfg.MODEL.IMAGE_SIZE[0]), int(cfg.MODEL.IMAGE_SIZE[1])),
+            flags=cv2.INTER_LINEAR)
+
+        # hwc -> 1chw
+        model_input = transform(model_input)#.unsqueeze(0)
+        model_inputs.append(model_input)
+
+    # n * 1chw -> nchw
+    model_inputs = torch.stack(model_inputs)
+
+    # compute output heatmap
+    output = pose_model(model_inputs.to(CTX))
+    coords, _ = get_final_preds(
+        cfg,
+        output.cpu().detach().numpy(),
+        np.asarray(centers),
+        np.asarray(scales))
+
+    return coords
+
+
+def box_to_center_scale(box, model_image_width, model_image_height):
+    """convert a box to center,scale information required for pose transformation
+    Parameters
+    ----------
+    box : list of tuple
+        list of length 2 with two tuples of floats representing
+        bottom left and top right corner of a box
+    model_image_width : int
+    model_image_height : int
+
+    Returns
+    -------
+    (numpy array, numpy array)
+        Two numpy arrays, coordinates for the center of the box and the scale of the box
+    """
+    center = np.zeros((2), dtype=np.float32)
+
+    bottom_left_corner = box[0]
+    top_right_corner = box[1]
+    box_width = top_right_corner[0]-bottom_left_corner[0]
+    box_height = top_right_corner[1]-bottom_left_corner[1]
+    bottom_left_x = bottom_left_corner[0]
+    bottom_left_y = bottom_left_corner[1]
+    center[0] = bottom_left_x + box_width * 0.5
+    center[1] = bottom_left_y + box_height * 0.5
+
+    aspect_ratio = model_image_width * 1.0 / model_image_height
+    pixel_std = 200
+
+    if box_width > aspect_ratio * box_height:
+        box_height = box_width * 1.0 / aspect_ratio
+    elif box_width < aspect_ratio * box_height:
+        box_width = box_height * aspect_ratio
+    scale = np.array(
+        [box_width * 1.0 / pixel_std, box_height * 1.0 / pixel_std],
+        dtype=np.float32)
+    if center[0] != -1:
+        scale = scale * 1.25
+
+    return center, scale
+
+
+def prepare_output_dirs(prefix='/output/'):
+    pose_dir = os.path.join(prefix, "pose")
+    if os.path.exists(pose_dir) and os.path.isdir(pose_dir):
+        shutil.rmtree(pose_dir)
+    os.makedirs(pose_dir, exist_ok=True)
+    return pose_dir
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Train keypoints network')
+    # general
+    parser.add_argument('--cfg', type=str, required=True)
+    parser.add_argument('--videoFile', type=str, required=True)
+    parser.add_argument('--outputDir', type=str, default='/output/')
+    parser.add_argument('--inferenceFps', type=int, default=10)
+    parser.add_argument('--writeBoxFrames', action='store_true')
+
+    parser.add_argument('opts',
+                        help='Modify config options using the command-line',
+                        default=None,
+                        nargs=argparse.REMAINDER)
+
+    args = parser.parse_args()
+
+    # args expected by supporting codebase
+    args.modelDir = ''
+    args.logDir = ''
+    args.dataDir = ''
+    args.prevModelDir = ''
+    return args
+
+
+def main():
+    # transformation
+    pose_transform = transforms.Compose([
+        transforms.ToTensor(),
+        transforms.Normalize(mean=[0.485, 0.456, 0.406],
+                             std=[0.229, 0.224, 0.225]),
+    ])
+
+    # cudnn related setting
+    cudnn.benchmark = cfg.CUDNN.BENCHMARK
+    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
+    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
+
+    args = parse_args()
+    update_config(cfg, args)
+    pose_dir = prepare_output_dirs(args.outputDir)
+    csv_output_rows = []
+
+    box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
+    box_model.to(CTX)
+    box_model.eval()
+    pose_model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
+        cfg, is_train=False
+    )
+
+    if cfg.TEST.MODEL_FILE:
+        print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
+        pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
+    else:
+        print('expected model defined in config at TEST.MODEL_FILE')
+
+    pose_model.to(CTX)
+    pose_model.eval()
+
+    # Loading an video
+    vidcap = cv2.VideoCapture(args.videoFile)
+    fps = vidcap.get(cv2.CAP_PROP_FPS)
+    if fps < args.inferenceFps:
+        print('desired inference fps is '+str(args.inferenceFps)+' but video fps is '+str(fps))
+        exit()
+    skip_frame_cnt = round(fps / args.inferenceFps)
+    frame_width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
+    frame_height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+    outcap = cv2.VideoWriter('{}/{}_pose.avi'.format(args.outputDir, os.path.splitext(os.path.basename(args.videoFile))[0]),
+                             cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), int(skip_frame_cnt), (frame_width, frame_height))
+
+    count = 0
+    while vidcap.isOpened():
+        total_now = time.time()
+        ret, image_bgr = vidcap.read()
+        count += 1
+
+        if not ret:
+            continue
+
+        if count % skip_frame_cnt != 0:
+            continue
+
+        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
+
+        # Clone 2 image for person detection and pose estimation
+        if cfg.DATASET.COLOR_RGB:
+            image_per = image_rgb.copy()
+            image_pose = image_rgb.copy()
+        else:
+            image_per = image_bgr.copy()
+            image_pose = image_bgr.copy()
+
+        # Clone 1 image for debugging purpose
+        image_debug = image_bgr.copy()
+
+        # object detection box
+        now = time.time()
+        pred_boxes = get_person_detection_boxes(box_model, image_per, threshold=0.9)
+        then = time.time()
+        print("Find person bbox in: {} sec".format(then - now))
+
+        # Can not find people. Move to next frame
+        if not pred_boxes:
+            count += 1
+            continue
+
+        if args.writeBoxFrames:
+            for box in pred_boxes:
+                cv2.rectangle(image_debug, box[0], box[1], color=(0, 255, 0),
+                              thickness=3)  # Draw Rectangle with the coordinates
+
+        # pose estimation : for multiple people
+        centers = []
+        scales = []
+        for box in pred_boxes:
+            center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
+            centers.append(center)
+            scales.append(scale)
+
+        now = time.time()
+        pose_preds = get_pose_estimation_prediction(pose_model, image_pose, centers, scales, transform=pose_transform)
+        then = time.time()
+        print("Find person pose in: {} sec".format(then - now))
+
+        new_csv_row = []
+        for coords in pose_preds:
+            # Draw each point on image
+            for coord in coords:
+                x_coord, y_coord = int(coord[0]), int(coord[1])
+                cv2.circle(image_debug, (x_coord, y_coord), 4, (255, 0, 0), 2)
+                new_csv_row.extend([x_coord, y_coord])
+
+        total_then = time.time()
+
+        text = "{:03.2f} sec".format(total_then - total_now)
+        cv2.putText(image_debug, text, (100, 50), cv2.FONT_HERSHEY_SIMPLEX,
+                            1, (0, 0, 255), 2, cv2.LINE_AA)
+
+        cv2.imshow("pos", image_debug)
+        if cv2.waitKey(1) & 0xFF == ord('q'):
+            break
+
+        csv_output_rows.append(new_csv_row)
+        img_file = os.path.join(pose_dir, 'pose_{:08d}.jpg'.format(count))
+        cv2.imwrite(img_file, image_debug)
+        outcap.write(image_debug)
+
+
+    # write csv
+    csv_headers = ['frame']
+    for keypoint in COCO_KEYPOINT_INDEXES.values():
+        csv_headers.extend([keypoint+'_x', keypoint+'_y'])
+
+    csv_output_filename = os.path.join(args.outputDir, 'pose-data.csv')
+    with open(csv_output_filename, 'w', newline='') as csvfile:
+        csvwriter = csv.writer(csvfile)
+        csvwriter.writerow(csv_headers)
+        csvwriter.writerows(csv_output_rows)
+
+    vidcap.release()
+    outcap.release()
+
+    cv2.destroyAllWindows()
+
+
+if __name__ == '__main__':
+    main()
diff --git a/demo/inference_1.jpg b/demo/inference_1.jpg
new file mode 100644
index 00000000..2ca29d1b
Binary files /dev/null and b/demo/inference_1.jpg differ
diff --git a/demo/inference_3.jpg b/demo/inference_3.jpg
new file mode 100644
index 00000000..7f20915c
Binary files /dev/null and b/demo/inference_3.jpg differ
diff --git a/demo/inference_5.jpg b/demo/inference_5.jpg
new file mode 100644
index 00000000..d7b7117c
Binary files /dev/null and b/demo/inference_5.jpg differ
diff --git a/demo/inference_6.jpg b/demo/inference_6.jpg
new file mode 100644
index 00000000..cc1183c0
Binary files /dev/null and b/demo/inference_6.jpg differ
diff --git a/demo/inference_7.jpg b/demo/inference_7.jpg
new file mode 100644
index 00000000..9629b300
Binary files /dev/null and b/demo/inference_7.jpg differ
diff --git a/figures/visualization/coco/score_610_id_2685_000000002685.png b/figures/visualization/coco/score_610_id_2685_000000002685.png
new file mode 100644
index 00000000..615ff258
Binary files /dev/null and b/figures/visualization/coco/score_610_id_2685_000000002685.png differ
diff --git a/figures/visualization/coco/score_710_id_153229_000000153229.png b/figures/visualization/coco/score_710_id_153229_000000153229.png
new file mode 100644
index 00000000..61d73e3a
Binary files /dev/null and b/figures/visualization/coco/score_710_id_153229_000000153229.png differ
diff --git a/figures/visualization/coco/score_755_id_343561_000000343561.png b/figures/visualization/coco/score_755_id_343561_000000343561.png
new file mode 100644
index 00000000..f114bd42
Binary files /dev/null and b/figures/visualization/coco/score_755_id_343561_000000343561.png differ
diff --git a/figures/visualization/coco/score_755_id_559842_000000559842.png b/figures/visualization/coco/score_755_id_559842_000000559842.png
new file mode 100644
index 00000000..7123e65e
Binary files /dev/null and b/figures/visualization/coco/score_755_id_559842_000000559842.png differ
diff --git a/figures/visualization/coco/score_770_id_6954_000000006954.png b/figures/visualization/coco/score_770_id_6954_000000006954.png
new file mode 100644
index 00000000..caba54b0
Binary files /dev/null and b/figures/visualization/coco/score_770_id_6954_000000006954.png differ
diff --git a/figures/visualization/coco/score_919_id_53626_000000053626.png b/figures/visualization/coco/score_919_id_53626_000000053626.png
new file mode 100644
index 00000000..3efcd62a
Binary files /dev/null and b/figures/visualization/coco/score_919_id_53626_000000053626.png differ
diff --git a/lib/core/function.py b/lib/core/function.py
index dadff0fa..1bc19daa 100755
--- a/lib/core/function.py
+++ b/lib/core/function.py
@@ -124,10 +124,7 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
                 output = outputs
 
             if config.TEST.FLIP_TEST:
-                # this part is ugly, because pytorch has not supported negative index
-                # input_flipped = model(input[:, :, :, ::-1])
-                input_flipped = np.flip(input.cpu().numpy(), 3).copy()
-                input_flipped = torch.from_numpy(input_flipped).cuda()
+                input_flipped = input.flip(3)
                 outputs_flipped = model(input_flipped)
 
                 if isinstance(outputs_flipped, list):
diff --git a/lib/models/pose_hrnet.py b/lib/models/pose_hrnet.py
index ea65f419..09ff346a 100644
--- a/lib/models/pose_hrnet.py
+++ b/lib/models/pose_hrnet.py
@@ -275,7 +275,7 @@ class PoseHighResolutionNet(nn.Module):
 
     def __init__(self, cfg, **kwargs):
         self.inplanes = 64
-        extra = cfg.MODEL.EXTRA
+        extra = cfg['MODEL']['EXTRA']
         super(PoseHighResolutionNet, self).__init__()
 
         # stem net
@@ -288,7 +288,7 @@ def __init__(self, cfg, **kwargs):
         self.relu = nn.ReLU(inplace=True)
         self.layer1 = self._make_layer(Bottleneck, 64, 4)
 
-        self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2']
+        self.stage2_cfg = extra['STAGE2']
         num_channels = self.stage2_cfg['NUM_CHANNELS']
         block = blocks_dict[self.stage2_cfg['BLOCK']]
         num_channels = [
@@ -298,7 +298,7 @@ def __init__(self, cfg, **kwargs):
         self.stage2, pre_stage_channels = self._make_stage(
             self.stage2_cfg, num_channels)
 
-        self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3']
+        self.stage3_cfg = extra['STAGE3']
         num_channels = self.stage3_cfg['NUM_CHANNELS']
         block = blocks_dict[self.stage3_cfg['BLOCK']]
         num_channels = [
@@ -309,7 +309,7 @@ def __init__(self, cfg, **kwargs):
         self.stage3, pre_stage_channels = self._make_stage(
             self.stage3_cfg, num_channels)
 
-        self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4']
+        self.stage4_cfg = extra['STAGE4']
         num_channels = self.stage4_cfg['NUM_CHANNELS']
         block = blocks_dict[self.stage4_cfg['BLOCK']]
         num_channels = [
@@ -322,13 +322,13 @@ def __init__(self, cfg, **kwargs):
 
         self.final_layer = nn.Conv2d(
             in_channels=pre_stage_channels[0],
-            out_channels=cfg.MODEL.NUM_JOINTS,
-            kernel_size=extra.FINAL_CONV_KERNEL,
+            out_channels=cfg['MODEL']['NUM_JOINTS'],
+            kernel_size=extra['FINAL_CONV_KERNEL'],
             stride=1,
-            padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0
+            padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0
         )
 
-        self.pretrained_layers = cfg['MODEL']['EXTRA']['PRETRAINED_LAYERS']
+        self.pretrained_layers = extra['PRETRAINED_LAYERS']
 
     def _make_transition_layer(
             self, num_channels_pre_layer, num_channels_cur_layer):
@@ -495,7 +495,7 @@ def init_weights(self, pretrained=''):
 def get_pose_net(cfg, is_train, **kwargs):
     model = PoseHighResolutionNet(cfg, **kwargs)
 
-    if is_train and cfg.MODEL.INIT_WEIGHTS:
-        model.init_weights(cfg.MODEL.PRETRAINED)
+    if is_train and cfg['MODEL']['INIT_WEIGHTS']:
+        model.init_weights(cfg['MODEL']['PRETRAINED'])
 
     return model
diff --git a/lib/utils/utils.py b/lib/utils/utils.py
index 9561baa3..5c31ca10 100644
--- a/lib/utils/utils.py
+++ b/lib/utils/utils.py
@@ -150,6 +150,7 @@ def hook(module, input, output):
            and module != model:
             hooks.append(module.register_forward_hook(hook))
 
+    model.eval()
     model.apply(add_hooks)
 
     space_len = item_length
diff --git a/requirements.txt b/requirements.txt
index 18c9ee11..14f225c7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,4 +8,4 @@ pyyaml
 json_tricks
 scikit-image
 yacs>=0.1.5
-tensorboardX>=1.6
+tensorboardX==1.6
diff --git a/tools/test.py b/tools/test.py
index b32ee746..cfa333f5 100755
--- a/tools/test.py
+++ b/tools/test.py
@@ -45,7 +45,6 @@ def parse_args():
                         default=None,
                         nargs=argparse.REMAINDER)
 
-    # philly
     parser.add_argument('--modelDir',
                         help='model directory',
                         type=str,
@@ -67,37 +66,10 @@ def parse_args():
     return args
 
 
-def copy_prev_models(prev_models_dir, model_dir):
-    import shutil
-
-    vc_folder = '/hdfs/' \
-        + '/' + os.environ['PHILLY_VC']
-    source = prev_models_dir
-    # If path is set as "sys/jobs/application_1533861538020_2366/models" prefix with the location of vc folder
-    source = vc_folder + '/' + source if not source.startswith(vc_folder) \
-        else source
-    destination = model_dir
-
-    if os.path.exists(source) and os.path.exists(destination):
-        for file in os.listdir(source):
-            source_file = os.path.join(source, file)
-            destination_file = os.path.join(destination, file)
-            if not os.path.exists(destination_file):
-                print("=> copying {0} to {1}".format(
-                    source_file, destination_file))
-                shutil.copytree(source_file, destination_file)
-    else:
-        print('=> {} or {} does not exist'.format(source, destination))
-
-
 def main():
     args = parse_args()
     update_config(cfg, args)
 
-    if args.prevModelDir and args.modelDir:
-        # copy pre models for philly
-        copy_prev_models(args.prevModelDir, args.modelDir)
-
     logger, final_output_dir, tb_log_dir = create_logger(
         cfg, args.cfg, 'valid')
 
diff --git a/tools/train.py b/tools/train.py
index 50e8273f..039c5487 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -73,37 +73,10 @@ def parse_args():
     return args
 
 
-def copy_prev_models(prev_models_dir, model_dir):
-    import shutil
-
-    vc_folder = '/hdfs/' \
-        + '/' + os.environ['PHILLY_VC']
-    source = prev_models_dir
-    # If path is set as "sys/jobs/application_1533861538020_2366/models" prefix with the location of vc folder
-    source = vc_folder + '/' + source if not source.startswith(vc_folder) \
-        else source
-    destination = model_dir
-
-    if os.path.exists(source) and os.path.exists(destination):
-        for file in os.listdir(source):
-            source_file = os.path.join(source, file)
-            destination_file = os.path.join(destination, file)
-            if not os.path.exists(destination_file):
-                print("=> copying {0} to {1}".format(
-                    source_file, destination_file))
-            shutil.copytree(source_file, destination_file)
-    else:
-        print('=> {} or {} does not exist'.format(source, destination))
-
-
 def main():
     args = parse_args()
     update_config(cfg, args)
 
-    if args.prevModelDir and args.modelDir:
-        # copy pre models for philly
-        copy_prev_models(args.prevModelDir, args.modelDir)
-
     logger, final_output_dir, tb_log_dir = create_logger(
         cfg, args.cfg, 'train')
 
diff --git a/visualization/plot_coco.py b/visualization/plot_coco.py
new file mode 100644
index 00000000..c0e79039
--- /dev/null
+++ b/visualization/plot_coco.py
@@ -0,0 +1,309 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Ke Sun (sunk@mail.ustc.edu.cn)
+# Modified by Depu Meng (mdp@mail.ustc.edu.cn)
+# ------------------------------------------------------------------------------
+
+import argparse
+import numpy as np
+import matplotlib.pyplot as plt
+import cv2
+import json
+import matplotlib.lines as mlines
+import matplotlib.patches as mpatches
+from pycocotools.coco import COCO
+from pycocotools.cocoeval import COCOeval
+import os
+
+
+class ColorStyle:
+    def __init__(self, color, link_pairs, point_color):
+        self.color = color
+        self.link_pairs = link_pairs
+        self.point_color = point_color
+
+        for i in range(len(self.color)):
+            self.link_pairs[i].append(tuple(np.array(self.color[i])/255.))
+
+        self.ring_color = []
+        for i in range(len(self.point_color)):
+            self.ring_color.append(tuple(np.array(self.point_color[i])/255.))
+        
+# Xiaochu Style
+# (R,G,B)
+color1 = [(179,0,0),(228,26,28),(255,255,51),
+    (49,163,84), (0,109,45), (255,255,51),
+    (240,2,127),(240,2,127),(240,2,127), (240,2,127), (240,2,127), 
+    (217,95,14), (254,153,41),(255,255,51),
+    (44,127,184),(0,0,255)]
+
+link_pairs1 = [
+        [15, 13], [13, 11], [11, 5], 
+        [12, 14], [14, 16], [12, 6], 
+        [3, 1],[1, 2],[1, 0],[0, 2],[2,4],
+        [9, 7], [7,5], [5, 6],
+        [6, 8], [8, 10],
+        ]
+
+point_color1 = [(240,2,127),(240,2,127),(240,2,127), 
+            (240,2,127), (240,2,127), 
+            (255,255,51),(255,255,51),
+            (254,153,41),(44,127,184),
+            (217,95,14),(0,0,255),
+            (255,255,51),(255,255,51),(228,26,28),
+            (49,163,84),(252,176,243),(0,176,240),
+            (255,255,0),(169, 209, 142),
+            (255,255,0),(169, 209, 142),
+            (255,255,0),(169, 209, 142)]
+
+xiaochu_style = ColorStyle(color1, link_pairs1, point_color1)
+
+
+# Chunhua Style
+# (R,G,B)
+color2 = [(252,176,243),(252,176,243),(252,176,243),
+    (0,176,240), (0,176,240), (0,176,240),
+    (240,2,127),(240,2,127),(240,2,127), (240,2,127), (240,2,127), 
+    (255,255,0), (255,255,0),(169, 209, 142),
+    (169, 209, 142),(169, 209, 142)]
+
+link_pairs2 = [
+        [15, 13], [13, 11], [11, 5], 
+        [12, 14], [14, 16], [12, 6], 
+        [3, 1],[1, 2],[1, 0],[0, 2],[2,4],
+        [9, 7], [7,5], [5, 6], [6, 8], [8, 10],
+        ]
+
+point_color2 = [(240,2,127),(240,2,127),(240,2,127), 
+            (240,2,127), (240,2,127), 
+            (255,255,0),(169, 209, 142),
+            (255,255,0),(169, 209, 142),
+            (255,255,0),(169, 209, 142),
+            (252,176,243),(0,176,240),(252,176,243),
+            (0,176,240),(252,176,243),(0,176,240),
+            (255,255,0),(169, 209, 142),
+            (255,255,0),(169, 209, 142),
+            (255,255,0),(169, 209, 142)]
+
+chunhua_style = ColorStyle(color2, link_pairs2, point_color2)
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Visualize COCO predictions')
+    # general
+    parser.add_argument('--image-path',
+                        help='Path of COCO val images',
+                        type=str,
+                        default='data/coco/images/val2017/'
+                        )
+
+    parser.add_argument('--gt-anno',
+                        help='Path of COCO val annotation',
+                        type=str,
+                        default='data/coco/annotations/person_keypoints_val2017.json'
+                        )
+
+    parser.add_argument('--save-path',
+                        help="Path to save the visualizations",
+                        type=str,
+                        default='visualization/coco/')
+
+    parser.add_argument('--prediction',
+                        help="Prediction file to visualize",
+                        type=str,
+                        required=True)
+
+    parser.add_argument('--style',
+                        help="Style of the visualization: Chunhua style or Xiaochu style",
+                        type=str,
+                        default='chunhua')
+
+    args = parser.parse_args()
+
+    return args
+
+
+def map_joint_dict(joints):
+    joints_dict = {}
+    for i in range(joints.shape[0]):
+        x = int(joints[i][0])
+        y = int(joints[i][1])
+        id = i
+        joints_dict[id] = (x, y)
+        
+    return joints_dict
+
+def plot(data, gt_file, img_path, save_path, 
+         link_pairs, ring_color, save=True):
+    
+    # joints
+    coco = COCO(gt_file)
+    coco_dt = coco.loadRes(data)
+    coco_eval = COCOeval(coco, coco_dt, 'keypoints')
+    coco_eval._prepare()
+    gts_ = coco_eval._gts
+    dts_ = coco_eval._dts
+    
+    p = coco_eval.params
+    p.imgIds = list(np.unique(p.imgIds))
+    if p.useCats:
+        p.catIds = list(np.unique(p.catIds))
+    p.maxDets = sorted(p.maxDets)
+
+    # loop through images, area range, max detection number
+    catIds = p.catIds if p.useCats else [-1]
+    threshold = 0.3
+    joint_thres = 0.2
+    for catId in catIds:
+        for imgId in p.imgIds[:5000]:
+            # dimention here should be Nxm
+            gts = gts_[imgId, catId]
+            dts = dts_[imgId, catId]
+            inds = np.argsort([-d['score'] for d in dts], kind='mergesort')
+            dts = [dts[i] for i in inds]
+            if len(dts) > p.maxDets[-1]:
+                dts = dts[0:p.maxDets[-1]]
+            if len(gts) == 0 or len(dts) == 0:
+                continue
+            
+            sum_score = 0
+            num_box = 0
+            img_name = str(imgId).zfill(12)
+            
+            # Read Images
+            img_file = img_path + img_name + '.jpg'
+            data_numpy = cv2.imread(img_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
+            h = data_numpy.shape[0]
+            w = data_numpy.shape[1]
+            
+            # Plot
+            fig = plt.figure(figsize=(w/100, h/100), dpi=100)
+            ax = plt.subplot(1,1,1)
+            bk = plt.imshow(data_numpy[:,:,::-1])
+            bk.set_zorder(-1)
+            print(img_name)
+            for j, gt in enumerate(gts):
+                # matching dt_box and gt_box
+                bb = gt['bbox']
+                x0 = bb[0] - bb[2]; x1 = bb[0] + bb[2] * 2
+                y0 = bb[1] - bb[3]; y1 = bb[1] + bb[3] * 2
+
+                # create bounds for ignore regions(double the gt bbox)
+                g = np.array(gt['keypoints'])
+                #xg = g[0::3]; yg = g[1::3]; 
+                vg = g[2::3]     
+            
+                for i, dt in enumerate(dts):
+                    # Calculate IoU
+                    dt_bb = dt['bbox']
+                    dt_x0 = dt_bb[0] - dt_bb[2]; dt_x1 = dt_bb[0] + dt_bb[2] * 2
+                    dt_y0 = dt_bb[1] - dt_bb[3]; dt_y1 = dt_bb[1] + dt_bb[3] * 2          
+                    
+                    ol_x = min(x1, dt_x1) - max(x0, dt_x0)
+                    ol_y = min(y1, dt_y1) - max(y0, dt_y0)
+                    ol_area = ol_x * ol_y
+                    s_x = max(x1, dt_x1) - min(x0, dt_x0)
+                    s_y = max(y1, dt_y1) - min(y0, dt_y0)
+                    sum_area = s_x * s_y
+                    iou = ol_area / (sum_area + np.spacing(1))                    
+                    score = dt['score']
+                    
+                    if iou < 0.1 or score < threshold:
+                        continue
+                    else:
+                        print('iou: ', iou)
+                        dt_w = dt_x1 - dt_x0
+                        dt_h = dt_y1 - dt_y0
+                        ref = min(dt_w, dt_h)
+                        num_box += 1
+                        sum_score += dt['score']
+                        dt_joints = np.array(dt['keypoints']).reshape(17,-1)
+                        joints_dict = map_joint_dict(dt_joints)
+                        
+                        # stick 
+                        for k, link_pair in enumerate(link_pairs):
+                            if link_pair[0] in joints_dict \
+                            and link_pair[1] in joints_dict:
+                                if dt_joints[link_pair[0],2] < joint_thres \
+                                    or dt_joints[link_pair[1],2] < joint_thres \
+                                    or vg[link_pair[0]] == 0 \
+                                    or vg[link_pair[1]] == 0:
+                                    continue
+                            if k in range(6,11):
+                                lw = 1
+                            else:
+                                lw = ref / 100.
+                            line = mlines.Line2D(
+                                    np.array([joints_dict[link_pair[0]][0],
+                                              joints_dict[link_pair[1]][0]]),
+                                    np.array([joints_dict[link_pair[0]][1],
+                                              joints_dict[link_pair[1]][1]]),
+                                    ls='-', lw=lw, alpha=1, color=link_pair[2],)
+                            line.set_zorder(0)
+                            ax.add_line(line)
+                        # black ring
+                        for k in range(dt_joints.shape[0]):
+                            if dt_joints[k,2] < joint_thres \
+                                or vg[link_pair[0]] == 0 \
+                                or vg[link_pair[1]] == 0:
+                                continue
+                            if dt_joints[k,0] > w or dt_joints[k,1] > h:
+                                continue
+                            if k in range(5):
+                                radius = 1
+                            else:
+                                radius = ref / 100
+                    
+                            circle = mpatches.Circle(tuple(dt_joints[k,:2]), 
+                                                     radius=radius, 
+                                                     ec='black', 
+                                                     fc=ring_color[k], 
+                                                     alpha=1, 
+                                                     linewidth=1)
+                            circle.set_zorder(1)
+                            ax.add_patch(circle)
+        
+            avg_score = (sum_score / (num_box+np.spacing(1)))*1000
+        
+            plt.gca().xaxis.set_major_locator(plt.NullLocator())
+            plt.gca().yaxis.set_major_locator(plt.NullLocator())
+            plt.axis('off')
+            plt.subplots_adjust(top=1,bottom=0,left=0,right=1,hspace=0,wspace=0)        
+            plt.margins(0,0)
+            if save:
+                plt.savefig(save_path + \
+                           'score_'+str(np.int(avg_score))+ \
+                           '_id_'+str(imgId)+ \
+                           '_'+img_name + '.png', 
+                           format='png', bbox_inckes='tight', dpi=100)
+                plt.savefig(save_path +'id_'+str(imgId)+ '.pdf', format='pdf', 
+                            bbox_inckes='tight', dpi=100)
+            # plt.show()
+            plt.close()
+
+if __name__ == '__main__':
+
+    args = parse_args()
+    if args.style == 'xiaochu':
+        # Xiaochu Style
+        colorstyle = xiaochu_style
+    elif args.style == 'chunhua':
+        # Chunhua Style
+        colorstyle = chunhua_style
+    else:
+        raise Exception('Invalid color style')
+    
+    save_path = args.save_path
+    img_path = args.image_path
+    if not os.path.exists(save_path):
+        try:
+            os.makedirs(save_path)
+        except Exception:
+            print('Fail to make {}'.format(save_path))
+
+    
+    with open(args.prediction) as f:
+        data = json.load(f)
+    gt_file = args.gt_anno
+    plot(data, gt_file, img_path, save_path, colorstyle.link_pairs, colorstyle.ring_color, save=True)
+