Skip to content

Commit 19d2cac

Browse files
committed
WIP
1 parent 752353c commit 19d2cac

File tree

3 files changed

+104
-32
lines changed

3 files changed

+104
-32
lines changed

maskrcnn_benchmark/image_retrieval/preprocessing.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,9 @@ def img_coco_mapping():
263263
valid_ids.append(img_id)
264264

265265
output = generate_detect_sg(detected_result, detected_info, valid_ids, img_coco, obj_thres = 0.1)
266-
267-
txt_img_sg = generate_txt_img_sg(output, cap_graph['vg_coco_id_to_capgraphs'])
266+
#cap_graph['vg_coco_id_to_capgraphs']
267+
cap_graph_new = json.load(open(os.path.join("/home/users/alatif/data/ImageCorpora/vg/checkpoint/causal-motifs-sgdet/inference/new_graphs.json")))
268+
txt_img_sg = generate_txt_img_sg(output, cap_graph_new)
268269

269270
with open(output_path, 'w') as outfile:
270271
json.dump(txt_img_sg, outfile)
@@ -276,21 +277,24 @@ def img_coco_mapping():
276277
# --output-file-name is the name of the output file (will be created under the path given for --test-results-path )
277278
if __name__ == "__main__":
278279
parser = argparse.ArgumentParser(description="Preprocessing of Scene Graphs for Image Retrieval")
279-
280+
type = "new"
280281
parser.add_argument(
281282
"--test-results-path",
282-
default='/home/users/alatif/data/ImageCorpora/vg/checkpoint/causal-motifs-sgdet/inference/VG_stanford_filtered_with_attribute_test/',
283+
default=f"/home/users/alatif/data/ImageCorpora/vg/checkpoint/causal-motifs-sgdet/inference/VG_stanford_filtered_with_attribute_{split}/",
283284
help="path to config file",
284285
)
285286

286287
parser.add_argument(
287288
"--output-file-name",
288-
default="sg_of_causal_sgdet_ctx_only.json",
289+
default=f"{type}_sg_of_causal_sgdet_ctx_only.json",
289290
help="creates this file under the path specified with --test-results-path",
290291
)
291292

292293
args = parser.parse_args()
293294

295+
path = lambda s : f"/home/users/alatif/data/ImageCorpora/vg/checkpoint/causal-motifs-sgdet/inference/VG_stanford_filtered_with_attribute_{s}"
296+
l = ["train", "val", "test"]
294297
path_to_test_results = args.test_results_path
295298
outputfile_name = args.output_file_name
296-
preprocess_scene_graphs_output(path_to_test_results, outputfile_name)
299+
for s in l:
300+
preprocess_scene_graphs_output(path(s), outputfile_name)

maskrcnn_benchmark/image_retrieval/sentence_to_graph_processing.py

+92-26
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,74 @@
66
import sng_parser
77
import json
88
from tqdm import tqdm
9+
import collections
10+
import torch
11+
import torchtext as tt
12+
import spacy
13+
14+
def make_vocab(all_caps, outpath, file_name_entity, file_name_relation, freq=1):
15+
counter_entity = collections.Counter()
16+
counter_relation = collections.Counter()
17+
# result = tt.vocab.Vocab(counter_obj, min_freq=1)
18+
19+
ent = os.path.join(outpath, file_name_entity)
20+
rel = os.path.join(outpath, file_name_relation)
21+
if not os.path.exists(ent) or not os.path.exists(rel):
22+
print("Generating Vocabulary.")
23+
for k in tqdm(all_caps.keys()):
24+
caps = all_caps[k]
25+
raw_graphs = [sng_parser.parse(cap) for cap in caps]
26+
for i, g in enumerate(raw_graphs):
27+
entities = g["entities"]
28+
relations = g["relations"]
29+
counter_entity.update([e["lemma_head"] for e in entities])
30+
counter_relation.update([r["lemma_relation"] for r in relations])
31+
# TODO find out the logic he used for Stop words or proper name
32+
vocab_entity = tt.vocab.Vocab(counter_entity, min_freq=freq)
33+
torch.save(vocab_entity, ent )
34+
vocab_relation = tt.vocab.Vocab(counter_relation, min_freq=freq)
35+
torch.save(vocab_relation, rel )
36+
else:
37+
print("Loading Vocabulary.")
38+
vocab_entity = torch.load(ent)
39+
vocab_relation = torch.load(rel)
40+
return vocab_entity, vocab_relation
41+
42+
43+
def extract_text_graph(all_caps, entity_vocabulary, relation_vocabulary):
44+
"""
45+
46+
:param all_caps:
47+
:param entity_vocabulary:
48+
:param relation_vocabulary:
49+
:return:
50+
"""
51+
52+
new_graphs = {}
53+
for k in tqdm(all_caps.keys()):
54+
caps = all_caps[k]
55+
raw_graphs = [sng_parser.parse(cap) for cap in caps]
56+
cleaned_graphs = []
57+
for i, g in enumerate(raw_graphs):
58+
entities = g["entities"]
59+
relations = g["relations"]
60+
# print(str(i),"\n")
61+
# print(caps[i])
62+
# print("\n")
63+
# print(graphs[i])
64+
# if len (entities) == 0 or len (relations) == 0:
65+
# continue
66+
# else:
67+
# TODO find out the logic he used for Stop words or proper name
68+
filtered_entities = [e["lemma_head"] if e["lemma_head"] in entity_vocabulary else 'none' for e in entities]
69+
filtered_relations = [[r["subject"], r["object"], r["lemma_relation"]] for r in relations if
70+
r["lemma_relation"] in relation_vocabulary]
71+
extracted_graph = {'entities': filtered_entities, 'relations': filtered_relations}
72+
cleaned_graphs.append(extracted_graph)
73+
74+
new_graphs[k] = cleaned_graphs
75+
76+
return new_graphs
977

