Skip to content

Commit fca9860

Browse files
committed
allow evaluation on custom images
1 parent d05be9f commit fca9860

File tree

9 files changed

+60
-13
lines changed

9 files changed

+60
-13
lines changed

README.md

+15-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ Our paper [Unbiased Scene Graph Generation from Biased Training](https://arxiv.o
88

99
## Recent Updates
1010

11-
- [x] 2020.06.23 [No Graph Constraint Mean Recall@K (ng-mR@K) and No Graph Constraint Zero-Shot Recall@K (ng-zR@K)](METRICS.md#explanation-of-our-metrics)
11+
- [x] 2020.06.23 Add No Graph Constraint Mean Recall@K (ng-mR@K) and No Graph Constraint Zero-Shot Recall@K (ng-zR@K)[link](METRICS.md#explanation-of-our-metrics)
12+
- [x] 2020.06.23 Allow Scene Graph Detection (SGDet) on Custom Images[link](#run-SGDet-on-custom-images)
1213

1314
## Contents
1415

@@ -23,9 +24,10 @@ Our paper [Unbiased Scene Graph Generation from Biased Training](https://arxiv.o
2324
6. [Scene Graph Generation as RoI_Head](#scene-graph-generation-as-RoI_Head)
2425
7. [Training on Scene Graph Generation](#perform-training-on-scene-graph-generation)
2526
8. [Evaluation on Scene Graph Generation](#Evaluation)
26-
9. [Other Options that May Improve the SGG](#other-options-that-may-improve-the-SGG)
27-
10. [Tips and Tricks for TDE on any Unbiased Task](#tips-and-Tricks-for-any-unbiased-taskX-from-biased-training)
28-
11. [Citations](#Citations)
27+
9. [SGDet on Custum Images](#run-SGDet-on-custom-images)
28+
10. [Other Options that May Improve the SGG](#other-options-that-may-improve-the-SGG)
29+
11. [Tips and Tricks for TDE on any Unbiased Task](#tips-and-Tricks-for-any-unbiased-taskX-from-biased-training)
30+
12. [Citations](#Citations)
2931

3032
## Overview
3133

@@ -168,6 +170,15 @@ MOTIFS-SGCls-TDE | 20.47 | 26.31 | 28.79 | 9.80 | 13.21 | 15.06 | 1.91 | 2.95
168170
MOTIFS-PredCls-none | 59.64 | 66.11 | 67.96 | 11.46 | 14.60 | 15.84 | 5.79 | 11.02 | 14.74
169171
MOTIFS-PredCls-TDE | 33.38 | 45.88 | 51.25 | 17.85 | 24.75 | 28.70 | 8.28 | 14.31 | 18.04
170172

173+
## Run SGDet on Custom Images
174+
Note that evaluation on custum images is only valid for SGDet model, because PredCls and SGCls model requires additional ground-truth bounding boxes information. You only need to turn on the switch TEST.CUSTUM_EVAL and give a folder path that contains the custom images to TEST.CUSTUM_PATH. Only JPG files are allowed. The output will be custom_prediction.pytorch saved in OUTPUT_DIR, which can be read by torch.load().
175+
176+
Test Example 1 : (SGDet, Motif Model)
177+
```bash
178+
CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --master_port 10027 --nproc_per_node=1 tools/relation_test_net.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" MODEL.ROI_RELATION_HEAD.USE_GT_BOX False MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL False MODEL.ROI_RELATION_HEAD.PREDICTOR MotifPredictor TEST.IMS_PER_BATCH 1 DTYPE "float16" GLOVE_DIR /home/kaihua/glove MODEL.PRETRAINED_DETECTOR_CKPT /home/kaihua/checkpoints/motif-precls-exmp OUTPUT_DIR /home/kaihua/checkpoints/motif-precls-exmp TEST.CUSTUM_EVAL True TEST.CUSTUM_PATH /home/kaihua/checkpoints/custom_images
179+
```
180+
181+
171182
## Other Options that May Improve the SGG
172183

173184
- For some models (not all), turning on or turning off ```MODEL.ROI_RELATION_HEAD.POOLING_ALL_LEVELS``` will affect the performance of predicate prediction, e.g., turning it off will improve VCTree PredCls but not the corresponding SGCls and SGGen. For the reported results of VCTree, we simply turn it on for all three protocols like other models.

configs/e2e_relation_X_101_32_8_FPN_1x.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,5 @@ TEST:
126126
SYNC_GATHER: True # turn on will slow down the evaluation to solve the sgdet test out of memory problem
127127
REQUIRE_OVERLAP: False
128128
LATER_NMS_PREDICTION_THRES: 0.5
129+
CUSTUM_EVAL: False # eval SGDet model on custum images, output a json
130+
CUSTUM_PATH: '.' # the folder that contains the custum images, only jpg files are allowed

maskrcnn_benchmark/config/defaults.py

+5
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,11 @@
574574
_C.TEST.RELATION.SYNC_GATHER = False
575575

576576
_C.TEST.ALLOW_LOAD_FROM_CACHE = True
577+
578+
579+
_C.TEST.CUSTUM_EVAL = False
580+
_C.TEST.CUSTUM_PATH = '.'
581+
577582
# ---------------------------------------------------------------------------- #
578583
# Misc options
579584
# ---------------------------------------------------------------------------- #

maskrcnn_benchmark/config/paths_catalog.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ def get(name, cfg):
157157
# else set filter to False, because we need all images for pretraining detector
158158
args['filter_non_overlap'] = (not cfg.MODEL.ROI_RELATION_HEAD.USE_GT_BOX) and cfg.MODEL.RELATION_ON and cfg.MODEL.ROI_RELATION_HEAD.REQUIRE_BOX_OVERLAP
159159
args['filter_empty_rels'] = cfg.MODEL.RELATION_ON
160-
args['flip_aug'] = cfg.MODEL.FLIP_AUG
160+
args['flip_aug'] = cfg.MODEL.FLIP_AUG
161+
args['custom_eval'] = cfg.TEST.CUSTUM_EVAL
162+
args['custom_path'] = cfg.TEST.CUSTUM_PATH
161163
return dict(
162164
factory="VGDataset",
163165
args=args,

maskrcnn_benchmark/data/datasets/evaluation/vg/sgg_eval.py

+4
Original file line numberDiff line numberDiff line change
@@ -314,9 +314,11 @@ def generate_print_string(self, mode):
314314
result_str += ' for mode=%s, type=Mean Recall.' % mode
315315
result_str += '\n'
316316
if self.print_detail:
317+
result_str += '----------------------- Details ------------------------\n'
317318
for n, r in zip(self.rel_name_list, self.result_dict[mode + '_mean_recall_list'][100]):
318319
result_str += '({}:{:.4f}) '.format(str(n), r)
319320
result_str += '\n'
321+
result_str += '--------------------------------------------------------\n'
320322

321323
return result_str
322324

@@ -384,9 +386,11 @@ def generate_print_string(self, mode):
384386
result_str += ' for mode=%s, type=No Graph Constraint Mean Recall.' % mode
385387
result_str += '\n'
386388
if self.print_detail:
389+
result_str += '----------------------- Details ------------------------\n'
387390
for n, r in zip(self.rel_name_list, self.result_dict[mode + '_ng_mean_recall_list'][100]):
388391
result_str += '({}:{:.4f}) '.format(str(n), r)
389392
result_str += '\n'
393+
result_str += '--------------------------------------------------------\n'
390394

391395
return result_str
392396

maskrcnn_benchmark/data/datasets/visual_genome.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class VGDataset(torch.utils.data.Dataset):
1818

1919
def __init__(self, split, img_dir, roidb_file, dict_file, image_file, transforms=None,
2020
filter_empty_rels=True, num_im=-1, num_val_im=5000,
21-
filter_duplicate_rels=True, filter_non_overlap=True, flip_aug=False):
21+
filter_duplicate_rels=True, filter_non_overlap=True, flip_aug=False, custom_eval=False, custom_path=''):
2222
"""
2323
Torch dataset for VisualGenome
2424
Parameters:
@@ -63,11 +63,18 @@ def __init__(self, split, img_dir, roidb_file, dict_file, image_file, transforms
6363
self.filenames = [self.filenames[i] for i in np.where(self.split_mask)[0]]
6464
self.img_info = [self.img_info[i] for i in np.where(self.split_mask)[0]]
6565

66+
self.custom_eval = custom_eval
67+
if self.custom_eval:
68+
self.get_custom_imgs(custom_path)
69+
6670

6771
def __getitem__(self, index):
6872
#if self.split == 'train':
6973
# while(random.random() > self.img_info[index]['anti_prop']):
7074
# index = int(random.random() * len(self.filenames))
75+
if self.custom_eval:
76+
img = Image.open(self.custom_files[index]).convert("RGB")
77+
return img, 0, index
7178

7279
img = Image.open(self.filenames[index]).convert("RGB")
7380
if img.size[0] != self.img_info[index]['width'] or img.size[1] != self.img_info[index]['height']:
@@ -103,6 +110,10 @@ def get_statistics(self):
103110
}
104111
return result
105112

113+
def get_custom_imgs(self, path):
114+
self.custom_files = []
115+
for file_name in os.listdir(path):
116+
self.custom_files.append(os.path.join(path, file_name))
106117

107118
def get_img_info(self, index):
108119
# WARNING: original image_file.json has several pictures with false image size
@@ -159,6 +170,8 @@ def get_groundtruth(self, index, evaluation=False, flip_img=False):
159170
return target
160171

161172
def __len__(self):
173+
if self.custom_eval:
174+
return len(self.custom_files)
162175
return len(self.filenames)
163176

164177

maskrcnn_benchmark/engine/inference.py

+4
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ def inference(
141141
expected_results_sigma_tol=expected_results_sigma_tol,
142142
)
143143

144+
if cfg.TEST.CUSTUM_EVAL:
145+
torch.save(predictions, os.path.join(cfg.OUTPUT_DIR, 'custom_prediction.pytorch'))
146+
return -1.0
147+
144148
return evaluate(cfg=cfg,
145149
dataset=dataset,
146150
predictions=predictions,

maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def forward(self, features, proposals, targets=None):
6464
return x, proposals, {}
6565
else:
6666
# mode==sgdet
67-
proposals = self.samp_processor.assign_label_to_proposals(proposals, targets)
67+
if self.training or not self.cfg.TEST.CUSTUM_EVAL:
68+
proposals = self.samp_processor.assign_label_to_proposals(proposals, targets)
6869
x = self.feature_extractor(features, proposals)
6970
class_logits, box_regression = self.predictor(x)
7071
proposals = add_predict_logits(proposals, class_logits)

maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ def __init__(
2626
box_coder=None,
2727
cls_agnostic_bbox_reg=False,
2828
bbox_aug_enabled=False,
29-
save_proposals=False
29+
save_proposals=False,
30+
custum_eval=False
3031
):
3132
"""
3233
Arguments:
@@ -47,6 +48,7 @@ def __init__(
4748
self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
4849
self.bbox_aug_enabled = bbox_aug_enabled
4950
self.save_proposals = save_proposals
51+
self.custum_eval = custum_eval
5052

5153
def forward(self, x, boxes, relation_mode=False):
5254
"""
@@ -104,11 +106,12 @@ def forward(self, x, boxes, relation_mode=False):
104106

105107
def add_important_fields(self, i, boxes, orig_inds, boxlist, boxes_per_cls, relation_mode=False):
106108
if relation_mode:
107-
gt_labels = boxes[i].get_field('labels')[orig_inds]
108-
gt_attributes = boxes[i].get_field('attributes')[orig_inds]
109+
if not self.custum_eval:
110+
gt_labels = boxes[i].get_field('labels')[orig_inds]
111+
gt_attributes = boxes[i].get_field('attributes')[orig_inds]
109112

110-
boxlist.add_field('labels', gt_labels)
111-
boxlist.add_field('attributes', gt_attributes)
113+
boxlist.add_field('labels', gt_labels)
114+
boxlist.add_field('attributes', gt_attributes)
112115

113116
predict_logits = boxes[i].get_field('predict_logits')[orig_inds]
114117
boxlist.add_field('boxes_per_cls', boxes_per_cls)
@@ -238,6 +241,7 @@ def make_roi_box_post_processor(cfg):
238241
post_nms_per_cls_topn = cfg.MODEL.ROI_HEADS.POST_NMS_PER_CLS_TOPN
239242
nms_filter_duplicates = cfg.MODEL.ROI_HEADS.NMS_FILTER_DUPLICATES
240243
save_proposals = cfg.TEST.SAVE_PROPOSALS
244+
custum_eval = cfg.TEST.CUSTUM_EVAL
241245

242246
postprocessor = PostProcessor(
243247
score_thresh,
@@ -248,6 +252,7 @@ def make_roi_box_post_processor(cfg):
248252
box_coder,
249253
cls_agnostic_bbox_reg,
250254
bbox_aug_enabled,
251-
save_proposals
255+
save_proposals,
256+
custum_eval
252257
)
253258
return postprocessor

0 commit comments

Comments
 (0)