Skip to content

Commit b542210

Browse files
committedAug 26, 2019
Added support for raw dataset inference
1 parent 00d7bf7 commit b542210

File tree

8 files changed

+96
-40
lines changed

8 files changed

+96
-40
lines changed
 

‎experiments/mpii/hrnet/w32_256x256_adam_lr1e-3.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ CUDNN:
44
DETERMINISTIC: false
55
ENABLED: true
66
DATA_DIR: ''
7-
GPUS: (0,1,2,3)
7+
GPUS: (0,)
88
OUTPUT_DIR: 'output'
99
LOG_DIR: 'log'
10-
WORKERS: 24
10+
WORKERS: 6
1111
PRINT_FREQ: 100
1212

1313
DATASET:
@@ -107,7 +107,7 @@ TRAIN:
107107
MOMENTUM: 0.9
108108
NESTEROV: false
109109
TEST:
110-
BATCH_SIZE_PER_GPU: 32
110+
BATCH_SIZE_PER_GPU: 128
111111
MODEL_FILE: ''
112112
FLIP_TEST: true
113113
POST_PROCESS: true

‎lib/config/default.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
_C.LOG_DIR = ''
2121
_C.DATA_DIR = ''
2222
_C.GPUS = (0,)
23-
_C.WORKERS = 4
23+
_C.WORKERS = 1
2424
_C.PRINT_FREQ = 20
2525
_C.AUTO_RESUME = False
2626
_C.PIN_MEMORY = True
@@ -89,14 +89,14 @@
8989
_C.TRAIN.RESUME = False
9090
_C.TRAIN.CHECKPOINT = ''
9191

92-
_C.TRAIN.BATCH_SIZE_PER_GPU = 32
92+
_C.TRAIN.BATCH_SIZE_PER_GPU = 4
9393
_C.TRAIN.SHUFFLE = True
9494

9595
# testing
9696
_C.TEST = CN()
9797

9898
# size of images for each device
99-
_C.TEST.BATCH_SIZE_PER_GPU = 32
99+
_C.TEST.BATCH_SIZE_PER_GPU = 8
100100
# Test Model Epoch
101101
_C.TEST.FLIP_TEST = False
102102
_C.TEST.POST_PROCESS = False

‎lib/core/function.py

+47-23
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import torch
17+
import cv2
1718

1819
from core.evaluate import accuracy
1920
from core.inference import get_final_preds
@@ -94,8 +95,23 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
9495
prefix)
9596

9697

98+
def compute_joints(batch_image, batch_joints, batch_joints_vis):
99+
for k in range(batch_image.size(0)):
100+
image_tensor = batch_image[k]
101+
image = image_tensor.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
102+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
103+
104+
joints = batch_joints[k]
105+
joints_vis = batch_joints_vis[k]
106+
for joint in joints:
107+
cv2.circle(image, (int(joint[0]), int(joint[1])), 2, [0, 0, 255], 2)
108+
109+
cv2.imshow("im", image)
110+
cv2.waitKey()
111+
112+
97113
def validate(config, val_loader, val_dataset, model, criterion, output_dir,
98-
tb_log_dir, writer_dict=None):
114+
tb_log_dir, writer_dict=None, predict_only=False):
99115
batch_time = AverageMeter()
100116
losses = AverageMeter()
101117
acc = AverageMeter()
@@ -116,6 +132,8 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
116132
with torch.no_grad():
117133
end = time.time()
118134
for i, (input, target, target_weight, meta) in enumerate(val_loader):
135+
img = input.data[0].mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
136+
119137
# compute output
120138
outputs = model(input)
121139
if isinstance(outputs, list):
@@ -147,12 +165,12 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
147165

148166
output = (output + output_flipped) * 0.5
149167

168+
num_images = input.size(0)
169+
150170
target = target.cuda(non_blocking=True)
151171
target_weight = target_weight.cuda(non_blocking=True)
152172

