Skip to content

Commit fb3e4c2

Browse files
Merge pull request #4 from InfiniteSkyAI/jh.module_refactor
Modified demo script to be class based
2 parents 6e31718 + 3b43ab1 commit fb3e4c2

File tree

1 file changed

+17
-21
lines changed

1 file changed

+17
-21
lines changed

demo/demo.py

+17-21
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torchvision
1515
import cv2
1616
import numpy as np
17-
import time
17+
import os
1818

1919
import _init_paths
2020
import models
@@ -211,16 +211,15 @@ def parse_args():
211211
return args
212212

213213

214-
def main():
214+
def get_deepHRnet_keypoints(video, output_dir=None, output_video=False, save_kpts=False):
215215

216216
keypoints = None
217217
# cudnn related setting
218218
cudnn.benchmark = cfg.CUDNN.BENCHMARK
219219
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
220220
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
221221

222-
args = parse_args()
223-
update_config(cfg, args)
222+
#update_config(cfg, args)
224223

225224
box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
226225
box_model.to(CTX)
@@ -245,16 +244,17 @@ def main():
245244
pose_model.eval()
246245

247246
# Loading an video or an video
248-
vidcap = cv2.VideoCapture(args.video)
249-
save_path = args.output_dir + "/output.avi"
250-
fourcc = cv2.VideoWriter_fourcc(*'XVID')
251-
vid_fps = vidcap.get(cv2.CAP_PROP_FPS)
252-
out = cv2.VideoWriter(save_path,fourcc, vid_fps, (int(vidcap.get(3)),int(vidcap.get(4))))
247+
vidcap = cv2.VideoCapture(video)
248+
vid_name, vid_type = os.path.splitext(video)
249+
if output_dir:
250+
save_path = output_dir + f"/{vid_name}_deephrnet_output.{vid_type}"
251+
fourcc = cv2.VideoWriter_fourcc(*'XVID')
252+
vid_fps = vidcap.get(cv2.CAP_PROP_FPS)
253+
out = cv2.VideoWriter(save_path,fourcc, vid_fps, (int(vidcap.get(3)),int(vidcap.get(4))))
253254

254255
while True:
255256
ret, image_bgr = vidcap.read()
256257
if ret:
257-
last_time = time.time()
258258
image = image_bgr[:, :, [2, 1, 0]]
259259

260260
input = []
@@ -291,25 +291,21 @@ def main():
291291
else:
292292
keypoints = np.append(keypoints, [[[np.nan, np.nan]]*len(COCO_KEYPOINT_INDEXES)], axis=0)
293293

294-
if args.showFps:
295-
fps = 1/(time.time()-last_time)
296-
img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2)
297-
298-
if args.write:
294+
if output_video:
299295
out.write(image_bgr)
300296

301297
else:
302298
print('Video ended')
303299
break
300+
301+
if save_kpts:
302+
np.save(f"{output_dir}/keypoints", keypoints)
303+
print(f'keypoint saved to {output_dir}/keypoints.npy')
304304

305-
np.save(f"{args.output_dir}/keypoints", keypoints)
306-
print(f'keypoint saved to {args.output_dir}/keypoints.npy')
307305
cv2.destroyAllWindows()
308306
vidcap.release()
309-
if args.write:
307+
if output_video:
310308
print('video has been saved as {}'.format(save_path))
311309
out.release()
312310

313-
314-
if __name__ == '__main__':
315-
main()
311+
return keypoints

0 commit comments

Comments
 (0)