Skip to content

Commit 69934d5

Browse files
gachiemchiepleoxiaobin
authored andcommitted
Re-write demo/inference.py
* Now can work with multiple person * Don't need the docker * Force people detection and pose estimation to be executed on same context (gpu or cpu) * Add same example output
1 parent be81104 commit 69934d5

File tree

5 files changed

+154
-116
lines changed

5 files changed

+154
-116
lines changed

demo/README.md

+25-29
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,37 @@
1-
This demo code is meant to be run on a video and includes a person detector.
2-
[Nvidia-docker](https://github.com/NVIDIA/nvidia-docker) and GPUs are required.
3-
It only expects there to be one person in each frame of video, though the code could easily be extended to support multiple people.
1+
# Inference hrnet
42

5-
### Prep
3+
Inferencing the deep-high-resolution-net.pytoch without using Docker.
4+
5+
## Prep
66
1. Download the researchers' pretrained pose estimator from [google drive](https://drive.google.com/drive/folders/1hOTihvbyIxsm5ygDpbUuJ7O_tzv4oXjC?usp=sharing) to this directory under `models/`
77
2. Put the video file you'd like to infer on in this directory under `videos`
88
3. build the docker container in this directory with `./build-docker.sh` (this can take time because it involves compiling opencv)
99
4. update the `inference-config.yaml` file to reflect the number of GPUs you have available
1010

11-
### Running the Model
12-
Start your docker container with:
13-
```
14-
nvidia-docker run --rm -it \
15-
-v $(pwd)/output:/output \
16-
-v $(pwd)/videos:/videos \
17-
-v $(pwd)/models:/models \
18-
-w /pose_root \
19-
hrnet_demo_inference \
20-
/bin/bash
11+
## Running the Model
2112
```
13+
python inference.py --cfg inference-config.yaml \
14+
--videoFile ../../multi_people.mp4 \
15+
--writeBoxFrames \
16+
--outputDir output \
17+
TEST.MODEL_FILE ../models/pytorch/pose_coco/pose_hrnet_w32_256x192.pth
2218
23-
Once the container is running, you can run inference with:
24-
```
25-
python tools/inference.py \
26-
--cfg inference-config.yaml \
27-
--videoFile /videos/my-video.mp4 \
28-
--inferenceFps 10 \
29-
--writeBoxFrames \
30-
TEST.MODEL_FILE \
31-
/models/pytorch/pose_coco/pose_hrnet_w32_384x288.pth
3219
```
3320

34-
The command above will output frames with boxes,
35-
frames with poses,
36-
a video with poses,
37-
and a csv with the keypoint coordinates for each frame.
21+
The above command will create a video under *output* directory and a lot of pose image under *output/pose* directory.
22+
Even with usage of GPU (GTX1080 in my case), the person detection will take nearly **0.06 sec**, the person pose match will
23+
take nearly **0.07 sec**. In total. inference time per frame will be **0.13 sec**, nearly 10fps. So if you prefer a real-time (fps >= 20)
24+
pose estimation then you should try other approach.
25+
26+
## Result
27+
28+
Some output image is as:
29+
30+
![1 person](inference_1.jpg)
31+
Fig: 1 person inference
3832

39-
![](hrnet-demo.gif)
33+
![3 person](inference_3.jpg)
34+
Fig: 3 person inference
4035

41-
Original source for demo video above is licensed for `Free for commercial use No attribution required` by [Pixabay](https://pixabay.com/service/license/)
36+
![3 person](inference_5.jpg)
37+
Fig: 3 person inference

demo/inference.py

+129-87
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,20 @@
1919
import cv2
2020
import numpy as np
2121

22+
import sys
23+
sys.path.append("../lib")
24+
import time
2225

23-
import _init_paths
26+
# import _init_paths
2427
import models
2528
from config import cfg
2629
from config import update_config
27-
from core.function import get_final_preds
30+
from core.inference import get_final_preds
2831
from utils.transforms import get_affine_transform
2932

33+
CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
34+
35+
3036
COCO_KEYPOINT_INDEXES = {
3137
0: 'nose',
3238
1: 'left_eye',
@@ -67,57 +73,53 @@ def get_person_detection_boxes(model, img, threshold=0.5):
6773
pil_image = Image.fromarray(img) # Load the image
6874
transform = transforms.Compose([transforms.ToTensor()]) # Defing PyTorch Transform
6975
transformed_img = transform(pil_image) # Apply the transform to the image
70-
pred = model([transformed_img]) # Pass the image to the model
76+
pred = model([transformed_img.to(CTX)]) # Pass the image to the model
77+
# Use the first detected person
7178
pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i]
72-
for i in list(pred[0]['labels'].numpy())] # Get the Prediction Score
79+
for i in list(pred[0]['labels'].cpu().numpy())] # Get the Prediction Score
7380
pred_boxes = [[(i[0], i[1]), (i[2], i[3])]
74-
for i in list(pred[0]['boxes'].detach().numpy())] # Bounding boxes
75-
pred_score = list(pred[0]['scores'].detach().numpy())
76-
if not pred_score:
77-
return []
78-
# Get list of index with score greater than threshold
79-
pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]
80-
pred_boxes = pred_boxes[:pred_t+1]
81-
pred_classes = pred_classes[:pred_t+1]
81+
for i in list(pred[0]['boxes'].cpu().detach().numpy())] # Bounding boxes
82+
pred_scores = list(pred[0]['scores'].cpu().detach().numpy())
8283