153173
loss = criterion(output, target, target_weight)
154-
155-
num_images = input.size(0)
156174
# measure accuracy and record loss
157175
losses.update(loss.item(), num_images)
158176
_, avg_acc, cnt, pred = accuracy(output.cpu().numpy(),
@@ -181,6 +199,7 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
181199
image_path.extend(meta['image'])
182200

183201
idx += num_images
202+
compute_joints(input, pred*4, meta['joints_vis'])
184203

185204
if i % config.PRINT_FREQ == 0:
186205
msg = 'Test: [{0}/{1}]\t' \
@@ -196,18 +215,20 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
196215
)
197216
save_debug_images(config, input, meta, target, pred*4, output,
198217
prefix)
218+
name_values = None
219+
perf_indicator = None
220+
if not predict_only:
221+
name_values, perf_indicator = val_dataset.evaluate(
222+
config, all_preds, output_dir, all_boxes, image_path,
223+
filenames, imgnums
224+
)
199225

200-
name_values, perf_indicator = val_dataset.evaluate(
201-
config, all_preds, output_dir, all_boxes, image_path,
202-
filenames, imgnums
203-
)
204-
205-
model_name = config.MODEL.NAME
206-
if isinstance(name_values, list):
207-
for name_value in name_values:
208-
_print_name_value(name_value, model_name)
209-
else:
210-
_print_name_value(name_values, model_name)
226+
model_name = config.MODEL.NAME
227+
if isinstance(name_values, list):
228+
for name_value in name_values:
229+
_print_name_value(name_value, model_name)
230+
else:
231+
_print_name_value(name_values, model_name)
211232

212233
if writer_dict:
213234
writer = writer_dict['writer']
@@ -222,19 +243,22 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
222243
acc.avg,
223244
global_steps
224245
)
225-
if isinstance(name_values, list):
226-
for name_value in name_values:
246+
247+
if not predict_only:
248+
if isinstance(name_values, list):
249+
for name_value in name_values:
250+
writer.add_scalars(
251+
'valid',
252+
dict(name_value),
253+
global_steps
254+
)
255+
else:
227256
writer.add_scalars(
228257
'valid',
229-
dict(name_value),
258+
dict(name_values),
230259
global_steps
231260
)
232-
else:
233-
writer.add_scalars(
234-
'valid',
235-
dict(name_values),
236-
global_steps
237-
)
261+
238262
writer_dict['valid_global_steps'] = global_steps + 1
239263

240264
return perf_indicator

‎lib/dataset/JointsDataset.py

+4
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,10 @@ def __getitem__(self, idx):
174174
if self.transform:
175175
input = self.transform(input)
176176

177+
input = cv2.resize(data_numpy, (256, 256), interpolation=cv2.INTER_LINEAR)
178+
input = cv2.normalize(input, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
179+
input = torch.from_numpy(input).permute(2, 0, 1).float()
180+
177181
for i in range(self.num_joints):
178182
if joints_vis[i, 0] > 0.0:
179183
joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)

‎lib/dataset/mpii.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
class MPIIDataset(JointsDataset):
26-
def __init__(self, cfg, root, image_set, is_train, transform=None):
26+
def __init__(self, cfg, root, image_set, is_train, transform=None, unannot_imgs_path=None):
2727
super().__init__(cfg, root, image_set, is_train, transform)
2828

2929
self.num_joints = 16
@@ -33,10 +33,13 @@ def __init__(self, cfg, root, image_set, is_train, transform=None):
3333
self.upper_body_ids = (7, 8, 9, 10, 11, 12, 13, 14, 15)
3434
self.lower_body_ids = (0, 1, 2, 3, 4, 5, 6)
3535

36-
self.db = self._get_db()
36+
if unannot_imgs_path is None:
37+
self.db = self._get_db()
3738

