Skip to content

Commit 634a6e2

Browse files
authored
Merge pull request #156 from rafiberlin/fix_image_retrieval_pipeline
Makes the image retrieval part work without too much effort. (should address issues #64 and #109)
2 parents 18c16b6 + 6c4d1bb commit 634a6e2

19 files changed

+1053
-289
lines changed

INSTALL.md

+6
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,14 @@ python setup.py build_ext install
4646
cd $INSTALL_DIR
4747
git clone https://github.com/NVIDIA/apex.git
4848
cd apex
49+
50+
# WARNING if you use older Versions of Pytorch (anything below 1.7), you will need a hard reset,
51+
# as the newer version of apex does require newer pytorch versions. Ignore the hard reset otherwise.
52+
git reset --hard 3fe10b5597ba14a748ebb271a6ab97c09c5701ac
53+
4954
python setup.py install --cuda_ext --cpp_ext
5055

56+
5157
# install PyTorch Detection
5258
cd $INSTALL_DIR
5359
git clone https://github.com/KaihuaTang/Scene-Graph-Benchmark.pytorch.git

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ MOTIFS-PredCls-none | 59.64 | 66.11 | 67.96 | 11.46 | 14.60 | 15.84 | 5.79 | 11.
176176
MOTIFS-PredCls-TDE | 33.38 | 45.88 | 51.25 | 17.85 | 24.75 | 28.70 | 8.28 | 14.31 | 18.04
177177

178178
## SGDet on Custom Images
179-
Note that evaluation on custum images is only applicable for SGDet model, because PredCls and SGCls model requires additional ground-truth bounding boxes information. To detect scene graphs into a json file on your own images, you need to turn on the switch TEST.CUSTUM_EVAL and give a folder path that contains the custom images to TEST.CUSTUM_PATH. Only JPG files are allowed. The output will be saved as custom_prediction.json in the given DETECTED_SGG_DIR.
179+
Note that evaluation on custum images is only applicable for SGDet model, because PredCls and SGCls model requires additional ground-truth bounding boxes information. To detect scene graphs into a json file on your own images, you need to turn on the switch TEST.CUSTUM_EVAL and give a folder path (or a json file containing a list of image paths) that contains the custom images to TEST.CUSTUM_PATH. Only JPG files are allowed. The output will be saved as custom_prediction.json in the given DETECTED_SGG_DIR.
180180

181181
Test Example 1 : (SGDet, **Causal TDE**, MOTIFS Model, SUM Fusion) [(checkpoint)](https://onedrive.live.com/embed?cid=22376FFAD72C4B64&resid=22376FFAD72C4B64%21781947&authkey=AF_EM-rkbMyT3gs)
182182
```bash

maskrcnn_benchmark/config/defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
_C.DATASETS.VAL = ()
7979
# List of the dataset names for testing, as present in paths_catalog.py
8080
_C.DATASETS.TEST = ()
81+
_C.DATASETS.TO_TEST = None
8182

8283
# -----------------------------------------------------------------------------
8384
# DataLoader

maskrcnn_benchmark/config/paths_catalog.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77

88
class DatasetCatalog(object):
9-
DATA_DIR = "datasets"
9+
#DATA_DIR = "/home/users/alatif/data/ImageCorpora/"
10+
DATA_DIR = "/media/rafi/Samsung_T5/_DATASETS/"
1011
DATASETS = {
1112
"coco_2017_train": {
1213
"img_dir": "coco/train2017",
@@ -116,6 +117,7 @@ class DatasetCatalog(object):
116117
"roidb_file": "vg/VG-SGG-with-attri.h5",
117118
"dict_file": "vg/VG-SGG-dicts-with-attri.json",
118119
"image_file": "vg/image_data.json",
120+
"capgraphs_file": "vg/vg_capgraphs_anno.json",
119121
},
120122
}
121123

maskrcnn_benchmark/data/build.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,17 @@ def get_dataset_statistics(cfg):
3939
logger.info('Loading data statistics from: ' + str(save_file))
4040
logger.info('-'*100)
4141
return torch.load(save_file, map_location=torch.device("cpu"))
42+
else:
43+
logger.info('Unable to load data statistics from: ' + str(save_file))
4244

4345
statistics = []
4446
for dataset_name in dataset_names:
4547
data = DatasetCatalog.get(dataset_name, cfg)
4648
factory = getattr(D, data["factory"])
4749
args = data["args"]
50+
# Remove it because not part of the original repo (factory cant deal with additional parameters...).
51+
if "capgraphs_file" in args.keys():
52+
del args["capgraphs_file"]
4853
dataset = factory(**args)
4954
statistics.append(dataset.get_statistics())
5055
logger.info('finish')
@@ -89,6 +94,11 @@ def build_dataset(cfg, dataset_list, transforms, dataset_catalog, is_train=True)
8994
if data["factory"] == "PascalVOCDataset":
9095
args["use_difficult"] = not is_train
9196
args["transforms"] = transforms
97+
98+
#Remove it because not part of the original repo (factory cant deal with additional parameters...).
99+
if "capgraphs_file" in args.keys():
100+
del args["capgraphs_file"]
101+
92102
# make dataset from factory
93103
dataset = factory(**args)
94104
datasets.append(dataset)
@@ -153,8 +163,14 @@ def make_batch_data_sampler(
153163
return batch_sampler
154164

155165

156-
def make_data_loader(cfg, mode='train', is_distributed=False, start_iter=0):
166+
def make_data_loader(cfg, mode='train', is_distributed=False, start_iter=0, dataset_to_test=None):
157167
assert mode in {'train', 'val', 'test'}
168+
assert dataset_to_test in {'train', 'val', 'test', None}
169+
# this variable enable to run a test on any data split, even on the training dataset
170+
# without actually flagging it for training....
171+
if dataset_to_test is None:
172+
dataset_to_test = mode
173+
158174
num_gpus = get_world_size()
159175
is_train = mode == 'train'
160176
if is_train:
@@ -199,9 +215,9 @@ def make_data_loader(cfg, mode='train', is_distributed=False, start_iter=0):
199215
"maskrcnn_benchmark.config.paths_catalog", cfg.PATHS_CATALOG, True
200216
)
201217
DatasetCatalog = paths_catalog.DatasetCatalog
202-
if mode == 'train':
218+
if dataset_to_test == 'train':
203219
dataset_list = cfg.DATASETS.TRAIN
204-
elif mode == 'val':
220+
elif dataset_to_test == 'val':
205221
dataset_list = cfg.DATASETS.VAL
206222
else:
207223
dataset_list = cfg.DATASETS.TEST

maskrcnn_benchmark/data/datasets/visual_genome.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,17 @@ def get_statistics(self):
115115
def get_custom_imgs(self, path):
116116
self.custom_files = []
117117
self.img_info = []
118-
for file_name in os.listdir(path):
119-
self.custom_files.append(os.path.join(path, file_name))
120-
img = Image.open(os.path.join(path, file_name)).convert("RGB")
121-
self.img_info.append({'width':int(img.width), 'height':int(img.height)})
118+
if os.path.isdir(path):
119+
for file_name in tqdm(os.listdir(path)):
120+
self.custom_files.append(os.path.join(path, file_name))
121+
img = Image.open(os.path.join(path, file_name)).convert("RGB")
122+
self.img_info.append({'width':int(img.width), 'height':int(img.height)})
123+
# Expecting a list of paths in a json file
124+
if os.path.isfile(path):
125+
file_list = json.load(open(path))
126+
for file in tqdm(file_list):
127+
img = Image.open(file).convert("RGB")
128+
self.img_info.append({'width': int(img.width), 'height': int(img.height)})
122129

123130
def get_img_info(self, index):
124131
# WARNING: original image_file.json has several pictures with false image size

maskrcnn_benchmark/image_retrieval/S2G-RETRIEVAL.md

+56-4
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,70 @@
11
# Sentence-to-Graph Retrieval (S2G)
22

3-
Forgive me, this part of code is ugly and less organized.
3+
Warning - this part of code is less organized.
44

55
## Preprocessing
66

7-
Run the ```maskrcnn_benchmark/image_retrieval/preprocessing.py``` to process the annotations and checkpoints, where ```detected_path``` should be set to the corresponding checkpoints you want to use, ```vg_data, vg_dict, vg_info``` should have already downloaded if you followed DATASET.md, ```cap_graph``` is the ground-truth captions and generated sentence graphs (you can download it from [here](https://onedrive.live.com/embed?cid=22376FFAD72C4B64&resid=22376FFAD72C4B64%21779999&authkey=AGW0Wxjb1JSDFnc)). We use [SceneGraphParser](https://github.com/vacancy/SceneGraphParser) to generate these sentence graphs.
7+
Pre-requisite: ```vg_data, vg_dict, vg_info``` should have already downloaded if you followed DATASET.md.
88

9-
You also need to set the ```cap_graph``` PATH and ```vg_dict``` PATH in ```maskrcnn_benchmark/image_retrieval/dataloader.py``` manually.
9+
You will also need a pre-trained SGDet model, for example from [here](https://onedrive.live.com/embed?cid=22376FFAD72C4B64&resid=22376FFAD72C4B64%21781947&authkey=AF_EM-rkbMyT3gs). This is the SGDet model that is beeing described in the main `README.md`
10+
11+
Download the ground-truth captions and generated sentence graphs from [here](https://onedrive.live.com/embed?cid=22376FFAD72C4B64&resid=22376FFAD72C4B64%21779999&authkey=AGW0Wxjb1JSDFnc).
12+
13+
Please note that this file needs to be configured properly in maskrcnn_benchmark/config/paths_catalog.py, See `DATASETS`, `VG_stanford_filtered_with_attribute` under the key `capgraphs_file`.
14+
15+
We used [SceneGraphParser](https://github.com/vacancy/SceneGraphParser) to generate these sentence graphs.
16+
The script ```maskrcnn_benchmark/image_retrieval/sentence_to_graph_processing.py``` partially shows, how the text scene graphs were generated (under the key `vg_coco_id_to_capgraphs` in the dowloaded generated sentence graphs file).
17+
18+
19+
Create the test results of the SGDet model for the training and test datasets with:
20+
21+
```bash
22+
CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --master_port 10027 --nproc_per_node=1 tools/relation_test_net.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" MODEL.ROI_RELATION_HEAD.USE_GT_BOX False MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL False MODEL.ROI_RELATION_HEAD.PREDICTOR CausalAnalysisPredictor MODEL.ROI_RELATION_HEAD.CAUSAL.EFFECT_TYPE TDE MODEL.ROI_RELATION_HEAD.CAUSAL.FUSION_TYPE sum MODEL.ROI_RELATION_HEAD.CAUSAL.CONTEXT_LAYER motifs TEST.IMS_PER_BATCH 1 DTYPE "float16" GLOVE_DIR /home/kaihua/glove MODEL.PRETRAINED_DETECTOR_CKPT /home/kaihua/checkpoints/causal-motifs-sgdet OUTPUT_DIR /home/kaihua/checkpoints/causal-motifs-sgdet DATASETS.TO_TEST train
23+
```
24+
25+
```bash
26+
CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --master_port 10027 --nproc_per_node=1 tools/relation_test_net.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" MODEL.ROI_RELATION_HEAD.USE_GT_BOX False MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL False MODEL.ROI_RELATION_HEAD.PREDICTOR CausalAnalysisPredictor MODEL.ROI_RELATION_HEAD.CAUSAL.EFFECT_TYPE TDE MODEL.ROI_RELATION_HEAD.CAUSAL.FUSION_TYPE sum MODEL.ROI_RELATION_HEAD.CAUSAL.CONTEXT_LAYER motifs TEST.IMS_PER_BATCH 1 DTYPE "float16" GLOVE_DIR /home/kaihua/glove MODEL.PRETRAINED_DETECTOR_CKPT /home/kaihua/checkpoints/causal-motifs-sgdet OUTPUT_DIR /home/kaihua/checkpoints/causal-motifs-sgdet DATASETS.TO_TEST test
27+
```
28+
29+
It will create under `/home/kaihua/checkpoints/causal-motifs-sgdet/inference/` the directories `VG_stanford_filtered_with_attribute_train` and `VG_stanford_filtered_with_attribute_test` with saved results.
30+
31+
Now, run the ```maskrcnn_benchmark/image_retrieval/preprocessing.py --test-results-path your-result-path --output-file-name outfile.json``` for both training and testing results previously produced.
32+
33+
You should be obtaining two files:
34+
35+
`/home/kaihua/checkpoints/causal-motifs-sgdet/inference/VG_stanford_filtered_with_attribute_train/sg_of_causal_sgdet_ctx_only.json`
36+
37+
and
38+
39+
`/home/kaihua/checkpoints/causal-motifs-sgdet/inference/VG_stanford_filtered_with_attribute_test/sg_of_causal_sgdet_ctx_only.json`
1040

1141
## Training and Evaluation
1242

43+
You need to manually set ```sg_train_path```, ```sg_val_path``` and ```sg_test_path``` in ```tools/image_retrieval_main.py``` to `/home/kaihua/checkpoints/causal-motifs-sgdet/inference/VG_stanford_filtered_with_attribute_train/sg_of_causal_sgdet_ctx_only.json`
44+
, `/home/kaihua/checkpoints/causal-motifs-sgdet/inference/VG_stanford_filtered_with_attribute_val/sg_of_causal_sgdet_ctx_only.json`
45+
and
46+
47+
`/home/kaihua/checkpoints/causal-motifs-sgdet/inference/VG_stanford_filtered_with_attribute_test/sg_of_causal_sgdet_ctx_only.json` respectively.
48+
49+
50+
If you use your own pretrained model: keep in mind that you need to evaluate your model on ** training, validation and testing set ** to get the generated crude scene graphs. Our evaluation code will automatically saves the crude SGGs into ```checkpoints/MODEL_NAME/inference/VG_stanford_filtered_with_attribute_test/``` or ```checkpoints/MODEL_NAME/inference/VG_stanford_filtered_with_attribute_train/```
51+
or ```checkpoints/MODEL_NAME/inference/VG_stanford_filtered_with_attribute_val/```
52+
53+
54+
1355
Run the ```tools/image_retrieval_main.py``` for both training and evaluation.
1456

15-
To load the generated scene graphs of the given SGG checkpoints, you need to manually set ```sg_train_path``` and ```sg_test_path``` in ```tools/image_retrieval_main.py```, which means you need to evaluate your model on **both training and testing set** to get the generated crude scene graphs. Our evaluation code will automatically saves the crude SGGs into ```checkpoints/MODEL_NAME/inference/VG_stanford_filtered_wth_attribute_test/``` or ```checkpoints/MODEL_NAME/inference/VG_stanford_filtered_wth_attribute_train/```, which will be further processed to generate the input of ```sg_train_path``` and ```sg_test_path``` by our preprocessing code ```maskrcnn_benchmark/image_retrieval/preprocessing.py```.
57+
For example, you can train it with:
58+
59+
```tools/image_retrieval_main.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" SOLVER.IMS_PER_BATCH 32 SOLVER.PRE_VAL True SOLVER.SCHEDULE.TYPE WarmupMultiStepLR SOLVER.MAX_ITER 18 SOLVER.CHECKPOINT_PERIOD 3 OUTPUT_DIR /media/rafi/Samsung_T5/_DATASETS/vg/model/ SOLVER.VAL_PERIOD 3```
60+
61+
You call also run an evaluation on any set (parameter `DATASETS.TO_TEST`) with:
62+
63+
```tools/image_retrieval_test.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" SOLVER.IMS_PER_BATCH 32 MODEL.PRETRAINED_DETECTOR_CKPT /media/rafi/Samsung_T5/_DATASETS/vg/model/[your_model_name].pytorch OUTPUT_DIR /media/rafi/Samsung_T5/_DATASETS/vg/model/results DATASETS.TO_TEST test```
64+
65+
Please note that the calculation logic differs from the one used in ```tools/image_retrieval_main.py```.
66+
Details of the calculation can be found under ```Test Cases Metrics.pdf```, under the Type Fei Fei.
67+
1668

1769
## Results
1870

Binary file not shown.

maskrcnn_benchmark/image_retrieval/dataloader.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from tqdm import tqdm
1717

1818
from maskrcnn_benchmark.config import cfg
19+
from maskrcnn_benchmark.config.paths_catalog import DatasetCatalog
1920
from maskrcnn_benchmark.data import make_data_loader
2021
from maskrcnn_benchmark.solver import make_lr_scheduler
2122
from maskrcnn_benchmark.solver import make_optimizer
@@ -30,13 +31,20 @@
3031
from maskrcnn_benchmark.utils.logger import setup_logger, debug_print
3132
from maskrcnn_benchmark.utils.miscellaneous import mkdir, save_config
3233
from maskrcnn_benchmark.utils.metric_logger import MetricLogger
34+
import os
3335

3436
class SGEncoding(data.Dataset):
3537
""" SGEncoding dataset """
3638
def __init__(self, train_ids, test_ids, sg_data, test_on=False, val_on=False, num_test=5000, num_val=5000):
3739
super(SGEncoding, self).__init__()
38-
cap_graph = json.load(open('/data1/vg_capgraphs_anno.json'))
39-
vg_dict = json.load(open('/home/kaihua/projects/maskrcnn-benchmark/datasets/vg/VG-SGG-dicts-with-attri.json'))
40+
41+
data_dir = DatasetCatalog.DATA_DIR
42+
attrs = DatasetCatalog.DATASETS["VG_stanford_filtered_with_attribute"]
43+
cap_graph_file = os.path.join(data_dir, attrs["capgraphs_file"])
44+
vg_dict_file = os.path.join(data_dir, attrs["dict_file"])
45+
46+
cap_graph = json.load(open(cap_graph_file))
47+
vg_dict = json.load(open(vg_dict_file))
4048
self.img_txt_sg = sg_data
4149
self.key_list = list(self.img_txt_sg.keys())
4250
self.key_list.sort()

0 commit comments

Comments
 (0)