Skip to content

Commit 22f4626

Browse files
committed
Minor changes, cleaned the repo, updated README
1 parent b37a874 commit 22f4626

26 files changed

+568
-691
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
.ipynb_checkpoints
33
/AMR2Text/*
44
/amr_suite/vectors
5+
__pycache__
6+
.gitignore
57

68
# Files not to ignore:
7-
!/AMR2Text/processed
9+
#!/AMR2Text/processed
810

AMRAnalysis.py

Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
import penman
2+
from penman import layout
3+
from penman.graph import Graph
4+
from penman.transform import reify_attributes
5+
6+
import re
7+
import json
8+
import argparse
9+
from pathlib import Path
10+
from collections import defaultdict
11+
12+
13+
def save_corpus(path, amr_analysis, concatenation=False):
14+
Path(path).parent.mkdir(parents=True, exist_ok=True)
15+
with open(path, 'w') as f:
16+
if concatenation:
17+
if not amr_analysis.graphs_concat_rel:
18+
amr_analysis.concat_rel()
19+
for amr_id, (g, g_concat) in amr_analysis.graphs_concat_rel.items():
20+
meta_block = amr_analysis.info_dict[amr_id]['meta']
21+
print(meta_block, file=f)
22+
pprint(g_concat, file=f)
23+
24+
else:
25+
for amr_id in amr_analysis.info_dict:
26+
meta_block = amr_analysis.info_dict[amr_id]['meta']
27+
print(meta_block, file=f)
28+
pprint(amr_analysis.info_dict[amr_id]['amr_string'], file=f)
29+
30+
31+
def pprint(l, reified=False, **args):
32+
if isinstance(l, dict):
33+
print('Key\tValue')
34+
for k, v in l.items():
35+
print(f'{k}\t{v}', **args)
36+
37+
elif isinstance(l, list) or isinstance(l, tuple) or isinstance(l, set):
38+
for el in l:
39+
print(el, **args)
40+
41+
elif isinstance(l, penman.Graph):
42+
if reified:
43+
l = penman.encode(l)
44+
l = reify_rename_graph_from_string(l)
45+
print(penman.encode(l), **args)
46+
47+
elif isinstance(l, penman.Tree):
48+
if reified:
49+
l = penman.format(l)
50+
l = reify_rename_graph_from_string(l)
51+
print(penman.encode(l), **args)
52+
else:
53+
print(penman.format(l), **args)
54+
55+
elif isinstance(l, str):
56+
if reified:
57+
l = reify_rename_graph_from_string(l)
58+
print(penman.encode(l), **args)
59+
else:
60+
print(penman.format(penman.parse(l)), **args)
61+
62+
else:
63+
raise ValueError('Unknown type')
64+
print(**args)
65+
66+
class AMRAnalysis:
67+
def __init__(self, amr2text_alingnment_path, keep_meta=True,
68+
extended_meta=False, concat_rel=False):
69+
self.amr2text_alingnment_path = amr2text_alingnment_path
70+
self.keep_meta = keep_meta
71+
if extended_meta:
72+
self.keep_meta = self.extended_meta = True
73+
else:
74+
self.extended_meta = False
75+
self.info_dict = {}
76+
self.graphs_concat_rel = {}
77+
if concat_rel:
78+
self.concat_rel()
79+
else:
80+
self.extract_info()
81+
82+
@staticmethod
83+
def reify_rename_graph_from_string(amr_string):
84+
85+
g1 = reify_attributes(penman.decode(amr_string))
86+
t1 = layout.configure(g1)
87+
t1.reset_variables(fmt='MRPNode-{i}')
88+
g1 = layout.interpret(t1)
89+
90+
return g1
91+
92+
@staticmethod
93+
def alignment_labels2mrp_labels(amr_string):
94+
"""Currently works only on reified graphs"""
95+
96+
amr_graph = AMRAnalysis.reify_rename_graph_from_string(amr_string)
97+
epidata, triples = amr_graph.epidata, amr_graph.triples
98+
cur_label, popped = '0', False
99+
labels_dict = {cur_label:amr_graph.top}
100+
for triple in triples:
101+
cur_node = triple[0]
102+
epi = epidata[triple]
103+
if epi and isinstance(epi[0], penman.layout.Push):
104+
cur_node = epi[0].variable
105+
if not popped:
106+
cur_label += '.0'
107+
labels_dict[cur_label] = cur_node
108+
popped = False
109+
elif epi and isinstance(epi[0], penman.layout.Pop):
110+
pops_count = epi.count(epi[0])
111+
split = cur_label.split('.')
112+
if popped:
113+
split = split[:len(split)-pops_count]
114+
else:
115+
split = split[:len(split)-pops_count+1]
116+
split[-1] = str(int(split[-1])+1)
117+
cur_label = '.'.join(split)
118+
popped = True
119+
120+
return labels_dict, amr_graph
121+
122+
@staticmethod
123+
def get_alignments_dict_from_string(alignments_string, alignment_pattern, toks, labels_dict):
124+
"""
125+
Somehow the alingnments string in 'new_alinged' does not contain
126+
all aligned nodes that are specified below ¯\_(ツ)_/¯
127+
"""
128+
matches = re.match(alignment_pattern, alignments_string)
129+
if not matches:
130+
raise ValueError(f'Alignments string "{alignments_string}" has wrong format!\nCould not find alignments.')
131+
alignments = matches.group(1).split()
132+
alignments_dict = {}
133+
134+
for alignment in alignments:
135+
parts = alignment.split('|')
136+
token_span = parts[0]
137+
#indices = span.split('-')
138+
#token_span = ' '.join(toks[int(indices[0]):int(indices[1])])
139+
nodes = parts[1].split('+')
140+
nodes = [labels_dict[node] for node in nodes]
141+
for node in nodes:
142+
alignments_dict[node] = token_span
143+
return alignments_dict
144+
145+
@staticmethod
146+
def get_alignments_dict(nodes_block, labels_dict, alignments_with_toks=False, toks=None):
147+
"""
148+
This function deals with the problem that was found while using the
149+
function above
150+
"""
151+
nodes_block = [spl_line for spl_line in nodes_block if len(spl_line) == 3]
152+
alignments_dict = {}
153+
for spl_line in nodes_block:
154+
node = spl_line[0]
155+
node = labels_dict[node] # '0.0.0' --> 'MRPNode2'
156+
token_span = spl_line[2]
157+
if alignments_with_toks:
158+
start_idx, end_idx = token_span.split('-')
159+
token_span = ' '.join(toks[int(start_idx):int(end_idx)])
160+
alignments_dict[node] = token_span
161+
162+
return alignments_dict
163+
164+
def extract_info(self, alignments_with_toks=False):
165+
with open(self.amr2text_alingnment_path) as f:
166+
amrs = f.read().strip().split('\n\n')
167+
amrs = [amr.split('\n') for amr in amrs]
168+
169+
alignment_pattern = re.compile(r'# ::alignments\s(.+?)\s::')
170+
171+
for amr_analysis in amrs:
172+
amr_id = amr_analysis[0].split()[-1]
173+
174+
toks = amr_analysis[2].split()[2:] # first 2 tokens are: '# ::tok'
175+
toks = [tok.lower() for tok in toks]
176+
177+
amr_string = amr_analysis[-1]
178+
labels_dict, amr_graph = AMRAnalysis.alignment_labels2mrp_labels(amr_string)
179+
180+
alignments_string = amr_analysis[3]
181+
nodes_block = [line.split()[2:] for line in amr_analysis if line.startswith('# ::node')] # first 2 tokens are: '# ::node'
182+
try:
183+
# function below works well, but the alignments string doesn't contain all alignments,
184+
# so a new function has to be defined
185+
#alignments_dict = AMRAnalysis.get_alignments_dict_from_string(alignments_string, alignment_pattern, toks, labels_dict)
186+
alignments_dict = AMRAnalysis.get_alignments_dict(nodes_block, labels_dict, alignments_with_toks, toks)
187+
alignments_dict = defaultdict(lambda: None, alignments_dict)
188+
except KeyError as e:
189+
print(amr_id)
190+
pprint(amr_string, reified=True)
191+
pprint(labels_dict)
192+
raise e
193+
194+
self.info_dict[amr_id] = {'amr_string':penman.encode(amr_graph), \
195+
'toks':toks, \
196+
'alignments_dict':alignments_dict, \
197+
'labels_dict':labels_dict, \
198+
'amr_graph':amr_graph}
199+
if self.keep_meta:
200+
meta = amr_analysis[:3] # save '# ::id', '# ::snt' fields
201+
meta = '\n'.join(meta)
202+
self.info_dict[amr_id]['meta'] = meta
203+
if self.extended_meta:
204+
labels_dict_string = json.dumps(labels_dict)
205+
alignments_dict = json.dumps(alignments_dict)
206+
self.info_dict[amr_id]['meta'] += f'\n# ::labels_dict {labels_dict_string}\n# ::alignments_dict {alignments_dict}'
207+
208+
return self
209+
210+
@staticmethod
211+
def find_below(labels_dict):
212+
"""
213+
Finds nodes below a certain node using a dictionary of the following form
214+
(located in 'info_dict[amr_id]['labels_dict']'):
215+
216+
Key Value
217+
0 MRPNode-0
218+
0.0 MRPNode-1
219+
0.0.0 MRPNode-2
220+
0.0.0.0 MRPNode-3
221+
0.0.0.0.0 MRPNode-4
222+
0.0.0.0.1 MRPNode-5
223+
0.0.1 MRPNode-6
224+
0.0.1.0 MRPNode-7
225+
226+
Returns a dict where the key is the node label (e.g 'MRPNode-2') and
227+
the value is a list with all nodes represented as strings below it.
228+
"""
229+
nodes_below_dict = defaultdict(list)
230+
for key, value in labels_dict.items():
231+
for k, v in labels_dict.items():
232+
if k.startswith(key) and len(k) > len(key):
233+
nodes_below_dict[value].append(v)
234+
return nodes_below_dict
235+
236+
@staticmethod
237+
def full_span(subtree_token_spans):
238+
"""
239+
Takes a list of token spans of a whole subtree
240+
and checks, if there are gaps.
241+
242+
Returns a list of indices if a token span is full, else False.
243+
"""
244+
toks_indices = set()
245+
for token_span in subtree_token_spans:
246+
spl = token_span.split('-')
247+
i1, i2 = int(spl[0]), int(spl[1])
248+
indices = set(range(i1, i2))
249+
toks_indices.update(indices)
250+
if not toks_indices:
251+
return None
252+
minimum, maximum = min(toks_indices), max(toks_indices)
253+
toks_indices = sorted(list(toks_indices))
254+
if toks_indices == list(range(minimum, maximum+1)):
255+
return toks_indices
256+
return None
257+
258+
def concat_rel(self, rel=':mod'):
259+
if not self.info_dict:
260+
self.extract_info()
261+
self.graphs_concat_rel = {}
262+
263+
# ONLY FOR DEBUGGING CERTAIN IDS!!!
264+
# DELETE FOR NORMAL USE!!!
265+
#self.info_dict = {k:v for k, v in self.info_dict.items() if k == '3'}
266+
267+
for amr_id in self.info_dict:
268+
triples_filtered = []
269+
g = self.info_dict[amr_id]['amr_graph']
270+
toks = self.info_dict[amr_id]['toks']
271+
alignments_dict = self.info_dict[amr_id]['alignments_dict']
272+
nodes_below_dict = AMRAnalysis.find_below(self.info_dict[amr_id]['labels_dict'])
273+
instances_dict = defaultdict(lambda: None, {node:concept for node, _, concept in g.instances()})
274+
reentrancies = defaultdict(lambda: None, g.reentrancies())
275+
276+
changed_instances = {}
277+
nodes_to_delete = []
278+
epidata = {}
279+
280+
for triple in g.triples:
281+
if triple[0] not in nodes_to_delete and triple[2] not in nodes_to_delete:
282+
if triple[1] == rel:
283+
invoked = triple[0]
284+
nodes_below_invoked = nodes_below_dict[invoked]
285+
nodes_below_invoked_with_invoked = nodes_below_invoked + [invoked]
286+
instances_below_invoked = [instances_dict[node] for node in nodes_below_invoked]
287+
288+
span = [alignments_dict[node] for node in nodes_below_invoked_with_invoked if alignments_dict[node]]
289+
subtree_token_span = AMRAnalysis.full_span(span)
290+
reentrancies_below_invoked = any([reentrancies[node] for node in nodes_below_invoked])
291+
292+
if subtree_token_span and not reentrancies_below_invoked:
293+
merged = [toks[i] for i in subtree_token_span]
294+
num_nodes_in_subtree = len(nodes_below_invoked_with_invoked)
295+
changed_instances[invoked] = '_'.join(merged) + '_' + str(num_nodes_in_subtree)
296+
nodes_to_delete += nodes_below_invoked
297+
continue
298+
299+
epidata[triple] = g.epidata[triple]
300+
triples_filtered.append(triple)
301+
302+
for i in range(len(triples_filtered)):
303+
n, r, c = triples_filtered[i]
304+
old_tuple = (n, r, c)
305+
if n in changed_instances and r == ':instance':
306+
new_tuple = (n, r, changed_instances[n])
307+
triples_filtered[i] = new_tuple
308+
epidata = {(k if k != old_tuple else new_tuple):(v if k != old_tuple else v+[penman.layout.Pop()])
309+
for k, v in epidata.items()}
310+
311+
new_g = Graph(triples=triples_filtered, epidata=epidata)
312+
new_t = layout.configure(new_g)
313+
new_t.reset_variables(fmt='MRPNode-{i}')
314+
new_g = layout.interpret(new_t)
315+
self.graphs_concat_rel[amr_id] = (g, new_g)
316+
317+
collapsed_instance_nodes = len(nodes_to_delete)
318+
if self.keep_meta:
319+
self.info_dict[amr_id]['meta'] += f'\n# ::collapsed instance nodes {collapsed_instance_nodes}'
320+
#else:
321+
# self.info_dict[amr_id]['meta'] = f'# ::collapsed instance nodes {collapsed_instance_nodes}'
322+
323+
return self
324+
325+
326+
def do_all_stuff(args):
327+
328+
if (not args.concat_rel) and (not args.extended_meta):
329+
output_suffix = 'reif'
330+
331+
elif args.concat_rel and args.extended_meta:
332+
output_suffix = 'concat_ext'
333+
334+
elif args.concat_rel:
335+
output_suffix = 'concat'
336+
337+
else:
338+
output_suffix = 'reif_ext'
339+
340+
print(f'Input parameters: concat_rel={args.concat_rel}, extended_meta={args.extended_meta}')
341+
342+
for i, f in enumerate(args.input[:2]):
343+
amr_analysis = AMRAnalysis(f, concat_rel=args.concat_rel,
344+
extended_meta=args.extended_meta)
345+
346+
save = Path(f'{args.output_prefix}_corpus_{chr(97+i)}_{output_suffix}.amr')
347+
348+
print(f'File: "{str(save)}" was sucessfully saved!')
349+
save_corpus(save, amr_analysis, concatenation=args.concat_rel)
350+
351+
352+
if __name__ == '__main__':
353+
parser = argparse.ArgumentParser()
354+
parser.add_argument('-i', '--input', nargs='+',
355+
help='path(s) of the amr2text alignment file')
356+
parser.add_argument('--extended_meta', action='store_true', default=False,
357+
help='defines whether alignment meta has to be added to AMRs')
358+
parser.add_argument('--concat_rel', action='store_true', default=False,
359+
help='defines whether AMR graphs have to be transformed according to their token alignments')
360+
parser.add_argument('--output_prefix', default='analysis/sts/STS2016',
361+
help='defines the prefix of the output file(s), e.g. if == STS2016 -> STS2016_corpus_(a|b)_(reif|ext|concat).amr')
362+
args = parser.parse_args()
363+
364+
do_all_stuff(args)
365+

0 commit comments

Comments
 (0)