8384
person_boxes = []
84-
for idx, box in enumerate(pred_boxes):
85-
if pred_classes[idx] == 'person':
86-
person_boxes.append(box)
85+
# Select box has score larger than threshold and is person
86+
for pred_class, pred_box, pred_score in zip(pred_classes, pred_boxes, pred_scores):
87+
if (pred_score > threshold) and (pred_class == 'person'):
88+
person_boxes.append(pred_box)
8789

8890
return person_boxes
8991

9092

91-
def get_pose_estimation_prediction(pose_model, image, center, scale):
93+
def get_pose_estimation_prediction(pose_model, image, centers, scales, transform):
9294
rotation = 0
9395

9496
# pose estimation transformation
95-
trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE)
96-
model_input = cv2.warpAffine(
97-
image,
98-
trans,
99-
(int(cfg.MODEL.IMAGE_SIZE[0]), int(cfg.MODEL.IMAGE_SIZE[1])),
100-
flags=cv2.INTER_LINEAR)
101-
transform = transforms.Compose([
102-
transforms.ToTensor(),
103-
transforms.Normalize(mean=[0.485, 0.456, 0.406],
104-
std=[0.229, 0.224, 0.225]),
105-
])
106-
107-
# pose estimation inference
108-
model_input = transform(model_input).unsqueeze(0)
109-
# switch to evaluate mode
110-
pose_model.eval()
111-
with torch.no_grad():
112-
# compute output heatmap
113-
output = pose_model(model_input)
114-
preds, _ = get_final_preds(
115-
cfg,
116-
output.clone().cpu().numpy(),
117-
np.asarray([center]),
118-
np.asarray([scale]))
119-
120-
return preds
97+
model_inputs = []
98+
for center, scale in zip(centers, scales):
99+
trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE)
100+
# Crop smaller image of people
101+
model_input = cv2.warpAffine(
102+
image,
103+
trans,
104+
(int(cfg.MODEL.IMAGE_SIZE[0]), int(cfg.MODEL.IMAGE_SIZE[1])),
105+
flags=cv2.INTER_LINEAR)
106+
107+
# hwc -> 1chw
108+
model_input = transform(model_input)#.unsqueeze(0)
109+
model_inputs.append(model_input)
110+
111+
# n * 1chw -> nchw
112+
model_inputs = torch.stack(model_inputs)
113+
114+
# compute output heatmap
115+
output = pose_model(model_inputs.to(CTX))
116+
coords, _ = get_final_preds(
117+
cfg,
118+
output.cpu().detach().numpy(),
119+
np.asarray(centers),
120+
np.asarray(scales))
121+
122+
return coords
121123

122124

123125
def box_to_center_scale(box, model_image_width, model_image_height):
@@ -163,15 +165,11 @@ def box_to_center_scale(box, model_image_width, model_image_height):
163165