1078

1179
if __name__ == "__main__":
@@ -18,8 +86,27 @@
1886
type=str,
1987
)
2088

89+
parser.add_argument(
90+
"--outpath",
91+
default="/media/rafi/Samsung_T5/_DATASETS/",
92+
metavar="FILE",
93+
help="path to config file",
94+
type=str,
95+
)
96+
97+
parser.add_argument(
98+
"--graph_file_name",
99+
default="new_graphs.json",
100+
metavar="FILE",
101+
help="file_name.json",
102+
type=str,
103+
)
104+
21105
args = parser.parse_args()
106+
spacy.load('en_core_web_sm')
22107

108+
ent_file = "vocab_entity.pth"
109+
rel_file = "vocab_relation.pth"
23110
data_dir = DatasetCatalog.DATA_DIR
24111
attrs = DatasetCatalog.DATASETS["VG_stanford_filtered_with_attribute"]
25112
cap_graph_file = os.path.join(data_dir, attrs["capgraphs_file"])
@@ -36,32 +123,11 @@
36123
entity_vocabulary = cap_graph["cap_category"].keys()
37124
relation_vocabulary = cap_graph["cap_predicate"].keys()
38125

126+
entity_vocab, relation_vocab = make_vocab(all_caps, args.outpath, ent_file, rel_file)
127+
news_graphs = extract_text_graph(all_caps, entity_vocabulary, relation_vocabulary)
39128

129+
with open(os.path.join(args.outpath, args.graph_file_name), 'w', encoding='utf-8') as f:
130+
json.dump(news_graphs, f, ensure_ascii=False, indent=4)
131+
print("Saved graph")
40132
# Looks like "this" is a stop word => completely removed.
41133
# Proper Name are replaced by 'none'
42-
new_graphs = {}
43-
for k in tqdm(all_caps.keys()):
44-
caps = all_caps[k]
45-
raw_graphs = [sng_parser.parse(cap) for cap in caps]
46-
cleaned_graphs = []
47-
for i, g in enumerate(raw_graphs):
48-
entities = g["entities"]
49-
relations = g["relations"]
50-
# print(str(i),"\n")
51-
# print(caps[i])
52-
# print("\n")
53-
# print(graphs[i])
54-
# if len (entities) == 0 or len (relations) == 0:
55-
# continue
56-
# else:
57-
#TODO find out the logic he used for Stop words or proper name
58-
filtered_entities = [ e["lemma_head"] if e["lemma_head"] in entity_vocabulary else 'none' for e in entities ]
59-
filtered_relations = [ [ r["subject"], r["object"], r["lemma_relation"] ] for r in relations if r["lemma_relation"] in relation_vocabulary]
60-
extracted_graph = {'entities': filtered_entities, 'relations': filtered_relations}
61-
cleaned_graphs.append(extracted_graph)
62-
63-
new_graphs[k] = cleaned_graphs
64-
65-
66-
67-
pass

requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@ yacs
33
cython
44
matplotlib
55
tqdm
6+
torchtext==0.4
7+
pandas

0 commit comments

Comments
 (0)