-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathindoor_eval.py
377 lines (326 loc) · 13.4 KB
/
indoor_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
# Copyright (c) OpenRobotLab. All rights reserved.
import numpy as np
import torch
from mmengine.logging import print_log
from terminaltables import AsciiTable
def average_precision(recalls, precisions, mode='area'):
"""Calculate average precision (for single or multiple scales).
Args:
recalls (np.ndarray): Recalls with shape of (num_scales, num_dets)
or (num_dets, ).
precisions (np.ndarray): Precisions with shape of
(num_scales, num_dets) or (num_dets, ).
mode (str): 'area' or '11points', 'area' means calculating the area
under precision-recall curve, '11points' means calculating
the average precision of recalls at [0, 0.1, ..., 1]
Returns:
float or np.ndarray: Calculated average precision.
"""
if recalls.ndim == 1:
recalls = recalls[np.newaxis, :]
precisions = precisions[np.newaxis, :]
assert recalls.shape == precisions.shape
assert recalls.ndim == 2
num_scales = recalls.shape[0]
ap = np.zeros(num_scales, dtype=np.float32)
if mode == 'area':
zeros = np.zeros((num_scales, 1), dtype=recalls.dtype)
ones = np.ones((num_scales, 1), dtype=recalls.dtype)
mrec = np.hstack((zeros, recalls, ones))
mpre = np.hstack((zeros, precisions, zeros))
for i in range(mpre.shape[1] - 1, 0, -1):
mpre[:, i - 1] = np.maximum(mpre[:, i - 1], mpre[:, i])
for i in range(num_scales):
ind = np.where(mrec[i, 1:] != mrec[i, :-1])[0]
ap[i] = np.sum(
(mrec[i, ind + 1] - mrec[i, ind]) * mpre[i, ind + 1])
elif mode == '11points':
for i in range(num_scales):
for thr in np.arange(0, 1 + 1e-3, 0.1):
precs = precisions[i, recalls[i, :] >= thr]
prec = precs.max() if precs.size > 0 else 0
ap[i] += prec
ap /= 11
else:
raise ValueError(
'Unrecognized mode, only "area" and "11points" are supported')
return ap
def eval_det_cls(pred, gt, iou_thr=None):
"""Generic functions to compute precision/recall for object detection for a
single class.
Args:
pred (dict): Predictions mapping from image id to bounding boxes
and scores.
gt (dict): Ground truths mapping from image id to bounding boxes.
iou_thr (list[float]): A list of iou thresholds.
Return:
tuple (np.ndarray, np.ndarray, float): Recalls, precisions and
average precision.
"""
# {img_id: {'bbox': box structure, 'det': matched list}}
class_recs = {}
npos = 0
# figure out the bbox code size first
gt_bbox_code_size = 9
pred_bbox_code_size = 9
for img_id in gt.keys():
if len(gt[img_id]) != 0:
gt_bbox_code_size = gt[img_id][0].tensor.shape[1]
break
for img_id in pred.keys():
if len(pred[img_id][0]) != 0:
pred_bbox_code_size = pred[img_id][0][0].tensor.shape[1]
break
assert gt_bbox_code_size == pred_bbox_code_size
for img_id in gt.keys():
cur_gt_num = len(gt[img_id])
if cur_gt_num != 0:
gt_cur = torch.zeros([cur_gt_num, gt_bbox_code_size],
dtype=torch.float32)
for i in range(cur_gt_num):
gt_cur[i] = gt[img_id][i].tensor
bbox = gt[img_id][0].new_box(gt_cur)
else:
bbox = gt[img_id]
det = [[False] * len(bbox) for i in iou_thr]
npos += len(bbox)
class_recs[img_id] = {'bbox': bbox, 'det': det}
# construct dets
image_ids = []
confidence = []
ious = []
for img_id in pred.keys():
cur_num = len(pred[img_id])
if cur_num == 0:
continue
pred_cur = torch.zeros((cur_num, pred_bbox_code_size),
dtype=torch.float32)
box_idx = 0
for box, score in pred[img_id]:
image_ids.append(img_id)
confidence.append(score)
# handle outlier (too thin) predicted boxes
w, l, h = box.tensor[0, 3:6]
faces = [w * l, w * h, h * l]
if torch.any(box.tensor.new_tensor(faces) < 2e-4):
# print('Find small predicted boxes,',
# 'and clamp short edges to 2e-2 meters.')
box.tensor[:, 3:6] = torch.clamp(box.tensor[:, 3:6], min=2e-2)
pred_cur[box_idx] = box.tensor
box_idx += 1
pred_cur = box.new_box(pred_cur)
gt_cur = class_recs[img_id]['bbox']
if len(gt_cur) > 0:
# calculate iou in each image
iou_cur = pred_cur.overlaps(pred_cur, gt_cur)
for i in range(cur_num):
ious.append(iou_cur[i])
else:
for i in range(cur_num):
ious.append(np.zeros(1))
confidence = np.array(confidence)
# sort by confidence
sorted_ind = np.argsort(-confidence)
image_ids = [image_ids[x] for x in sorted_ind]
ious = [ious[x] for x in sorted_ind]
# go down dets and mark TPs and FPs
num_images = len(image_ids)
tp_thr = [np.zeros(num_images) for i in iou_thr]
fp_thr = [np.zeros(num_images) for i in iou_thr]
for d in range(num_images):
R = class_recs[image_ids[d]]
iou_max = -np.inf
BBGT = R['bbox']
cur_iou = ious[d]
if len(BBGT) > 0:
# compute overlaps
for j in range(len(BBGT)):
# iou = get_iou_main(get_iou_func, (bb, BBGT[j,...]))
iou = cur_iou[j]
if iou > iou_max:
iou_max = iou
jmax = j
for iou_idx, thresh in enumerate(iou_thr):
if iou_max > thresh:
if not R['det'][iou_idx][jmax]:
tp_thr[iou_idx][d] = 1.
R['det'][iou_idx][jmax] = 1
else:
fp_thr[iou_idx][d] = 1.
else:
fp_thr[iou_idx][d] = 1.
ret = []
for iou_idx, thresh in enumerate(iou_thr):
# compute precision recall
fp = np.cumsum(fp_thr[iou_idx])
tp = np.cumsum(tp_thr[iou_idx])
recall = tp / float(npos)
# avoid divide by zero in case the first detection matches a difficult
# ground truth
precision = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
ap = average_precision(recall, precision)
ret.append((recall, precision, ap))
return ret
def eval_map_recall(pred, gt, ovthresh=None):
"""Evaluate mAP and recall.
Generic functions to compute precision/recall for object detection
for multiple classes.
Args:
pred (dict): Information of detection results,
which maps class_id and predictions.
gt (dict): Information of ground truths, which maps class_id and
ground truths.
ovthresh (list[float], optional): iou threshold. Default: None.
Return:
tuple[dict]: dict results of recall, AP, and precision for all classes.
"""
ret_values = {}
for classname in gt.keys():
if classname in pred:
ret_values[classname] = eval_det_cls(pred[classname],
gt[classname], ovthresh)
recall = [{} for i in ovthresh]
precision = [{} for i in ovthresh]
ap = [{} for i in ovthresh]
for label in gt.keys():
for iou_idx, thresh in enumerate(ovthresh):
if label in pred:
recall[iou_idx][label], precision[iou_idx][label], ap[iou_idx][
label] = ret_values[label][iou_idx]
else:
recall[iou_idx][label] = np.zeros(1)
precision[iou_idx][label] = np.zeros(1)
ap[iou_idx][label] = np.zeros(1)
return recall, precision, ap
def indoor_eval(gt_annos,
dt_annos,
metric,
label2cat,
logger=None,
box_mode_3d=None,
classes_split=None):
"""Indoor Evaluation.
Evaluate the result of the detection.
Args:
gt_annos (list[dict]): Ground truth annotations.
dt_annos (list[dict]): Detection annotations. the dict
includes the following keys
- labels_3d (torch.Tensor): Labels of boxes.
- bboxes_3d (:obj:`BaseInstance3DBoxes`):
3D bounding boxes in Depth coordinate.
- scores_3d (torch.Tensor): Scores of boxes.
metric (list[float]): IoU thresholds for computing average precisions.
label2cat (tuple): Map from label to category.
logger (logging.Logger | str, optional): The way to print the mAP
summary. See `mmdet.utils.print_log()` for details. Default: None.
Return:
dict[str, float]: Dict of results.
"""
assert len(dt_annos) == len(gt_annos)
pred = {} # map {class_id: pred}
gt = {} # map {class_id: gt}
for img_id in range(len(dt_annos)):
# parse detected annotations
det_anno = dt_annos[img_id]
for i in range(len(det_anno['labels_3d'])):
label = det_anno['labels_3d'].numpy()[i]
bbox = det_anno['bboxes_3d'].convert_to(box_mode_3d)[i]
score = det_anno['scores_3d'].numpy()[i]
if label not in pred:
pred[int(label)] = {}
if img_id not in pred[label]:
pred[int(label)][img_id] = []
if label not in gt:
gt[int(label)] = {}
if img_id not in gt[label]:
gt[int(label)][img_id] = []
pred[int(label)][img_id].append((bbox, score))
# parse gt annotations
gt_anno = gt_annos[img_id]
gt_boxes = gt_anno['gt_bboxes_3d']
labels_3d = gt_anno['gt_labels_3d']
for i in range(len(labels_3d)):
label = labels_3d[i]
bbox = gt_boxes[i]
if label not in gt:
gt[label] = {}
if img_id not in gt[label]:
gt[label][img_id] = []
gt[label][img_id].append(bbox)
rec, prec, ap = eval_map_recall(pred, gt, metric)
# filter nan results
ori_keys = list(ap[0].keys())
for key in ori_keys:
if np.isnan(ap[0][key][0]):
for r in rec:
del r[key]
for p in prec:
del p[key]
for a in ap:
del a[key]
ret_dict = dict()
header = ['classes']
table_columns = [[label2cat[label]
for label in ap[0].keys()] + ['Overall']]
for i, iou_thresh in enumerate(metric):
header.append(f'AP_{iou_thresh:.2f}')
header.append(f'AR_{iou_thresh:.2f}')
rec_list = []
for label in ap[i].keys():
ret_dict[f'{label2cat[label]}_AP_{iou_thresh:.2f}'] = float(
ap[i][label][0])
ret_dict[f'mAP_{iou_thresh:.2f}'] = float(np.mean(list(
ap[i].values())))
table_columns.append(list(map(float, list(ap[i].values()))))
table_columns[-1] += [ret_dict[f'mAP_{iou_thresh:.2f}']]
table_columns[-1] = [f'{x:.4f}' for x in table_columns[-1]]
for label in rec[i].keys():
ret_dict[f'{label2cat[label]}_rec_{iou_thresh:.2f}'] = float(
rec[i][label][-1])
rec_list.append(rec[i][label][-1])
ret_dict[f'mAR_{iou_thresh:.2f}'] = float(np.mean(rec_list))
table_columns.append(list(map(float, rec_list)))
table_columns[-1] += [ret_dict[f'mAR_{iou_thresh:.2f}']]
table_columns[-1] = [f'{x:.4f}' for x in table_columns[-1]]
table_data = [header]
table_rows = list(zip(*table_columns))
table_data += table_rows
table = AsciiTable(table_data)
table.inner_footing_row_border = True
print_log('\n' + table.table, logger=logger)
if classes_split is not None:
splits = ['head', 'common', 'tail']
for idx in range(len(splits)):
header = [f'{splits[idx]}_classes']
# init the category list/column
cat_list = []
for label in classes_split[idx]:
if label in ap[0]:
cat_list.append(label2cat[label])
table_columns = [cat_list + ['Overall']]
for i, iou_thresh in enumerate(metric):
header.append(f'AP_{iou_thresh:.2f}')
header.append(f'AR_{iou_thresh:.2f}')
ap_list = []
for label in classes_split[idx]:
if label in ap[i]:
ap_list.append(float(ap[i][label][0]))
mean_ap = float(np.mean(ap_list))
table_columns.append(list(map(float, ap_list)))
table_columns[-1] += [mean_ap]
table_columns[-1] = [f'{x:.4f}' for x in table_columns[-1]]
rec_list = []
for label in classes_split[idx]:
if label in rec[i]:
rec_list.append(rec[i][label][-1])
mean_rec = float(np.mean(rec_list))
table_columns.append(list(map(float, rec_list)))
table_columns[-1] += [mean_rec]
table_columns[-1] = [f'{x:.4f}' for x in table_columns[-1]]
table_data = [header]
table_rows = list(zip(*table_columns))
table_data += table_rows
table = AsciiTable(table_data)
table.inner_footing_row_border = True
print_log('\n' + table.table, logger=logger)
return ret_dict