164166

165167
def prepare_output_dirs(prefix='/output/'):
166-
pose_dir = prefix+'poses/'
167-
box_dir = prefix+'boxes/'
168+
pose_dir = os.path.join(prefix, "pose")
168169
if os.path.exists(pose_dir) and os.path.isdir(pose_dir):
169170
shutil.rmtree(pose_dir)
170-
if os.path.exists(box_dir) and os.path.isdir(box_dir):
171-
shutil.rmtree(box_dir)
172171
os.makedirs(pose_dir, exist_ok=True)
173-
os.makedirs(box_dir, exist_ok=True)
174-
return pose_dir, box_dir
172+
return pose_dir
175173

176174

177175
def parse_args():
@@ -199,20 +197,26 @@ def parse_args():
199197

200198

201199
def main():
200+
# transformation
201+
pose_transform = transforms.Compose([
202+
transforms.ToTensor(),
203+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
204+
std=[0.229, 0.224, 0.225]),
205+
])
206+
202207
# cudnn related setting
203208
cudnn.benchmark = cfg.CUDNN.BENCHMARK
204209
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
205210
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
206211

207212
args = parse_args()
208213
update_config(cfg, args)
209-
pose_dir, box_dir = prepare_output_dirs(args.outputDir)
210-
csv_output_filename = args.outputDir+'pose-data.csv'
214+
pose_dir = prepare_output_dirs(args.outputDir)
211215
csv_output_rows = []
212216

213217
box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
218+
box_model.to(CTX)
214219
box_model.eval()
215-
216220
pose_model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
217221
cfg, is_train=False
218222
)
@@ -223,76 +227,114 @@ def main():
223227
else:
224228
print('expected model defined in config at TEST.MODEL_FILE')
225229

226-
pose_model = torch.nn.DataParallel(pose_model, device_ids=cfg.GPUS).cuda()
230+
pose_model.to(CTX)
231+
pose_model.eval()
227232

228233
# Loading an video
229234
vidcap = cv2.VideoCapture(args.videoFile)
230235
fps = vidcap.get(cv2.CAP_PROP_FPS)
231236
if fps < args.inferenceFps:
232237
print('desired inference fps is '+str(args.inferenceFps)+' but video fps is '+str(fps))
233238
exit()
234-
every_nth_frame = round(fps/args.inferenceFps)
239+
skip_frame_cnt = round(fps / args.inferenceFps)
240+
frame_width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
241+
frame_height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
242+
outcap = cv2.VideoWriter('{}/{}_pose.avi'.format(args.outputDir, os.path.splitext(os.path.basename(args.videoFile))[0]),
243+
cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), int(skip_frame_cnt), (frame_width, frame_height))
235244

236-
success, image_bgr = vidcap.read()
237245
count = 0
246+
while vidcap.isOpened():
247+
total_now = time.time()
248+
ret, image_bgr = vidcap.read()
249+
count += 1
238250

239-
while success:
240-
if count % every_nth_frame != 0:
241-
success, image_bgr = vidcap.read()
242-
count += 1
251+
if not ret:
243252
continue
244253

245-
image = image_bgr[:, :, [2, 1, 0]]
246-
count_str = str(count).zfill(32)
254+
if count % skip_frame_cnt != 0:
255+
continue
256+
257+
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
258+
259+
# Clone 2 image for person detection and pose estimation
260+
if cfg.DATASET.COLOR_RGB:
261+
image_per = image_rgb.copy()
262+
image_pose = image_rgb.copy()
263+
else:
264+
image_per = image_bgr.copy()
265+
image_pose = image_bgr.copy()
266+
267+
# Clone 1 image for debugging purpose
268+
image_debug = image_bgr.copy()
247269

