|
| 1 | +from __future__ import absolute_import |
| 2 | +from __future__ import division |
| 3 | +from __future__ import print_function |
| 4 | + |
| 5 | +import argparse |
| 6 | +import csv |
| 7 | +import os |
| 8 | +import shutil |
| 9 | + |
| 10 | +from PIL import Image |
| 11 | +import torch |
| 12 | +import torch.nn.parallel |
| 13 | +import torch.backends.cudnn as cudnn |
| 14 | +import torch.optim |
| 15 | +import torch.utils.data |
| 16 | +import torch.utils.data.distributed |
| 17 | +import torchvision.transforms as transforms |
| 18 | +import torchvision |
| 19 | +import cv2 |
| 20 | +import numpy as np |
| 21 | + |
| 22 | + |
| 23 | +import _init_paths |
| 24 | +import models |
| 25 | +from config import cfg |
| 26 | +from config import update_config |
| 27 | +from core.function import get_final_preds |
| 28 | +from utils.transforms import get_affine_transform |
| 29 | + |
| 30 | +COCO_KEYPOINT_INDEXES = { |
| 31 | + 0: 'nose', |
| 32 | + 1: 'left_eye', |
| 33 | + 2: 'right_eye', |
| 34 | + 3: 'left_ear', |
| 35 | + 4: 'right_ear', |
| 36 | + 5: 'left_shoulder', |
| 37 | + 6: 'right_shoulder', |
| 38 | + 7: 'left_elbow', |
| 39 | + 8: 'right_elbow', |
| 40 | + 9: 'left_wrist', |
| 41 | + 10: 'right_wrist', |
| 42 | + 11: 'left_hip', |
| 43 | + 12: 'right_hip', |
| 44 | + 13: 'left_knee', |
| 45 | + 14: 'right_knee', |
| 46 | + 15: 'left_ankle', |
| 47 | + 16: 'right_ankle' |
| 48 | +} |
| 49 | + |
| 50 | +COCO_INSTANCE_CATEGORY_NAMES = [ |
| 51 | + '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', |
| 52 | + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', |
| 53 | + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', |
| 54 | + 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', |
| 55 | + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', |
| 56 | + 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', |
| 57 | + 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', |
| 58 | + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', |
| 59 | + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', |
| 60 | + 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', |
| 61 | + 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', |
| 62 | + 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' |
| 63 | +] |
| 64 | + |
| 65 | + |
| 66 | +def get_person_detection_boxes(model, img, threshold=0.5): |
| 67 | + pil_image = Image.fromarray(img) # Load the image |
| 68 | + transform = transforms.Compose([transforms.ToTensor()]) # Defing PyTorch Transform |
| 69 | + transformed_img = transform(pil_image) # Apply the transform to the image |
| 70 | + pred = model([transformed_img]) # Pass the image to the model |
| 71 | + pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i] |
| 72 | + for i in list(pred[0]['labels'].numpy())] # Get the Prediction Score |
| 73 | + pred_boxes = [[(i[0], i[1]), (i[2], i[3])] |
| 74 | + for i in list(pred[0]['boxes'].detach().numpy())] # Bounding boxes |
| 75 | + pred_score = list(pred[0]['scores'].detach().numpy()) |
| 76 | + if not pred_score: |
| 77 | + return [] |
| 78 | + # Get list of index with score greater than threshold |
| 79 | + pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1] |
| 80 | + pred_boxes = pred_boxes[:pred_t+1] |
| 81 | + pred_classes = pred_classes[:pred_t+1] |
| 82 | + |
| 83 | + person_boxes = [] |
| 84 | + for idx, box in enumerate(pred_boxes): |
| 85 | + if pred_classes[idx] == 'person': |
| 86 | + person_boxes.append(box) |
| 87 | + |
| 88 | + return person_boxes |
| 89 | + |
| 90 | + |
| 91 | +def get_pose_estimation_prediction(pose_model, image, center, scale): |
| 92 | + rotation = 0 |
| 93 | + |
| 94 | + # pose estimation transformation |
| 95 | + trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE) |
| 96 | + model_input = cv2.warpAffine( |
| 97 | + image, |
| 98 | + trans, |
| 99 | + (int(cfg.MODEL.IMAGE_SIZE[0]), int(cfg.MODEL.IMAGE_SIZE[1])), |
| 100 | + flags=cv2.INTER_LINEAR) |
| 101 | + transform = transforms.Compose([ |
| 102 | + transforms.ToTensor(), |
| 103 | + transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| 104 | + std=[0.229, 0.224, 0.225]), |
| 105 | + ]) |
| 106 | + |
| 107 | + # pose estimation inference |
| 108 | + model_input = transform(model_input).unsqueeze(0) |
| 109 | + # switch to evaluate mode |
| 110 | + pose_model.eval() |
| 111 | + with torch.no_grad(): |
| 112 | + # compute output heatmap |
| 113 | + output = pose_model(model_input) |
| 114 | + preds, _ = get_final_preds( |
| 115 | + cfg, |
| 116 | + output.clone().cpu().numpy(), |
| 117 | + np.asarray([center]), |
| 118 | + np.asarray([scale])) |
| 119 | + |
| 120 | + return preds |
| 121 | + |
| 122 | + |
| 123 | +def box_to_center_scale(box, model_image_width, model_image_height): |
| 124 | + """convert a box to center,scale information required for pose transformation |
| 125 | + Parameters |
| 126 | + ---------- |
| 127 | + box : list of tuple |
| 128 | + list of length 2 with two tuples of floats representing |
| 129 | + bottom left and top right corner of a box |
| 130 | + model_image_width : int |
| 131 | + model_image_height : int |
| 132 | +
|
| 133 | + Returns |
| 134 | + ------- |
| 135 | + (numpy array, numpy array) |
| 136 | + Two numpy arrays, coordinates for the center of the box and the scale of the box |
| 137 | + """ |
| 138 | + center = np.zeros((2), dtype=np.float32) |
| 139 | + |
| 140 | + bottom_left_corner = box[0] |
| 141 | + top_right_corner = box[1] |
| 142 | + box_width = top_right_corner[0]-bottom_left_corner[0] |
| 143 | + box_height = top_right_corner[1]-bottom_left_corner[1] |
| 144 | + bottom_left_x = bottom_left_corner[0] |
| 145 | + bottom_left_y = bottom_left_corner[1] |
| 146 | + center[0] = bottom_left_x + box_width * 0.5 |
| 147 | + center[1] = bottom_left_y + box_height * 0.5 |
| 148 | + |
| 149 | + aspect_ratio = model_image_width * 1.0 / model_image_height |
| 150 | + pixel_std = 200 |
| 151 | + |
| 152 | + if box_width > aspect_ratio * box_height: |
| 153 | + box_height = box_width * 1.0 / aspect_ratio |
| 154 | + elif box_width < aspect_ratio * box_height: |
| 155 | + box_width = box_height * aspect_ratio |
| 156 | + scale = np.array( |
| 157 | + [box_width * 1.0 / pixel_std, box_height * 1.0 / pixel_std], |
| 158 | + dtype=np.float32) |
| 159 | + if center[0] != -1: |
| 160 | + scale = scale * 1.25 |
| 161 | + |
| 162 | + return center, scale |
| 163 | + |
| 164 | + |
| 165 | +def prepare_output_dirs(prefix='/output/'): |
| 166 | + pose_dir = prefix+'poses/' |
| 167 | + box_dir = prefix+'boxes/' |
| 168 | + if os.path.exists(pose_dir) and os.path.isdir(pose_dir): |
| 169 | + shutil.rmtree(pose_dir) |
| 170 | + if os.path.exists(box_dir) and os.path.isdir(box_dir): |
| 171 | + shutil.rmtree(box_dir) |
| 172 | + os.makedirs(pose_dir, exist_ok=True) |
| 173 | + os.makedirs(box_dir, exist_ok=True) |
| 174 | + return pose_dir, box_dir |
| 175 | + |
| 176 | + |
| 177 | +def parse_args(): |
| 178 | + parser = argparse.ArgumentParser(description='Train keypoints network') |
| 179 | + # general |
| 180 | + parser.add_argument('--cfg', type=str, required=True) |
| 181 | + parser.add_argument('--videoFile', type=str, required=True) |
| 182 | + parser.add_argument('--outputDir', type=str, default='/output/') |
| 183 | + parser.add_argument('--inferenceFps', type=int, default=10) |
| 184 | + parser.add_argument('--writeBoxFrames', action='store_true') |
| 185 | + |
| 186 | + parser.add_argument('opts', |
| 187 | + help='Modify config options using the command-line', |
| 188 | + default=None, |
| 189 | + nargs=argparse.REMAINDER) |
| 190 | + |
| 191 | + args = parser.parse_args() |
| 192 | + |
| 193 | + # args expected by supporting codebase |
| 194 | + args.modelDir = '' |
| 195 | + args.logDir = '' |
| 196 | + args.dataDir = '' |
| 197 | + args.prevModelDir = '' |
| 198 | + return args |
| 199 | + |
| 200 | + |
| 201 | +def main(): |
| 202 | + # cudnn related setting |
| 203 | + cudnn.benchmark = cfg.CUDNN.BENCHMARK |
| 204 | + torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC |
| 205 | + torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED |
| 206 | + |
| 207 | + args = parse_args() |
| 208 | + update_config(cfg, args) |
| 209 | + pose_dir, box_dir = prepare_output_dirs(args.outputDir) |
| 210 | + csv_output_filename = args.outputDir+'pose-data.csv' |
| 211 | + csv_output_rows = [] |
| 212 | + |
| 213 | + box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) |
| 214 | + box_model.eval() |
| 215 | + |
| 216 | + pose_model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')( |
| 217 | + cfg, is_train=False |
| 218 | + ) |
| 219 | + |
| 220 | + if cfg.TEST.MODEL_FILE: |
| 221 | + print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE)) |
| 222 | + pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False) |
| 223 | + else: |
| 224 | + print('expected model defined in config at TEST.MODEL_FILE') |
| 225 | + |
| 226 | + pose_model = torch.nn.DataParallel(pose_model, device_ids=cfg.GPUS).cuda() |
| 227 | + |
| 228 | + # Loading an video |
| 229 | + vidcap = cv2.VideoCapture(args.videoFile) |
| 230 | + fps = vidcap.get(cv2.CAP_PROP_FPS) |
| 231 | + if fps < args.inferenceFps: |
| 232 | + print('desired inference fps is '+str(args.inferenceFps)+' but video fps is '+str(fps)) |
| 233 | + exit() |
| 234 | + every_nth_frame = round(fps/args.inferenceFps) |
| 235 | + |
| 236 | + success, image_bgr = vidcap.read() |
| 237 | + count = 0 |
| 238 | + |
| 239 | + while success: |
| 240 | + if count % every_nth_frame != 0: |
| 241 | + success, image_bgr = vidcap.read() |
| 242 | + count += 1 |
| 243 | + continue |
| 244 | + |
| 245 | + image = image_bgr[:, :, [2, 1, 0]] |
| 246 | + count_str = str(count).zfill(32) |
| 247 | + |
| 248 | + # object detection box |
| 249 | + pred_boxes = get_person_detection_boxes(box_model, image, threshold=0.8) |
| 250 | + if args.writeBoxFrames: |
| 251 | + image_bgr_box = image_bgr.copy() |
| 252 | + for box in pred_boxes: |
| 253 | + cv2.rectangle(image_bgr_box, box[0], box[1], color=(0, 255, 0), |
| 254 | + thickness=3) # Draw Rectangle with the coordinates |
| 255 | + cv2.imwrite(box_dir+'box%s.jpg' % count_str, image_bgr_box) |
| 256 | + if not pred_boxes: |
| 257 | + success, image_bgr = vidcap.read() |
| 258 | + count += 1 |
| 259 | + continue |
| 260 | + |
| 261 | + # pose estimation |
| 262 | + box = pred_boxes[0] # assume there is only 1 person |
| 263 | + center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]) |
| 264 | + image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy() |
| 265 | + pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale) |
| 266 | + |
| 267 | + new_csv_row = [] |
| 268 | + for _, mat in enumerate(pose_preds[0]): |
| 269 | + x_coord, y_coord = int(mat[0]), int(mat[1]) |
| 270 | + cv2.circle(image_bgr, (x_coord, y_coord), 4, (255, 0, 0), 2) |
| 271 | + new_csv_row.extend([x_coord, y_coord]) |
| 272 | + |
| 273 | + csv_output_rows.append(new_csv_row) |
| 274 | + cv2.imwrite(pose_dir+'pose%s.jpg' % count_str, image_bgr) |
| 275 | + |
| 276 | + # get next frame |
| 277 | + success, image_bgr = vidcap.read() |
| 278 | + count += 1 |
| 279 | + |
| 280 | + # write csv |
| 281 | + csv_headers = ['frame'] |
| 282 | + for keypoint in COCO_KEYPOINT_INDEXES.values(): |
| 283 | + csv_headers.extend([keypoint+'_x', keypoint+'_y']) |
| 284 | + |
| 285 | + with open(csv_output_filename, 'w', newline='') as csvfile: |
| 286 | + csvwriter = csv.writer(csvfile) |
| 287 | + csvwriter.writerow(csv_headers) |
| 288 | + csvwriter.writerows(csv_output_rows) |
| 289 | + |
| 290 | + os.system("ffmpeg -y -r " |
| 291 | + + str(args.inferenceFps) |
| 292 | + + " -pattern_type glob -i '" |
| 293 | + + pose_dir |
| 294 | + + "/*.jpg' -c:v libx264 -vf fps=" |
| 295 | + + str(args.inferenceFps)+" -pix_fmt yuv420p /output/movie.mp4") |
| 296 | + |
| 297 | + |
| 298 | +if __name__ == '__main__': |
| 299 | + main() |
0 commit comments