19
19
import cv2
20
20
import numpy as np
21
21
22
+ import sys
23
+ sys .path .append ("../lib" )
24
+ import time
22
25
23
- import _init_paths
26
+ # import _init_paths
24
27
import models
25
28
from config import cfg
26
29
from config import update_config
27
- from core .function import get_final_preds
30
+ from core .inference import get_final_preds
28
31
from utils .transforms import get_affine_transform
29
32
33
+ CTX = torch .device ('cuda' ) if torch .cuda .is_available () else torch .device ('cpu' )
34
+
35
+
30
36
COCO_KEYPOINT_INDEXES = {
31
37
0 : 'nose' ,
32
38
1 : 'left_eye' ,
@@ -67,57 +73,53 @@ def get_person_detection_boxes(model, img, threshold=0.5):
67
73
pil_image = Image .fromarray (img ) # Load the image
68
74
transform = transforms .Compose ([transforms .ToTensor ()]) # Defing PyTorch Transform
69
75
transformed_img = transform (pil_image ) # Apply the transform to the image
70
- pred = model ([transformed_img ]) # Pass the image to the model
76
+ pred = model ([transformed_img .to (CTX )]) # Pass the image to the model
77
+ # Use the first detected person
71
78
pred_classes = [COCO_INSTANCE_CATEGORY_NAMES [i ]
72
- for i in list (pred [0 ]['labels' ].numpy ())] # Get the Prediction Score
79
+ for i in list (pred [0 ]['labels' ].cpu (). numpy ())] # Get the Prediction Score
73
80
pred_boxes = [[(i [0 ], i [1 ]), (i [2 ], i [3 ])]
74
- for i in list (pred [0 ]['boxes' ].detach ().numpy ())] # Bounding boxes
75
- pred_score = list (pred [0 ]['scores' ].detach ().numpy ())
76
- if not pred_score :
77
- return []
78
- # Get list of index with score greater than threshold
79
- pred_t = [pred_score .index (x ) for x in pred_score if x > threshold ][- 1 ]
80
- pred_boxes = pred_boxes [:pred_t + 1 ]
81
- pred_classes = pred_classes [:pred_t + 1 ]
81
+ for i in list (pred [0 ]['boxes' ].cpu ().detach ().numpy ())] # Bounding boxes
82
+ pred_scores = list (pred [0 ]['scores' ].cpu ().detach ().numpy ())
82
83
83
84
person_boxes = []
84
- for idx , box in enumerate (pred_boxes ):
85
- if pred_classes [idx ] == 'person' :
86
- person_boxes .append (box )
85
+ # Select box has score larger than threshold and is person
86
+ for pred_class , pred_box , pred_score in zip (pred_classes , pred_boxes , pred_scores ):
87
+ if (pred_score > threshold ) and (pred_class == 'person' ):
88
+ person_boxes .append (pred_box )
87
89
88
90
return person_boxes
89
91
90
92
91
- def get_pose_estimation_prediction (pose_model , image , center , scale ):
93
+ def get_pose_estimation_prediction (pose_model , image , centers , scales , transform ):
92
94
rotation = 0
93
95
94
96
# pose estimation transformation
95
- trans = get_affine_transform ( center , scale , rotation , cfg . MODEL . IMAGE_SIZE )
96
- model_input = cv2 . warpAffine (
97
- image ,
98
- trans ,
99
- ( int ( cfg . MODEL . IMAGE_SIZE [ 0 ]), int ( cfg . MODEL . IMAGE_SIZE [ 1 ])),
100
- flags = cv2 . INTER_LINEAR )
101
- transform = transforms . Compose ([
102
- transforms . ToTensor ( ),
103
- transforms . Normalize ( mean = [ 0.485 , 0.456 , 0.406 ],
104
- std = [ 0.229 , 0.224 , 0.225 ]),
105
- ])
106
-
107
- # pose estimation inference
108
- model_input = transform ( model_input ). unsqueeze ( 0 )
109
- # switch to evaluate mode
110
- pose_model . eval ( )
111
- with torch . no_grad ():
112
- # compute output heatmap
113
- output = pose_model (model_input )
114
- preds , _ = get_final_preds (
115
- cfg ,
116
- output .clone ().cpu ().numpy (),
117
- np .asarray ([ center ] ),
118
- np .asarray ([ scale ] ))
119
-
120
- return preds
97
+ model_inputs = []
98
+ for center , scale in zip ( centers , scales ):
99
+ trans = get_affine_transform ( center , scale , rotation , cfg . MODEL . IMAGE_SIZE )
100
+ # Crop smaller image of people
101
+ model_input = cv2 . warpAffine (
102
+ image ,
103
+ trans ,
104
+ ( int ( cfg . MODEL . IMAGE_SIZE [ 0 ]), int ( cfg . MODEL . IMAGE_SIZE [ 1 ]) ),
105
+ flags = cv2 . INTER_LINEAR )
106
+
107
+ # hwc -> 1chw
108
+ model_input = transform ( model_input ) #.unsqueeze(0)
109
+ model_inputs . append ( model_input )
110
+
111
+ # n * 1chw -> nchw
112
+ model_inputs = torch . stack ( model_inputs )
113
+
114
+ # compute output heatmap
115
+ output = pose_model (model_inputs . to ( CTX ) )
116
+ coords , _ = get_final_preds (
117
+ cfg ,
118
+ output .cpu ().detach ().numpy (),
119
+ np .asarray (centers ),
120
+ np .asarray (scales ))
121
+
122
+ return coords
121
123
122
124
123
125
def box_to_center_scale (box , model_image_width , model_image_height ):
@@ -163,15 +165,11 @@ def box_to_center_scale(box, model_image_width, model_image_height):
163
165
164
166
165
167
def prepare_output_dirs (prefix = '/output/' ):
166
- pose_dir = prefix + 'poses/'
167
- box_dir = prefix + 'boxes/'
168
+ pose_dir = os .path .join (prefix , "pose" )
168
169
if os .path .exists (pose_dir ) and os .path .isdir (pose_dir ):
169
170
shutil .rmtree (pose_dir )
170
- if os .path .exists (box_dir ) and os .path .isdir (box_dir ):
171
- shutil .rmtree (box_dir )
172
171
os .makedirs (pose_dir , exist_ok = True )
173
- os .makedirs (box_dir , exist_ok = True )
174
- return pose_dir , box_dir
172
+ return pose_dir
175
173
176
174
177
175
def parse_args ():
@@ -199,20 +197,26 @@ def parse_args():
199
197
200
198
201
199
def main ():
200
+ # transformation
201
+ pose_transform = transforms .Compose ([
202
+ transforms .ToTensor (),
203
+ transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ],
204
+ std = [0.229 , 0.224 , 0.225 ]),
205
+ ])
206
+
202
207
# cudnn related setting
203
208
cudnn .benchmark = cfg .CUDNN .BENCHMARK
204
209
torch .backends .cudnn .deterministic = cfg .CUDNN .DETERMINISTIC
205
210
torch .backends .cudnn .enabled = cfg .CUDNN .ENABLED
206
211
207
212
args = parse_args ()
208
213
update_config (cfg , args )
209
- pose_dir , box_dir = prepare_output_dirs (args .outputDir )
210
- csv_output_filename = args .outputDir + 'pose-data.csv'
214
+ pose_dir = prepare_output_dirs (args .outputDir )
211
215
csv_output_rows = []
212
216
213
217
box_model = torchvision .models .detection .fasterrcnn_resnet50_fpn (pretrained = True )
218
+ box_model .to (CTX )
214
219
box_model .eval ()
215
-
216
220
pose_model = eval ('models.' + cfg .MODEL .NAME + '.get_pose_net' )(
217
221
cfg , is_train = False
218
222
)
@@ -223,76 +227,114 @@ def main():
223
227
else :
224
228
print ('expected model defined in config at TEST.MODEL_FILE' )
225
229
226
- pose_model = torch .nn .DataParallel (pose_model , device_ids = cfg .GPUS ).cuda ()
230
+ pose_model .to (CTX )
231
+ pose_model .eval ()
227
232
228
233
# Loading an video
229
234
vidcap = cv2 .VideoCapture (args .videoFile )
230
235
fps = vidcap .get (cv2 .CAP_PROP_FPS )
231
236
if fps < args .inferenceFps :
232
237
print ('desired inference fps is ' + str (args .inferenceFps )+ ' but video fps is ' + str (fps ))
233
238
exit ()
234
- every_nth_frame = round (fps / args .inferenceFps )
239
+ skip_frame_cnt = round (fps / args .inferenceFps )
240
+ frame_width = int (vidcap .get (cv2 .CAP_PROP_FRAME_WIDTH ))
241
+ frame_height = int (vidcap .get (cv2 .CAP_PROP_FRAME_HEIGHT ))
242
+ outcap = cv2 .VideoWriter ('{}/{}_pose.avi' .format (args .outputDir , os .path .splitext (os .path .basename (args .videoFile ))[0 ]),
243
+ cv2 .VideoWriter_fourcc ('M' , 'J' , 'P' , 'G' ), int (skip_frame_cnt ), (frame_width , frame_height ))
235
244
236
- success , image_bgr = vidcap .read ()
237
245
count = 0
246
+ while vidcap .isOpened ():
247
+ total_now = time .time ()
248
+ ret , image_bgr = vidcap .read ()
249
+ count += 1
238
250
239
- while success :
240
- if count % every_nth_frame != 0 :
241
- success , image_bgr = vidcap .read ()
242
- count += 1
251
+ if not ret :
243
252
continue
244
253
245
- image = image_bgr [:, :, [2 , 1 , 0 ]]
246
- count_str = str (count ).zfill (32 )
254
+ if count % skip_frame_cnt != 0 :
255
+ continue
256
+
257
+ image_rgb = cv2 .cvtColor (image_bgr , cv2 .COLOR_BGR2RGB )
258
+
259
+ # Clone 2 image for person detection and pose estimation
260
+ if cfg .DATASET .COLOR_RGB :
261
+ image_per = image_rgb .copy ()
262
+ image_pose = image_rgb .copy ()
263
+ else :
264
+ image_per = image_bgr .copy ()
265
+ image_pose = image_bgr .copy ()
266
+
267
+ # Clone 1 image for debugging purpose
268
+ image_debug = image_bgr .copy ()
247
269
248
270
# object detection box
249
- pred_boxes = get_person_detection_boxes (box_model , image , threshold = 0.8 )
250
- if args .writeBoxFrames :
251
- image_bgr_box = image_bgr .copy ()
252
- for box in pred_boxes :
253
- cv2 .rectangle (image_bgr_box , box [0 ], box [1 ], color = (0 , 255 , 0 ),
254
- thickness = 3 ) # Draw Rectangle with the coordinates
255
- cv2 .imwrite (box_dir + 'box%s.jpg' % count_str , image_bgr_box )
271
+ now = time .time ()
272
+ pred_boxes = get_person_detection_boxes (box_model , image_per , threshold = 0.9 )
273
+ then = time .time ()
274
+ print ("Find person bbox in: {} sec" .format (then - now ))
275
+
276
+ # Can not find people. Move to next frame
256
277
if not pred_boxes :
257
- success , image_bgr = vidcap .read ()
258
278
count += 1
259
279
continue
260
280
261
- # pose estimation
262
- box = pred_boxes [0 ] # assume there is only 1 person
263
- center , scale = box_to_center_scale (box , cfg .MODEL .IMAGE_SIZE [0 ], cfg .MODEL .IMAGE_SIZE [1 ])
264
- image_pose = image .copy () if cfg .DATASET .COLOR_RGB else image_bgr .copy ()
265
- pose_preds = get_pose_estimation_prediction (pose_model , image_pose , center , scale )
281
+ if args .writeBoxFrames :
282
+ for box in pred_boxes :
283
+ cv2 .rectangle (image_debug , box [0 ], box [1 ], color = (0 , 255 , 0 ),
284
+ thickness = 3 ) # Draw Rectangle with the coordinates
285
+
286
+ # pose estimation : for multiple people
287
+ centers = []
288
+ scales = []
289
+ for box in pred_boxes :
290
+ center , scale = box_to_center_scale (box , cfg .MODEL .IMAGE_SIZE [0 ], cfg .MODEL .IMAGE_SIZE [1 ])
291
+ centers .append (center )
292
+ scales .append (scale )
293
+
294
+ now = time .time ()
295
+ pose_preds = get_pose_estimation_prediction (pose_model , image_pose , centers , scales , transform = pose_transform )
296
+ then = time .time ()
297
+ print ("Find person pose in: {} sec" .format (then - now ))
266
298
267
299
new_csv_row = []
268
- for _ , mat in enumerate (pose_preds [0 ]):
269
- x_coord , y_coord = int (mat [0 ]), int (mat [1 ])
270
- cv2 .circle (image_bgr , (x_coord , y_coord ), 4 , (255 , 0 , 0 ), 2 )
271
- new_csv_row .extend ([x_coord , y_coord ])
300
+ for coords in pose_preds :
301
+ # Draw each point on image
302
+ for coord in coords :
303
+ x_coord , y_coord = int (coord [0 ]), int (coord [1 ])
304
+ cv2 .circle (image_debug , (x_coord , y_coord ), 4 , (255 , 0 , 0 ), 2 )
305
+ new_csv_row .extend ([x_coord , y_coord ])
306
+
307
+ total_then = time .time ()
308
+
309
+ text = "{:03.2f} sec" .format (total_then - total_now )
310
+ cv2 .putText (image_debug , text , (100 , 50 ), cv2 .FONT_HERSHEY_SIMPLEX ,
311
+ 1 , (0 , 0 , 255 ), 2 , cv2 .LINE_AA )
312
+
313
+ cv2 .imshow ("pos" , image_debug )
314
+ if cv2 .waitKey (1 ) & 0xFF == ord ('q' ):
315
+ break
272
316
273
317
csv_output_rows .append (new_csv_row )
274
- cv2 .imwrite (pose_dir + 'pose%s.jpg' % count_str , image_bgr )
318
+ img_file = os .path .join (pose_dir , 'pose_{:08d}.jpg' .format (count ))
319
+ cv2 .imwrite (img_file , image_debug )
320
+ outcap .write (image_debug )
275
321
276
- # get next frame
277
- success , image_bgr = vidcap .read ()
278
- count += 1
279
322
280
323
# write csv
281
324
csv_headers = ['frame' ]
282
325
for keypoint in COCO_KEYPOINT_INDEXES .values ():
283
326
csv_headers .extend ([keypoint + '_x' , keypoint + '_y' ])
284
327
328
+ csv_output_filename = os .path .join (args .outputDir , 'pose-data.csv' )
285
329
with open (csv_output_filename , 'w' , newline = '' ) as csvfile :
286
330
csvwriter = csv .writer (csvfile )
287
331
csvwriter .writerow (csv_headers )
288
332
csvwriter .writerows (csv_output_rows )
289
333
290
- os .system ("ffmpeg -y -r "
291
- + str (args .inferenceFps )
292
- + " -pattern_type glob -i '"
293
- + pose_dir
294
- + "/*.jpg' -c:v libx264 -vf fps="
295
- + str (args .inferenceFps )+ " -pix_fmt yuv420p /output/movie.mp4" )
334
+ vidcap .release ()
335
+ outcap .release ()
336
+
337
+ cv2 .destroyAllWindows ()
296
338
297
339
298
340
if __name__ == '__main__' :
0 commit comments