Skip to content

Commit 543c816

Browse files
Merge pull request #1 from InfiniteSkyAI/jh.no_model_saved
Jh.no model saved
2 parents 3ba4125 + b3386c1 commit 543c816

File tree

2 files changed

+65
-105
lines changed

2 files changed

+65
-105
lines changed

demo/demo.py

+60-101
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,7 @@
33
from __future__ import print_function
44

55
import argparse
6-
import csv
7-
import os
8-
import shutil
96

10-
from PIL import Image
117
import torch
128
import torch.nn.parallel
139
import torch.backends.cudnn as cudnn
@@ -20,7 +16,6 @@
2016
import numpy as np
2117
import time
2218

23-
2419
import _init_paths
2520
import models
2621
from config import cfg
@@ -107,7 +102,7 @@ def get_person_detection_boxes(model, img, threshold=0.5):
107102
return []
108103
# Get list of index with score greater than threshold
109104
pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]
110-
pred_boxes = pred_boxes[:pred_t+1]
105+
pred_boxes = pred_boxes[:1]
111106
pred_classes = pred_classes[:pred_t+1]
112107

113108
person_boxes = []
@@ -191,6 +186,7 @@ def box_to_center_scale(box, model_image_width, model_image_height):
191186

192187
return center, scale
193188

189+
194190
def parse_args():
195191
parser = argparse.ArgumentParser(description='Train keypoints network')
196192
# general
@@ -200,6 +196,7 @@ def parse_args():
200196
parser.add_argument('--image',type=str)
201197
parser.add_argument('--write',action='store_true')
202198
parser.add_argument('--showFps',action='store_true')
199+
parser.add_argument('--output_dir',type=str, default='/')
203200

204201
parser.add_argument('opts',
205202
help='Modify config options using the command-line',
@@ -217,6 +214,8 @@ def parse_args():
217214

218215

219216
def main():
217+
218+
keypoints = None
220219
# cudnn related setting
221220
cudnn.benchmark = cfg.CUDNN.BENCHMARK
222221
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
@@ -243,101 +242,61 @@ def main():
243242
pose_model.to(CTX)
244243
pose_model.eval()
245244

246-
# Loading an video or an image or webcam
247-
if args.webcam:
248-
vidcap = cv2.VideoCapture(0)
249-
elif args.video:
250-
vidcap = cv2.VideoCapture(args.video)
251-
elif args.image:
252-
image_bgr = cv2.imread(args.image)
253-
else:
254-
print('please use --video or --webcam or --image to define the input.')
255-
return
256-
257-
if args.webcam or args.video:
258-
if args.write:
259-
save_path = 'output.avi'
260-
fourcc = cv2.VideoWriter_fourcc(*'XVID')
261-
out = cv2.VideoWriter(save_path,fourcc, 24.0, (int(vidcap.get(3)),int(vidcap.get(4))))
262-
while True:
263-
ret, image_bgr = vidcap.read()
264-
if ret:
265-
last_time = time.time()
266-
image = image_bgr[:, :, [2, 1, 0]]
267-
268-
input = []
269-
img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
270-
img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().to(CTX)
271-
input.append(img_tensor)
272-
273-
# object detection box
274-
pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.9)
275-
276-
# pose estimation
277-
if len(pred_boxes) >= 1:
278-
for box in pred_boxes:
279-
center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
280-
image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy()
281-
pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale)
282-
if len(pose_preds)>=1:
283-
for kpt in pose_preds:
284-
draw_pose(kpt,image_bgr) # draw the poses
285-
286-
if args.showFps:
287-
fps = 1/(time.time()-last_time)
288-
img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2)
289-
290-
if args.write:
291-
out.write(image_bgr)
292-
293-
cv2.imshow('demo',image_bgr)
294-
if cv2.waitKey(1) & 0XFF==ord('q'):
295-
break
296-
else:
297-
print('cannot load the video.')
298-
break
299-
300-
cv2.destroyAllWindows()
301-
vidcap.release()
302-
if args.write:
303-
print('video has been saved as {}'.format(save_path))
304-
out.release()
245+
# Loading an video or an video
246+
vidcap = cv2.VideoCapture(args.video)
247+
save_path = args.output_dir + "/output.avi"
248+
fourcc = cv2.VideoWriter_fourcc(*'XVID')
249+
vid_fps = vidcap.get(cv2.CAP_PROP_FPS)
250+
out = cv2.VideoWriter(save_path,fourcc, vid_fps, (int(vidcap.get(3)),int(vidcap.get(4))))
251+
252+
while True:
253+
ret, image_bgr = vidcap.read()
254+
if ret:
255+
last_time = time.time()
256+
image = image_bgr[:, :, [2, 1, 0]]
257+
258+
input = []
259+
img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
260+
img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().to(CTX)
261+
input.append(img_tensor)
262+
263+
# object detection box
264+
pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.95)
265+
266+
# pose estimation
267+
if len(pred_boxes) >= 1:
268+
for box in pred_boxes:
269+
center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
270+
image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy()
271+
pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale)
272+
if len(pose_preds)>=1:
273+
for i, kpt in enumerate(pose_preds):
274+
name = COCO_KEYPOINT_INDEXES[i]
275+
if keypoints is None:
276+
keypoints = np.array([kpt])
277+
else:
278+
keypoints = np.append(keypoints, [kpt], axis = 0)
279+
draw_pose(kpt,image_bgr) # draw the poses
280+
281+
if args.showFps:
282+
fps = 1/(time.time()-last_time)
283+
img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2)
284+
285+
if args.write:
286+
out.write(image_bgr)
287+
288+
else:
289+
print('Video ended')
290+
break
291+
292+
np.save(f"{args.output_dir}/keypoints", keypoints)
293+
print(f'keypoint saved to {args.output_dir}/keypoints.npy')
294+
cv2.destroyAllWindows()
295+
vidcap.release()
296+
if args.write:
297+
print('video has been saved as {}'.format(save_path))
298+
out.release()
299+
305300

306-
else:
307-
# estimate on the image
308-
last_time = time.time()
309-
image = image_bgr[:, :, [2, 1, 0]]
310-
311-
input = []
312-
img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
313-
img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().to(CTX)
314-
input.append(img_tensor)
315-
316-
# object detection box
317-
pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.9)
318-
319-
# pose estimation
320-
if len(pred_boxes) >= 1:
321-
for box in pred_boxes:
322-
center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
323-
image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy()
324-
pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale)
325-
if len(pose_preds)>=1:
326-
for kpt in pose_preds:
327-
draw_pose(kpt,image_bgr) # draw the poses
328-
329-
if args.showFps:
330-
fps = 1/(time.time()-last_time)
331-
img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2)
332-
333-
if args.write:
334-
save_path = 'output.jpg'
335-
cv2.imwrite(save_path,image_bgr)
336-
print('the result image has been saved as {}'.format(save_path))
337-
338-
cv2.imshow('demo',image_bgr)
339-
if cv2.waitKey(0) & 0XFF==ord('q'):
340-
cv2.destroyAllWindows()
341-
342301
if __name__ == '__main__':
343302
main()

requirements.txt

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
EasyDict==1.7
2-
opencv-python==3.4.1.15
3-
shapely==1.6.4
1+
EasyDict>=1.7
2+
opencv-python
3+
shapely>=1.6.4
44
Cython
55
scipy
66
pandas
77
pyyaml
88
json_tricks
99
scikit-image
1010
yacs>=0.1.5
11-
tensorboardX==1.6
11+
tensorboardX>=1.6
12+
torchvision

0 commit comments

Comments
 (0)