6
6
import sng_parser
7
7
import json
8
8
from tqdm import tqdm
9
+ import collections
10
+ import torch
11
+ import torchtext as tt
12
+ import spacy
13
+
14
+ def make_vocab (all_caps , outpath , file_name_entity , file_name_relation , freq = 1 ):
15
+ counter_entity = collections .Counter ()
16
+ counter_relation = collections .Counter ()
17
+ # result = tt.vocab.Vocab(counter_obj, min_freq=1)
18
+
19
+ ent = os .path .join (outpath , file_name_entity )
20
+ rel = os .path .join (outpath , file_name_relation )
21
+ if not os .path .exists (ent ) or not os .path .exists (rel ):
22
+ print ("Generating Vocabulary." )
23
+ for k in tqdm (all_caps .keys ()):
24
+ caps = all_caps [k ]
25
+ raw_graphs = [sng_parser .parse (cap ) for cap in caps ]
26
+ for i , g in enumerate (raw_graphs ):
27
+ entities = g ["entities" ]
28
+ relations = g ["relations" ]
29
+ counter_entity .update ([e ["lemma_head" ] for e in entities ])
30
+ counter_relation .update ([r ["lemma_relation" ] for r in relations ])
31
+ # TODO find out the logic he used for Stop words or proper name
32
+ vocab_entity = tt .vocab .Vocab (counter_entity , min_freq = freq )
33
+ torch .save (vocab_entity , ent )
34
+ vocab_relation = tt .vocab .Vocab (counter_relation , min_freq = freq )
35
+ torch .save (vocab_relation , rel )
36
+ else :
37
+ print ("Loading Vocabulary." )
38
+ vocab_entity = torch .load (ent )
39
+ vocab_relation = torch .load (rel )
40
+ return vocab_entity , vocab_relation
41
+
42
+
43
+ def extract_text_graph (all_caps , entity_vocabulary , relation_vocabulary ):
44
+ """
45
+
46
+ :param all_caps:
47
+ :param entity_vocabulary:
48
+ :param relation_vocabulary:
49
+ :return:
50
+ """
51
+
52
+ new_graphs = {}
53
+ for k in tqdm (all_caps .keys ()):
54
+ caps = all_caps [k ]
55
+ raw_graphs = [sng_parser .parse (cap ) for cap in caps ]
56
+ cleaned_graphs = []
57
+ for i , g in enumerate (raw_graphs ):
58
+ entities = g ["entities" ]
59
+ relations = g ["relations" ]
60
+ # print(str(i),"\n")
61
+ # print(caps[i])
62
+ # print("\n")
63
+ # print(graphs[i])
64
+ # if len (entities) == 0 or len (relations) == 0:
65
+ # continue
66
+ # else:
67
+ # TODO find out the logic he used for Stop words or proper name
68
+ filtered_entities = [e ["lemma_head" ] if e ["lemma_head" ] in entity_vocabulary else 'none' for e in entities ]
69
+ filtered_relations = [[r ["subject" ], r ["object" ], r ["lemma_relation" ]] for r in relations if
70
+ r ["lemma_relation" ] in relation_vocabulary ]
71
+ extracted_graph = {'entities' : filtered_entities , 'relations' : filtered_relations }
72
+ cleaned_graphs .append (extracted_graph )
73
+
74
+ new_graphs [k ] = cleaned_graphs
75
+
76
+ return new_graphs
9
77
10
78
11
79
if __name__ == "__main__" :
18
86
type = str ,
19
87
)
20
88
89
+ parser .add_argument (
90
+ "--outpath" ,
91
+ default = "/media/rafi/Samsung_T5/_DATASETS/" ,
92
+ metavar = "FILE" ,
93
+ help = "path to config file" ,
94
+ type = str ,
95
+ )
96
+
97
+ parser .add_argument (
98
+ "--graph_file_name" ,
99
+ default = "new_graphs.json" ,
100
+ metavar = "FILE" ,
101
+ help = "file_name.json" ,
102
+ type = str ,
103
+ )
104
+
21
105
args = parser .parse_args ()
106
+ spacy .load ('en_core_web_sm' )
22
107
108
+ ent_file = "vocab_entity.pth"
109
+ rel_file = "vocab_relation.pth"
23
110
data_dir = DatasetCatalog .DATA_DIR
24
111
attrs = DatasetCatalog .DATASETS ["VG_stanford_filtered_with_attribute" ]
25
112
cap_graph_file = os .path .join (data_dir , attrs ["capgraphs_file" ])
36
123
entity_vocabulary = cap_graph ["cap_category" ].keys ()
37
124
relation_vocabulary = cap_graph ["cap_predicate" ].keys ()
38
125
126
+ entity_vocab , relation_vocab = make_vocab (all_caps , args .outpath , ent_file , rel_file )
127
+ news_graphs = extract_text_graph (all_caps , entity_vocabulary , relation_vocabulary )
39
128
129
+ with open (os .path .join (args .outpath , args .graph_file_name ), 'w' , encoding = 'utf-8' ) as f :
130
+ json .dump (news_graphs , f , ensure_ascii = False , indent = 4 )
131
+ print ("Saved graph" )
40
132
# Looks like "this" is a stop word => completely removed.
41
133
# Proper Name are replaced by 'none'
42
- new_graphs = {}
43
- for k in tqdm (all_caps .keys ()):
44
- caps = all_caps [k ]
45
- raw_graphs = [sng_parser .parse (cap ) for cap in caps ]
46
- cleaned_graphs = []
47
- for i , g in enumerate (raw_graphs ):
48
- entities = g ["entities" ]
49
- relations = g ["relations" ]
50
- # print(str(i),"\n")
51
- # print(caps[i])
52
- # print("\n")
53
- # print(graphs[i])
54
- # if len (entities) == 0 or len (relations) == 0:
55
- # continue
56
- # else:
57
- #TODO find out the logic he used for Stop words or proper name
58
- filtered_entities = [ e ["lemma_head" ] if e ["lemma_head" ] in entity_vocabulary else 'none' for e in entities ]
59
- filtered_relations = [ [ r ["subject" ], r ["object" ], r ["lemma_relation" ] ] for r in relations if r ["lemma_relation" ] in relation_vocabulary ]
60
- extracted_graph = {'entities' : filtered_entities , 'relations' : filtered_relations }
61
- cleaned_graphs .append (extracted_graph )
62
-
63
- new_graphs [k ] = cleaned_graphs
64
-
65
-
66
-
67
- pass
0 commit comments