248270
# object detection box
249-
pred_boxes = get_person_detection_boxes(box_model, image, threshold=0.8)
250-
if args.writeBoxFrames:
251-
image_bgr_box = image_bgr.copy()
252-
for box in pred_boxes:
253-
cv2.rectangle(image_bgr_box, box[0], box[1], color=(0, 255, 0),
254-
thickness=3) # Draw Rectangle with the coordinates
255-
cv2.imwrite(box_dir+'box%s.jpg' % count_str, image_bgr_box)
271+
now = time.time()
272+
pred_boxes = get_person_detection_boxes(box_model, image_per, threshold=0.9)
273+
then = time.time()
274+
print("Find person bbox in: {} sec".format(then - now))
275+
276+
# Can not find people. Move to next frame
256277
if not pred_boxes:
257-
success, image_bgr = vidcap.read()
258278
count += 1
259279
continue
260280

261-
# pose estimation
262-
box = pred_boxes[0] # assume there is only 1 person
263-
center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
264-
image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy()
265-
pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale)
281+
if args.writeBoxFrames:
282+
for box in pred_boxes:
283+
cv2.rectangle(image_debug, box[0], box[1], color=(0, 255, 0),
284+
thickness=3) # Draw Rectangle with the coordinates
285+
286+
# pose estimation : for multiple people
287+
centers = []
288+
scales = []
289+
for box in pred_boxes:
290+
center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
291+
centers.append(center)
292+
scales.append(scale)
293+
294+
now = time.time()
295+
pose_preds = get_pose_estimation_prediction(pose_model, image_pose, centers, scales, transform=pose_transform)
296+
then = time.time()
297+
print("Find person pose in: {} sec".format(then - now))
266298

267299
new_csv_row = []
268-
for _, mat in enumerate(pose_preds[0]):
269-
x_coord, y_coord = int(mat[0]), int(mat[1])
270-
cv2.circle(image_bgr, (x_coord, y_coord), 4, (255, 0, 0), 2)
271-
new_csv_row.extend([x_coord, y_coord])
300+
for coords in pose_preds:
301+
# Draw each point on image
302+
for coord in coords:
303+
x_coord, y_coord = int(coord[0]), int(coord[1])
304+
cv2.circle(image_debug, (x_coord, y_coord), 4, (255, 0, 0), 2)
305+
new_csv_row.extend([x_coord, y_coord])
306+
307+
total_then = time.time()
308+
309+
text = "{:03.2f} sec".format(total_then - total_now)
310+
cv2.putText(image_debug, text, (100, 50), cv2.FONT_HERSHEY_SIMPLEX,
311+
1, (0, 0, 255), 2, cv2.LINE_AA)
312+
313+
cv2.imshow("pos", image_debug)
314+
if cv2.waitKey(1) & 0xFF == ord('q'):
315+
break
272316

273317
csv_output_rows.append(new_csv_row)
274-
cv2.imwrite(pose_dir+'pose%s.jpg' % count_str, image_bgr)
318+
img_file = os.path.join(pose_dir, 'pose_{:08d}.jpg'.format(count))
319+
cv2.imwrite(img_file, image_debug)
320+
outcap.write(image_debug)
275321

276-
# get next frame
277-
success, image_bgr = vidcap.read()
278-
count += 1
279322

280323
# write csv
281324
csv_headers = ['frame']
282325
for keypoint in COCO_KEYPOINT_INDEXES.values():
283326
csv_headers.extend([keypoint+'_x', keypoint+'_y'])
284327

328+
csv_output_filename = os.path.join(args.outputDir, 'pose-data.csv')
285329
with open(csv_output_filename, 'w', newline='') as csvfile:
286330
csvwriter = csv.writer(csvfile)
287331
csvwriter.writerow(csv_headers)
288332
csvwriter.writerows(csv_output_rows)
289333

290-
os.system("ffmpeg -y -r "
291-
+ str(args.inferenceFps)
292-
+ " -pattern_type glob -i '"
293-
+ pose_dir
294-
+ "/*.jpg' -c:v libx264 -vf fps="
295-
+ str(args.inferenceFps)+" -pix_fmt yuv420p /output/movie.mp4")
334+
vidcap.release()
335+
outcap.release()
336+
337+
cv2.destroyAllWindows()
296338

297339

298340
if __name__ == '__main__':

demo/inference_1.jpg

145 KB
Loading

demo/inference_3.jpg

264 KB
Loading

demo/inference_5.jpg

226 KB
Loading

0 commit comments

Comments
 (0)