3
3
from __future__ import print_function
4
4
5
5
import argparse
6
- import csv
7
- import os
8
- import shutil
9
6
10
- from PIL import Image
11
7
import torch
12
8
import torch .nn .parallel
13
9
import torch .backends .cudnn as cudnn
20
16
import numpy as np
21
17
import time
22
18
23
-
24
19
import _init_paths
25
20
import models
26
21
from config import cfg
@@ -107,7 +102,7 @@ def get_person_detection_boxes(model, img, threshold=0.5):
107
102
return []
108
103
# Get list of index with score greater than threshold
109
104
pred_t = [pred_score .index (x ) for x in pred_score if x > threshold ][- 1 ]
110
- pred_boxes = pred_boxes [:pred_t + 1 ]
105
+ pred_boxes = pred_boxes [:1 ]
111
106
pred_classes = pred_classes [:pred_t + 1 ]
112
107
113
108
person_boxes = []
@@ -191,6 +186,7 @@ def box_to_center_scale(box, model_image_width, model_image_height):
191
186
192
187
return center , scale
193
188
189
+
194
190
def parse_args ():
195
191
parser = argparse .ArgumentParser (description = 'Train keypoints network' )
196
192
# general
@@ -200,6 +196,7 @@ def parse_args():
200
196
parser .add_argument ('--image' ,type = str )
201
197
parser .add_argument ('--write' ,action = 'store_true' )
202
198
parser .add_argument ('--showFps' ,action = 'store_true' )
199
+ parser .add_argument ('--output_dir' ,type = str , default = '/' )
203
200
204
201
parser .add_argument ('opts' ,
205
202
help = 'Modify config options using the command-line' ,
@@ -217,6 +214,8 @@ def parse_args():
217
214
218
215
219
216
def main ():
217
+
218
+ keypoints = None
220
219
# cudnn related setting
221
220
cudnn .benchmark = cfg .CUDNN .BENCHMARK
222
221
torch .backends .cudnn .deterministic = cfg .CUDNN .DETERMINISTIC
@@ -243,101 +242,61 @@ def main():
243
242
pose_model .to (CTX )
244
243
pose_model .eval ()
245
244
246
- # Loading an video or an image or webcam
247
- if args .webcam :
248
- vidcap = cv2 .VideoCapture (0 )
249
- elif args .video :
250
- vidcap = cv2 .VideoCapture (args .video )
251
- elif args .image :
252
- image_bgr = cv2 .imread (args .image )
253
- else :
254
- print ('please use --video or --webcam or --image to define the input.' )
255
- return
256
-
257
- if args .webcam or args .video :
258
- if args .write :
259
- save_path = 'output.avi'
260
- fourcc = cv2 .VideoWriter_fourcc (* 'XVID' )
261
- out = cv2 .VideoWriter (save_path ,fourcc , 24.0 , (int (vidcap .get (3 )),int (vidcap .get (4 ))))
262
- while True :
263
- ret , image_bgr = vidcap .read ()
264
- if ret :
265
- last_time = time .time ()
266
- image = image_bgr [:, :, [2 , 1 , 0 ]]
267
-
268
- input = []
269
- img = cv2 .cvtColor (image_bgr , cv2 .COLOR_BGR2RGB )
270
- img_tensor = torch .from_numpy (img / 255. ).permute (2 ,0 ,1 ).float ().to (CTX )
271
- input .append (img_tensor )
272
-
273
- # object detection box
274
- pred_boxes = get_person_detection_boxes (box_model , input , threshold = 0.9 )
275
-
276
- # pose estimation
277
- if len (pred_boxes ) >= 1 :
278
- for box in pred_boxes :
279
- center , scale = box_to_center_scale (box , cfg .MODEL .IMAGE_SIZE [0 ], cfg .MODEL .IMAGE_SIZE [1 ])
280
- image_pose = image .copy () if cfg .DATASET .COLOR_RGB else image_bgr .copy ()
281
- pose_preds = get_pose_estimation_prediction (pose_model , image_pose , center , scale )
282
- if len (pose_preds )>= 1 :
283
- for kpt in pose_preds :
284
- draw_pose (kpt ,image_bgr ) # draw the poses
285
-
286
- if args .showFps :
287
- fps = 1 / (time .time ()- last_time )
288
- img = cv2 .putText (image_bgr , 'fps: ' + "%.2f" % (fps ), (25 , 40 ), cv2 .FONT_HERSHEY_SIMPLEX , 1.2 , (0 , 255 , 0 ), 2 )
289
-
290
- if args .write :
291
- out .write (image_bgr )
292
-
293
- cv2 .imshow ('demo' ,image_bgr )
294
- if cv2 .waitKey (1 ) & 0XFF == ord ('q' ):
295
- break
296
- else :
297
- print ('cannot load the video.' )
298
- break
299
-
300
- cv2 .destroyAllWindows ()
301
- vidcap .release ()
302
- if args .write :
303
- print ('video has been saved as {}' .format (save_path ))
304
- out .release ()
245
+ # Loading an video or an video
246
+ vidcap = cv2 .VideoCapture (args .video )
247
+ save_path = args .output_dir + "/output.avi"
248
+ fourcc = cv2 .VideoWriter_fourcc (* 'XVID' )
249
+ vid_fps = vidcap .get (cv2 .CAP_PROP_FPS )
250
+ out = cv2 .VideoWriter (save_path ,fourcc , vid_fps , (int (vidcap .get (3 )),int (vidcap .get (4 ))))
251
+
252
+ while True :
253
+ ret , image_bgr = vidcap .read ()
254
+ if ret :
255
+ last_time = time .time ()
256
+ image = image_bgr [:, :, [2 , 1 , 0 ]]
257
+
258
+ input = []
259
+ img = cv2 .cvtColor (image_bgr , cv2 .COLOR_BGR2RGB )
260
+ img_tensor = torch .from_numpy (img / 255. ).permute (2 ,0 ,1 ).float ().to (CTX )
261
+ input .append (img_tensor )
262
+
263
+ # object detection box
264
+ pred_boxes = get_person_detection_boxes (box_model , input , threshold = 0.95 )
265
+
266
+ # pose estimation
267
+ if len (pred_boxes ) >= 1 :
268
+ for box in pred_boxes :
269
+ center , scale = box_to_center_scale (box , cfg .MODEL .IMAGE_SIZE [0 ], cfg .MODEL .IMAGE_SIZE [1 ])
270
+ image_pose = image .copy () if cfg .DATASET .COLOR_RGB else image_bgr .copy ()
271
+ pose_preds = get_pose_estimation_prediction (pose_model , image_pose , center , scale )
272
+ if len (pose_preds )>= 1 :
273
+ for i , kpt in enumerate (pose_preds ):
274
+ name = COCO_KEYPOINT_INDEXES [i ]
275
+ if keypoints is None :
276
+ keypoints = np .array ([kpt ])
277
+ else :
278
+ keypoints = np .append (keypoints , [kpt ], axis = 0 )
279
+ draw_pose (kpt ,image_bgr ) # draw the poses
280
+
281
+ if args .showFps :
282
+ fps = 1 / (time .time ()- last_time )
283
+ img = cv2 .putText (image_bgr , 'fps: ' + "%.2f" % (fps ), (25 , 40 ), cv2 .FONT_HERSHEY_SIMPLEX , 1.2 , (0 , 255 , 0 ), 2 )
284
+
285
+ if args .write :
286
+ out .write (image_bgr )
287
+
288
+ else :
289
+ print ('Video ended' )
290
+ break
291
+
292
+ np .save (f"{ args .output_dir } /keypoints" , keypoints )
293
+ print (f'keypoint saved to { args .output_dir } /keypoints.npy' )
294
+ cv2 .destroyAllWindows ()
295
+ vidcap .release ()
296
+ if args .write :
297
+ print ('video has been saved as {}' .format (save_path ))
298
+ out .release ()
299
+
305
300
306
- else :
307
- # estimate on the image
308
- last_time = time .time ()
309
- image = image_bgr [:, :, [2 , 1 , 0 ]]
310
-
311
- input = []
312
- img = cv2 .cvtColor (image_bgr , cv2 .COLOR_BGR2RGB )
313
- img_tensor = torch .from_numpy (img / 255. ).permute (2 ,0 ,1 ).float ().to (CTX )
314
- input .append (img_tensor )
315
-
316
- # object detection box
317
- pred_boxes = get_person_detection_boxes (box_model , input , threshold = 0.9 )
318
-
319
- # pose estimation
320
- if len (pred_boxes ) >= 1 :
321
- for box in pred_boxes :
322
- center , scale = box_to_center_scale (box , cfg .MODEL .IMAGE_SIZE [0 ], cfg .MODEL .IMAGE_SIZE [1 ])
323
- image_pose = image .copy () if cfg .DATASET .COLOR_RGB else image_bgr .copy ()
324
- pose_preds = get_pose_estimation_prediction (pose_model , image_pose , center , scale )
325
- if len (pose_preds )>= 1 :
326
- for kpt in pose_preds :
327
- draw_pose (kpt ,image_bgr ) # draw the poses
328
-
329
- if args .showFps :
330
- fps = 1 / (time .time ()- last_time )
331
- img = cv2 .putText (image_bgr , 'fps: ' + "%.2f" % (fps ), (25 , 40 ), cv2 .FONT_HERSHEY_SIMPLEX , 1.2 , (0 , 255 , 0 ), 2 )
332
-
333
- if args .write :
334
- save_path = 'output.jpg'
335
- cv2 .imwrite (save_path ,image_bgr )
336
- print ('the result image has been saved as {}' .format (save_path ))
337
-
338
- cv2 .imshow ('demo' ,image_bgr )
339
- if cv2 .waitKey (0 ) & 0XFF == ord ('q' ):
340
- cv2 .destroyAllWindows ()
341
-
342
301
if __name__ == '__main__' :
343
302
main ()
0 commit comments