Skip to content

Commit 251fa36

Browse files
committedJan 19, 2022
Update.
1 parent 36fd8b6 commit 251fa36

File tree

10 files changed

+1049
-4
lines changed

10 files changed

+1049
-4
lines changed
 

‎dataset_tools/keypoints/__init__.py

Whitespace-only changes.
+274
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
import torchvision.transforms as transforms
5+
import os.path as osp
6+
import cv2
7+
import argparse
8+
import copy
9+
import os
10+
import pprint
11+
import glob
12+
import torch
13+
import torch.backends.cudnn as cudnn
14+
import torch.nn.parallel
15+
import torch.optim
16+
import torch.utils.data
17+
import torch.utils.data.distributed
18+
import torchvision.transforms
19+
import torch.multiprocessing
20+
import numpy as np
21+
import onnxruntime as ort
22+
from .yolox import *
23+
import cv2
24+
import object_detection2.bboxes as odb
25+
from .image_encode import ImageEncoder
26+
27+
curdir_path = osp.dirname(__file__)
28+
29+
class PersonDetection:
30+
def __init__(self):
31+
self.model = YOLOXDetection()
32+
33+
def __call__(self, img):
34+
'''
35+
img: BGR order
36+
'''
37+
assert len(img.shape)==3,"Error img size"
38+
output = self.model(img)
39+
mask = output[...,-1]==0
40+
output = output[mask]
41+
bboxes = output[...,:4]
42+
#labels = output[...,-1]
43+
probs = output[...,4]*output[...,5]
44+
45+
wh = bboxes[...,2:]-bboxes[...,:2]
46+
bboxes = odb.npscale_bboxes(bboxes,1.2,max_size=[img.shape[1],img.shape[0]])
47+
wh_mask = wh>1
48+
size_mask = np.logical_and(wh_mask[...,0],wh_mask[...,1])
49+
bboxes = bboxes[size_mask]
50+
probs = probs[size_mask]
51+
52+
return bboxes,probs
53+
54+
class KPDetection:
55+
def __init__(self) -> None:
56+
onnx_path = osp.join(curdir_path,"keypoints.onnx")
57+
self.model = ort.InferenceSession(onnx_path)
58+
self.input_name = self.model.get_inputs()[0].name
59+
self.person_det = PersonDetection()
60+
self.transform = transforms.Compose([
61+
transforms.ToTensor(),
62+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
63+
std=[0.229, 0.224, 0.225]),
64+
])
65+
self.image_encoder = ImageEncoder()
66+
67+
@staticmethod
68+
def cut_and_resizev0(img,bboxes,size=(288,384)):
69+
res = []
70+
for bbox in bboxes:
71+
cur_img = img[bbox[1]:bbox[3],bbox[0]:bbox[2],:]
72+
if cur_img.shape[0]>1 and cur_img.shape[1]>1:
73+
cur_img = cv2.resize(cur_img,size,interpolation=cv2.INTER_LINEAR)
74+
else:
75+
cur_img = np.zeros([size[1],size[0],3],dtype=np.float32)
76+
res.append(cur_img)
77+
return res
78+
79+
@staticmethod
80+
def cut_and_resize(img,bboxes,size=(288,384)):
81+
res = []
82+
res_bboxes = []
83+
for i,bbox in enumerate(bboxes):
84+
cur_img = img[bbox[1]:bbox[3],bbox[0]:bbox[2],:]
85+
if cur_img.shape[0]>1 and cur_img.shape[1]>1:
86+
#cur_img = cv2.resize(cur_img,size,interpolation=cv2.INTER_LINEAR)
87+
cur_img,bbox = KPDetection.resize_img(cur_img,bbox,size)
88+
else:
89+
cur_img = np.zeros([size[1],size[0],3],dtype=np.float32)
90+
res.append(cur_img)
91+
res_bboxes.append(bbox)
92+
return res,np.array(res_bboxes,dtype=np.float32)
93+
94+
@staticmethod
95+
def resize_img(img,bbox,target_size,pad_color=(127,127,127)):
96+
res = np.ndarray([target_size[1],target_size[0],3],dtype=np.uint8)
97+
res[:,:] = np.array(pad_color,dtype=np.uint8)
98+
ratio = target_size[0]/target_size[1]
99+
bbox_cx = (bbox[2]+bbox[0])/2
100+
bbox_cy = (bbox[3]+bbox[1])/2
101+
bbox_w = (bbox[2]-bbox[0])
102+
bbox_h = (bbox[3]-bbox[1])
103+
if img.shape[1]>ratio*img.shape[0]:
104+
nw = target_size[0]
105+
nh = int(target_size[0]*img.shape[0]/img.shape[1])
106+
bbox_h = bbox_w/ratio
107+
else:
108+
nh = target_size[1]
109+
nw = int(target_size[1]*img.shape[1]/img.shape[0])
110+
bbox_w = bbox_h*ratio
111+
112+
img = cv2.resize(img,(nw,nh),interpolation=cv2.INTER_LINEAR)
113+
xoffset = (target_size[0]-nw)//2
114+
yoffset = (target_size[1]-nh)//2
115+
res[yoffset:yoffset+nh,xoffset:xoffset+nw] = img
116+
bbox = np.array([bbox_cx-bbox_w/2,bbox_cy-bbox_h/2,bbox_cx+bbox_w/2,bbox_cy+bbox_h/2],dtype=np.float32)
117+
return res,bbox
118+
119+
@staticmethod
120+
def get_offset_and_scalar(bboxes,size=(288,384)):
121+
offset = bboxes[...,:2]
122+
offset = np.expand_dims(offset,axis=1)
123+
bboxes_size = bboxes[...,2:]-bboxes[...,:2]
124+
cur_size = np.array(size,np.float32)
125+
cur_size = np.resize(cur_size,[1,2])
126+
scalar = bboxes_size/cur_size
127+
scalar = np.expand_dims(scalar,axis=1)*4
128+
return offset,scalar
129+
130+
@staticmethod
131+
def get_max_preds(batch_heatmaps):
132+
'''
133+
get predictions from score maps
134+
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
135+
'''
136+
assert isinstance(batch_heatmaps, np.ndarray), \
137+
'batch_heatmaps should be numpy.ndarray'
138+
assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'
139+
batch_size = batch_heatmaps.shape[0]
140+
num_joints = batch_heatmaps.shape[1]
141+
width = batch_heatmaps.shape[3]
142+
heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
143+
idx = np.argmax(heatmaps_reshaped, 2)
144+
maxvals = np.amax(heatmaps_reshaped, 2)
145+
146+
maxvals = maxvals.reshape((batch_size, num_joints, 1))
147+
idx = idx.reshape((batch_size, num_joints, 1))
148+
149+
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
150+
151+
preds[:, :, 0] = (preds[:, :, 0]) % width #x
152+
preds[:, :, 1] = np.floor((preds[:, :, 1]) / width) #y
153+
154+
pred_mask = np.tile(np.greater(maxvals, 0.05), (1, 1, 2))
155+
pred_mask = pred_mask.astype(np.float32)
156+
157+
preds *= pred_mask
158+
return preds, maxvals
159+
160+
@staticmethod
161+
def get_final_preds(batch_heatmaps):
162+
coords, maxvals = KPDetection.get_max_preds(batch_heatmaps)
163+
164+
heatmap_height = batch_heatmaps.shape[2]
165+
heatmap_width = batch_heatmaps.shape[3]
166+
167+
# post-processing
168+
if True:
169+
for n in range(coords.shape[0]):
170+
for p in range(coords.shape[1]):
171+
hm = batch_heatmaps[n][p]
172+
px = int(math.floor(coords[n][p][0] + 0.5))
173+
py = int(math.floor(coords[n][p][1] + 0.5))
174+
if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
175+
diff = np.array(
176+
[
177+
hm[py][px + 1] - hm[py][px - 1],
178+
hm[py + 1][px] - hm[py - 1][px]
179+
]
180+
)
181+
coords[n][p] += np.sign(diff) * .25
182+
183+
preds = coords.copy()
184+
185+
#return preds, maxvals
186+
return np.concatenate([preds,maxvals],axis=-1)
187+
188+
def get_person_bboxes(self,img,return_ext_info=False):
189+
'''
190+
191+
Args:
192+
img: RGB order
193+
194+
Returns:
195+
ans: [N,17,3] (x,y,score,...)
196+
'''
197+
img = img[...,::-1]
198+
bboxes,probs = self.person_det(img)
199+
if len(probs) == 0:
200+
return np.zeros([0,4],dtype=np.float32),np.zeros([0,128],dtype=np.float32)
201+
if return_ext_info:
202+
img_patchs = self.cut_and_resizev0(img,bboxes.astype(np.int32),size=(64,128))
203+
img_patchs = np.array(img_patchs)
204+
img_patchs = img_patchs[...,::-1]
205+
embds = self.image_encoder(img_patchs)
206+
return bboxes,embds
207+
return bboxes
208+
209+
def get_kps_by_bboxes(self,img,bboxes,return_fea=True):
210+
'''
211+
212+
Args:
213+
img: RGB order
214+
215+
Returns:
216+
ans: [N,17,3] (x,y,score,...)
217+
'''
218+
#print(bboxes)
219+
#cv2.imwrite("/home/wj/ai/mldata/0day/x1/a.jpg",img)
220+
imgs,bboxes = self.cut_and_resize(img,bboxes.astype(np.int32))
221+
#cv2.imwrite("/home/wj/ai/mldata/0day/x1/b.jpg",imgs[0])
222+
imgs = [self.transform(x) for x in imgs]
223+
imgs = [x.cpu().numpy() for x in imgs]
224+
imgs = np.ascontiguousarray(np.array(imgs))
225+
#print(imgs.shape)
226+
try:
227+
output = self.model.run(None, {self.input_name: imgs})[0]
228+
except Exception as e:
229+
print(f"ERROR")
230+
if return_fea:
231+
return np.zeros([imgs.shape[0],17,3],dtype=np.float32),None
232+
else:
233+
return np.zeros([imgs.shape[0],17,3],dtype=np.float32)
234+
output_fea = output
235+
output = self.get_final_preds(output)
236+
offset,scalar = self.get_offset_and_scalar(bboxes)
237+
output[...,:2] = output[...,:2]*scalar+offset
238+
if return_fea:
239+
return output,output_fea,bboxes
240+
return output
241+
242+
@staticmethod
243+
def trans_person_bboxes(bboxes,img_width=None,img_height=None):
244+
ratio = 288/384
245+
bboxes = odb.npto_cyxhw(bboxes)
246+
cx,cy,w,h = bboxes[...,0],bboxes[...,1],bboxes[...,2],bboxes[...,3]
247+
mask0 = w>(h*ratio)
248+
h0 = w/ratio
249+
mask1 = h>(w/ratio)
250+
w1 = h*ratio
251+
h = np.where(mask0,h0,h)
252+
w = np.where(mask1,w1,w)
253+
w = w*1.25
254+
h = h*1.25
255+
bboxes = np.stack([cx,cy,w,h],axis=-1)
256+
bboxes = odb.npto_yminxminymaxxmax(bboxes)
257+
bboxes = np.maximum(bboxes,0)
258+
if img_width is not None:
259+
bboxes[...,2] = np.minimum(bboxes[...,2],img_width-1)
260+
if img_height is not None:
261+
bboxes[...,3] = np.minimum(bboxes[...,3],img_height-1)
262+
return bboxes
263+
264+
def __call__(self, img):
265+
'''
266+
267+
Args:
268+
img: RGB order
269+
270+
Returns:
271+
ans: [N,17,3] (x,y,score,...)
272+
'''
273+
bboxes = self.get_person_bboxes(img,return_ext_info=False)
274+
return self.get_kps_by_bboxes(img,bboxes)
+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
import os.path as osp
4+
5+
class ImageEncoder(object):
6+
7+
def __init__(self, input_name="images",
8+
output_name="features"):
9+
self.session = tf.Session()
10+
cur_dir = osp.dirname(__file__)
11+
checkpoint_filename = osp.join(cur_dir,"networks","mars-small128.pb")
12+
with tf.gfile.GFile(checkpoint_filename, "rb") as file_handle:
13+
graph_def = tf.GraphDef()
14+
graph_def.ParseFromString(file_handle.read())
15+
tf.import_graph_def(graph_def, name="net")
16+
self.input_var = tf.get_default_graph().get_tensor_by_name(
17+
"net/%s:0" % input_name)
18+
self.output_var = tf.get_default_graph().get_tensor_by_name(
19+
"net/%s:0" % output_name)
20+
21+
assert len(self.output_var.get_shape()) == 2
22+
assert len(self.input_var.get_shape()) == 4
23+
self.feature_dim = self.output_var.get_shape().as_list()[-1]
24+
self.image_shape = self.input_var.get_shape().as_list()[1:]
25+
26+
def __call__(self, data_x):
27+
feed_dict = {self.input_var:data_x}
28+
out = self.session.run(self.output_var, feed_dict=feed_dict)
29+
return out
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)