Skip to content

[feat] Add Batch Inference Script #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 258 additions & 0 deletions demo/demo_without_detection_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import csv
import os
import shutil

from PIL import Image
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision
import cv2
import numpy as np
import time


import _init_paths
import models
from config import cfg
from config import update_config
from core.function import get_final_preds
from utils.transforms import get_affine_transform

time_processing_total = 0

COCO_KEYPOINT_INDEXES = {
0: 'nose',
1: 'left_eye',
2: 'right_eye',
3: 'left_ear',
4: 'right_ear',
5: 'left_shoulder',
6: 'right_shoulder',
7: 'left_elbow',
8: 'right_elbow',
9: 'left_wrist',
10: 'right_wrist',
11: 'left_hip',
12: 'right_hip',
13: 'left_knee',
14: 'right_knee',
15: 'left_ankle',
16: 'right_ankle'
}

COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

SKELETON = [
[1,3],[1,0],[2,4],[2,0],[0,5],[0,6],[5,7],[7,9],[6,8],[8,10],[5,11],[6,12],[11,12],[11,13],[13,15],[12,14],[14,16]
]

CocoColors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]

NUM_KPTS = 17

CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


def box_to_center_scale(box, model_image_width, model_image_height):
"""convert a box to center,scale information required for pose transformation
Parameters
----------
box : list of tuple
list of length 2 with two tuples of floats representing
bottom left and top right corner of a box
model_image_width : int
model_image_height : int

Returns
-------
(numpy array, numpy array)
Two numpy arrays, coordinates for the center of the box and the scale of the box
"""
center = np.zeros((2), dtype=np.float32)

bottom_left_corner = box[0]
top_right_corner = box[1]
box_width = top_right_corner[0]-bottom_left_corner[0]
box_height = top_right_corner[1]-bottom_left_corner[1]
bottom_left_x = bottom_left_corner[0]
bottom_left_y = bottom_left_corner[1]
center[0] = bottom_left_x + box_width * 0.5
center[1] = bottom_left_y + box_height * 0.5

aspect_ratio = model_image_width * 1.0 / model_image_height
pixel_std = 200

if box_width > aspect_ratio * box_height:
box_height = box_width * 1.0 / aspect_ratio
elif box_width < aspect_ratio * box_height:
box_width = box_height * aspect_ratio
scale = np.array(
[box_width * 1.0 / pixel_std, box_height * 1.0 / pixel_std],
dtype=np.float32)
if center[0] != -1:
scale = scale * 1.25

return center, scale


def get_model(args):
update_config(cfg, args)
pose_model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(cfg, is_train=False)

if cfg.TEST.MODEL_FILE:
print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
else:
print('expected model defined in config at TEST.MODEL_FILE')

pose_model = torch.nn.DataParallel(pose_model, device_ids=cfg.GPUS)
pose_model.to(CTX)
pose_model.eval()
return pose_model


def preprocess(image):
image_permuted = image[:, :, [2, 1, 0]]
h, w, _ = image_permuted.shape
box = [[0, 0],
[w, h]]
rotation = 0
center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])

# pose estimation transformation
trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE)
image_warped = cv2.warpAffine(image_permuted, trans, (int(cfg.MODEL.IMAGE_SIZE[0]), int(cfg.MODEL.IMAGE_SIZE[1])), flags=cv2.INTER_LINEAR)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
image_tensor = transform(image_warped)
return image_tensor, center, scale



def inference_batch(model ,batch_input, batch_center, batch_scale):
model.eval()
with torch.no_grad():
# compute output heatmap
##############################################################################
torch.cuda.current_stream().synchronize()
time_start = time.time()
##############################################################################
batch_output = model(batch_input)
##############################################################################
torch.cuda.current_stream().synchronize()
time_end = time.time()
time_processing = time_end - time_start
global time_processing_total
time_processing_total += time_processing
##############################################################################
batch_output = batch_output.detach().cpu().numpy()

list_keypoints = []
for output, center, scale in zip(batch_output, batch_center, batch_scale):
keypoints, _ = get_final_preds(cfg, output[np.newaxis, :, :], np.asarray([center]), np.asarray([scale]))
list_keypoints.append(keypoints)

return list_keypoints


def draw_pose(keypoints,img):
"""draw the keypoints and the skeletons.
:params keypoints: the shape should be equal to [17,2]
:params img:
"""
assert keypoints.shape == (NUM_KPTS,2)
for i in range(len(SKELETON)):
kpt_a, kpt_b = SKELETON[i][0], SKELETON[i][1]
x_a, y_a = keypoints[kpt_a][0],keypoints[kpt_a][1]
x_b, y_b = keypoints[kpt_b][0],keypoints[kpt_b][1]
cv2.circle(img, (int(x_a), int(y_a)), 6, CocoColors[i], -1)
cv2.circle(img, (int(x_b), int(y_b)), 6, CocoColors[i], -1)
cv2.line(img, (int(x_a), int(y_a)), (int(x_b), int(y_b)), CocoColors[i], 2)


def parse_args():
parser = argparse.ArgumentParser(description='Without Detection Demo')
parser.add_argument('--image', type=str, default='sunglassman.jpg')
parser.add_argument('--cfg', type=str, default='demo/inference-config.yaml')
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('opts',
help='Modify config options using the command-line',
default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
args.modelDir = ''
args.logDir = ''
args.dataDir = ''
args.prevModelDir = ''
return args


if __name__ == '__main__':
cudnn.benchmark = cfg.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

args = parse_args()
model_ = get_model(args)
print('Total ', sum(p.numel() for p in model_.parameters()), ' Parameters')


image_ = cv2.imread(args.image)

COUNT_ITER = 100
print('Timecheck Start!')
# ##############################################################################
# torch.cuda.current_stream().synchronize()
# time_start = time.time()
# ##############################################################################
for iter in range(COUNT_ITER):
list_image_tensor = []
list_center = []
list_scale = []
for index_image in range(args.batch_size):
image_ = np.clip((image_ + ((index_image + iter) % 10) -5), 0, 255)
image_tensor_, center_, scale_ = preprocess(image_)
list_image_tensor.append(image_tensor_)
list_center.append(center_[np.newaxis, :])
list_scale.append(scale_[np.newaxis, :])

batch_image_tensor = torch.stack(list_image_tensor, dim=0)
batch_center_ = np.concatenate(list_center, axis=0)
batch_scale_ = np.concatenate(list_scale, axis=0)

list_keypoints_ = inference_batch(model_, batch_image_tensor, batch_center_, batch_scale_)
##############################################################################
# torch.cuda.current_stream().synchronize()
# time_end = time.time()
# time_processing = time_end - time_start
# time_processing_total += time_processing
##############################################################################
print(time_processing_total / (COUNT_ITER * args.batch_size))
2 changes: 1 addition & 1 deletion demo/inference-config_w32_256x192.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ TRAIN:
MOMENTUM: 0.9
NESTEROV: false
TEST:
BATCH_SIZE_PER_GPU: 32
BATCH_SIZE_PER_GPU: 512
COCO_BBOX_FILE: 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json'
BBOX_THRE: 1.0
IMAGE_THRE: 0.0
Expand Down