-
Notifications
You must be signed in to change notification settings - Fork 209
/
Copy pathskeleton_based_demo.py
217 lines (187 loc) · 6.83 KB
/
skeleton_based_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import argparse
import os
import os.path as osp
import shutil
import cv2
import mmcv
import numpy as np
import torch
from easycv.file.utils import is_url_path
from easycv.predictors.pose_predictor import PoseTopDownPredictor
from easycv.predictors.video_classifier import STGCNPredictor
try:
import moviepy.editor as mpy
except ImportError:
raise ImportError('Please install moviepy to enable output file')
FONTFACE = cv2.FONT_HERSHEY_DUPLEX
FONTSCALE = 0.75
FONTCOLOR = (255, 255, 255) # BGR, white
THICKNESS = 1
LINETYPE = 1
TMP_DIR = './tmp'
def parse_args():
parser = argparse.ArgumentParser(
description='Video classification demo based skeleton.')
parser.add_argument(
'--video',
default=
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/demos/videos/ntu_sample.avi',
help='video file/url')
parser.add_argument(
'--out_file',
default=f'{TMP_DIR}/demo_show.mp4',
help='output filename')
parser.add_argument(
'--config',
default=(
'configs/video_recognition/stgcn/stgcn_80e_ntu60_xsub_keypoint.py'
),
help='skeleton model config file path')
parser.add_argument(
'--checkpoint',
default=
('http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/video/skeleton_based/stgcn/stgcn_80e_ntu60_xsub.pth'
),
help='skeleton model checkpoint file/url')
parser.add_argument(
'--det-config',
default='configs/detection/yolox/yolox_s_8xb16_300e_coco.py',
help='human detection config file path')
parser.add_argument(
'--det-checkpoint',
default=
('http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/yolox/yolox_s_bs16_lr002/epoch_300.pt'
),
help='human detection checkpoint file/url')
parser.add_argument(
'--det-predictor-type',
default='YoloXPredictor',
help='detection predictor type')
parser.add_argument(
'--pose-config',
default='configs/pose/hrnet_w48_coco_256x192_udp.py',
help='human pose estimation config file path')
parser.add_argument(
'--pose-checkpoint',
default=
('http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/pose/top_down_hrnet/pose_hrnet_epoch_210_export.pt'
),
help='human pose estimation checkpoint file/url')
parser.add_argument(
'--bbox-thr',
type=float,
default=0.5,
help='the threshold of human detection score')
parser.add_argument(
'--device', type=str, default='cuda:0', help='CPU/CUDA device option')
parser.add_argument(
'--short-side',
type=int,
default=480,
help='specify the short-side length of the image')
args = parser.parse_args()
return args
def frame_extraction(video_path, short_side):
"""Extract frames given video_path.
Args:
video_path (str): The video_path.
"""
if is_url_path(video_path):
from torch.hub import download_url_to_file
cache_video_path = os.path.join(TMP_DIR, os.path.basename(video_path))
print(
'Download video file from remote to local path "{cache_video_path}"...'
)
download_url_to_file(video_path, cache_video_path)
video_path = cache_video_path
# Load the video, extract frames into ./tmp/video_name
target_dir = osp.join(TMP_DIR, osp.basename(osp.splitext(video_path)[0]))
os.makedirs(target_dir, exist_ok=True)
# Should be able to handle videos up to several hours
frame_tmpl = osp.join(target_dir, 'img_{:06d}.jpg')
vid = cv2.VideoCapture(video_path)
frames = []
frame_paths = []
flag, frame = vid.read()
cnt = 0
new_h, new_w = None, None
while flag:
if new_h is None:
h, w, _ = frame.shape
new_w, new_h = mmcv.rescale_size((w, h), (short_side, np.Inf))
frame = mmcv.imresize(frame, (new_w, new_h))
frames.append(frame)
frame_path = frame_tmpl.format(cnt + 1)
frame_paths.append(frame_path)
cv2.imwrite(frame_path, frame)
cnt += 1
flag, frame = vid.read()
return frame_paths, frames
def main():
args = parse_args()
if not osp.exists(TMP_DIR):
os.makedirs(TMP_DIR)
frame_paths, original_frames = frame_extraction(args.video,
args.short_side)
num_frame = len(frame_paths)
h, w, _ = original_frames[0].shape
# Get Human detection results
pose_predictor = PoseTopDownPredictor(
model_path=args.pose_checkpoint,
config_file=args.pose_config,
detection_predictor_config=dict(
type=args.det_predictor_type,
model_path=args.det_checkpoint,
config_file=args.det_config,
),
bbox_thr=args.bbox_thr,
cat_id=0, # person category id
)
video_cls_predictor = STGCNPredictor(
model_path=args.checkpoint,
config_file=args.config,
ori_image_size=(w, h),
label_map=None)
pose_results = pose_predictor(original_frames)
torch.cuda.empty_cache()
fake_anno = dict(
frame_dir='',
label=-1,
img_shape=(h, w),
original_shape=(h, w),
start_index=0,
modality='Pose',
total_frames=num_frame)
num_person = max([len(x) for x in pose_results])
num_keypoint = 17
keypoints = np.zeros((num_person, num_frame, num_keypoint, 2),
dtype=np.float16)
keypoints_score = np.zeros((num_person, num_frame, num_keypoint),
dtype=np.float16)
for i, poses in enumerate(pose_results):
if len(poses) < 1:
continue
_keypoint = poses['keypoints'] # shape = (num_person, num_keypoint, 3)
for j, pose in enumerate(_keypoint):
keypoints[j, i] = pose[:, :2]
keypoints_score[j, i] = pose[:, 2]
fake_anno['keypoint'] = keypoints
fake_anno['keypoint_score'] = keypoints_score
results = video_cls_predictor([fake_anno])
action_label = results[0]['class_name'][0]
print(f'action label: {action_label}')
vis_frames = [
pose_predictor.show_result(original_frames[i], pose_results[i])
if len(pose_results[i]) > 0 else original_frames[i]
for i in range(num_frame)
]
for frame in vis_frames:
cv2.putText(frame, action_label, (10, 30), FONTFACE, FONTSCALE,
FONTCOLOR, THICKNESS, LINETYPE)
vid = mpy.ImageSequenceClip([x[:, :, ::-1] for x in vis_frames], fps=24)
vid.write_videofile(args.out_file, remove_temp=True)
print(f'Write video to {args.out_file} successfully!')
tmp_frame_dir = osp.dirname(frame_paths[0])
shutil.rmtree(tmp_frame_dir)
if __name__ == '__main__':
main()