38-
if is_train and cfg.DATASET.SELECT_DATA:
39-
self.db = self.select_data(self.db)
39+
if is_train and cfg.DATASET.SELECT_DATA:
40+
self.db = self.select_data(self.db)
41+
else:
42+
self.db = self._get_db_raw(unannot_imgs_path)
4043

4144
logger.info('=> load {} samples'.format(len(self.db)))
4245

@@ -93,6 +96,24 @@ def _get_db(self):
9396

9497
return gt_db
9598

99+
def _get_db_raw(self, unannot_imgs_path):
100+
gt_db = []
101+
102+
for image_name in os.listdir(unannot_imgs_path):
103+
gt_db.append(
104+
{
105+
'image': os.path.join(unannot_imgs_path, image_name),
106+
'center': np.array([0, 0]),
107+
'scale': np.array([1., 1.]),
108+
'joints_3d': np.ones((16, 3)),
109+
'joints_3d_vis': np.ones((16, 3)),
110+
'filename': '',
111+
'imgnum': 0,
112+
}
113+
)
114+
115+
return gt_db
116+
96117
def evaluate(self, cfg, preds, output_dir, *args, **kwargs):
97118
# convert 0-based index to 1-based index
98119
preds = preds[:, :, 0:2] + 1.0

‎lib/utils/vis.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def save_batch_image_with_joints(batch_image, batch_joints, batch_joints_vis,
4343
joints_vis = batch_joints_vis[k]
4444

4545
for joint, joint_vis in zip(joints, joints_vis):
46-
joint[0] = x * width + padding + joint[0]
47-
joint[1] = y * height + padding + joint[1]
46+
# joint[0] = x * width + padding + joint[0]
47+
# joint[1] = y * height + padding + joint[1]
4848
if joint_vis[0]:
4949
cv2.circle(ndarr, (int(joint[0]), int(joint[1])), 2, [255, 0, 0], 2)
5050
k = k + 1

‎requirements.txt

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
EasyDict==1.7
2-
opencv-python==3.4.1.15
32
shapely==1.6.4
43
Cython
54
scipy

‎tools/test.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch.utils.data.distributed
2222
import torchvision.transforms as transforms
2323

24-
import _init_paths
24+
#import _init_paths
2525
from config import cfg
2626
from config import update_config
2727
from core.loss import JointsMSELoss
@@ -62,6 +62,12 @@ def parse_args():
6262
type=str,
6363
default='')
6464

65+
parser.add_argument('--unnanotImgsPath',
66+
help='annotations availability',
67+
type=str,
68+
default='')
69+
# default='/home/andrettin/Repo/POSE_ESTIMATION/deep-high-resolution-net.pytorch/data/mpii/images')
70+
6571
args = parser.parse_args()
6672
return args
6773

@@ -95,7 +101,7 @@ def main():
95101
logger.info('=> loading model from {}'.format(model_state_file))
96102
model.load_state_dict(torch.load(model_state_file))
97103

98-
model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
104+
model = torch.nn.DataParallel(model, device_ids=[0]).cuda()
99105

100106
# define loss function (criterion) and optimizer
101107
criterion = JointsMSELoss(
@@ -111,7 +117,8 @@ def main():
111117
transforms.Compose([
112118
transforms.ToTensor(),
113119
normalize,
114-
])
120+
]),
121+
args.unnanotImgsPath if args.unnanotImgsPath != '' else None
115122
)
116123
valid_loader = torch.utils.data.DataLoader(
117124
valid_dataset,
@@ -123,7 +130,8 @@ def main():
123130

124131
# evaluate on validation set
125132
validate(cfg, valid_loader, valid_dataset, model, criterion,
126-
final_output_dir, tb_log_dir)
133+
final_output_dir, tb_log_dir,
134+
predict_only=True if args.unnanotImgsPath != '' else False)
127135

128136

129137
if __name__ == '__main__':

0 commit comments

Comments
 (0)