|
| 1 | +import penman |
| 2 | +from penman import layout |
| 3 | +from penman.graph import Graph |
| 4 | +from penman.transform import reify_attributes |
| 5 | + |
| 6 | +import re |
| 7 | +import json |
| 8 | +import argparse |
| 9 | +from pathlib import Path |
| 10 | +from collections import defaultdict |
| 11 | + |
| 12 | + |
| 13 | +def save_corpus(path, amr_analysis, concatenation=False): |
| 14 | + Path(path).parent.mkdir(parents=True, exist_ok=True) |
| 15 | + with open(path, 'w') as f: |
| 16 | + if concatenation: |
| 17 | + if not amr_analysis.graphs_concat_rel: |
| 18 | + amr_analysis.concat_rel() |
| 19 | + for amr_id, (g, g_concat) in amr_analysis.graphs_concat_rel.items(): |
| 20 | + meta_block = amr_analysis.info_dict[amr_id]['meta'] |
| 21 | + print(meta_block, file=f) |
| 22 | + pprint(g_concat, file=f) |
| 23 | + |
| 24 | + else: |
| 25 | + for amr_id in amr_analysis.info_dict: |
| 26 | + meta_block = amr_analysis.info_dict[amr_id]['meta'] |
| 27 | + print(meta_block, file=f) |
| 28 | + pprint(amr_analysis.info_dict[amr_id]['amr_string'], file=f) |
| 29 | + |
| 30 | + |
| 31 | +def pprint(l, reified=False, **args): |
| 32 | + if isinstance(l, dict): |
| 33 | + print('Key\tValue') |
| 34 | + for k, v in l.items(): |
| 35 | + print(f'{k}\t{v}', **args) |
| 36 | + |
| 37 | + elif isinstance(l, list) or isinstance(l, tuple) or isinstance(l, set): |
| 38 | + for el in l: |
| 39 | + print(el, **args) |
| 40 | + |
| 41 | + elif isinstance(l, penman.Graph): |
| 42 | + if reified: |
| 43 | + l = penman.encode(l) |
| 44 | + l = reify_rename_graph_from_string(l) |
| 45 | + print(penman.encode(l), **args) |
| 46 | + |
| 47 | + elif isinstance(l, penman.Tree): |
| 48 | + if reified: |
| 49 | + l = penman.format(l) |
| 50 | + l = reify_rename_graph_from_string(l) |
| 51 | + print(penman.encode(l), **args) |
| 52 | + else: |
| 53 | + print(penman.format(l), **args) |
| 54 | + |
| 55 | + elif isinstance(l, str): |
| 56 | + if reified: |
| 57 | + l = reify_rename_graph_from_string(l) |
| 58 | + print(penman.encode(l), **args) |
| 59 | + else: |
| 60 | + print(penman.format(penman.parse(l)), **args) |
| 61 | + |
| 62 | + else: |
| 63 | + raise ValueError('Unknown type') |
| 64 | + print(**args) |
| 65 | + |
| 66 | +class AMRAnalysis: |
| 67 | + def __init__(self, amr2text_alingnment_path, keep_meta=True, |
| 68 | + extended_meta=False, concat_rel=False): |
| 69 | + self.amr2text_alingnment_path = amr2text_alingnment_path |
| 70 | + self.keep_meta = keep_meta |
| 71 | + if extended_meta: |
| 72 | + self.keep_meta = self.extended_meta = True |
| 73 | + else: |
| 74 | + self.extended_meta = False |
| 75 | + self.info_dict = {} |
| 76 | + self.graphs_concat_rel = {} |
| 77 | + if concat_rel: |
| 78 | + self.concat_rel() |
| 79 | + else: |
| 80 | + self.extract_info() |
| 81 | + |
| 82 | + @staticmethod |
| 83 | + def reify_rename_graph_from_string(amr_string): |
| 84 | + |
| 85 | + g1 = reify_attributes(penman.decode(amr_string)) |
| 86 | + t1 = layout.configure(g1) |
| 87 | + t1.reset_variables(fmt='MRPNode-{i}') |
| 88 | + g1 = layout.interpret(t1) |
| 89 | + |
| 90 | + return g1 |
| 91 | + |
| 92 | + @staticmethod |
| 93 | + def alignment_labels2mrp_labels(amr_string): |
| 94 | + """Currently works only on reified graphs""" |
| 95 | + |
| 96 | + amr_graph = AMRAnalysis.reify_rename_graph_from_string(amr_string) |
| 97 | + epidata, triples = amr_graph.epidata, amr_graph.triples |
| 98 | + cur_label, popped = '0', False |
| 99 | + labels_dict = {cur_label:amr_graph.top} |
| 100 | + for triple in triples: |
| 101 | + cur_node = triple[0] |
| 102 | + epi = epidata[triple] |
| 103 | + if epi and isinstance(epi[0], penman.layout.Push): |
| 104 | + cur_node = epi[0].variable |
| 105 | + if not popped: |
| 106 | + cur_label += '.0' |
| 107 | + labels_dict[cur_label] = cur_node |
| 108 | + popped = False |
| 109 | + elif epi and isinstance(epi[0], penman.layout.Pop): |
| 110 | + pops_count = epi.count(epi[0]) |
| 111 | + split = cur_label.split('.') |
| 112 | + if popped: |
| 113 | + split = split[:len(split)-pops_count] |
| 114 | + else: |
| 115 | + split = split[:len(split)-pops_count+1] |
| 116 | + split[-1] = str(int(split[-1])+1) |
| 117 | + cur_label = '.'.join(split) |
| 118 | + popped = True |
| 119 | + |
| 120 | + return labels_dict, amr_graph |
| 121 | + |
| 122 | + @staticmethod |
| 123 | + def get_alignments_dict_from_string(alignments_string, alignment_pattern, toks, labels_dict): |
| 124 | + """ |
| 125 | + Somehow the alingnments string in 'new_alinged' does not contain |
| 126 | + all aligned nodes that are specified below ¯\_(ツ)_/¯ |
| 127 | + """ |
| 128 | + matches = re.match(alignment_pattern, alignments_string) |
| 129 | + if not matches: |
| 130 | + raise ValueError(f'Alignments string "{alignments_string}" has wrong format!\nCould not find alignments.') |
| 131 | + alignments = matches.group(1).split() |
| 132 | + alignments_dict = {} |
| 133 | + |
| 134 | + for alignment in alignments: |
| 135 | + parts = alignment.split('|') |
| 136 | + token_span = parts[0] |
| 137 | + #indices = span.split('-') |
| 138 | + #token_span = ' '.join(toks[int(indices[0]):int(indices[1])]) |
| 139 | + nodes = parts[1].split('+') |
| 140 | + nodes = [labels_dict[node] for node in nodes] |
| 141 | + for node in nodes: |
| 142 | + alignments_dict[node] = token_span |
| 143 | + return alignments_dict |
| 144 | + |
| 145 | + @staticmethod |
| 146 | + def get_alignments_dict(nodes_block, labels_dict, alignments_with_toks=False, toks=None): |
| 147 | + """ |
| 148 | + This function deals with the problem that was found while using the |
| 149 | + function above |
| 150 | + """ |
| 151 | + nodes_block = [spl_line for spl_line in nodes_block if len(spl_line) == 3] |
| 152 | + alignments_dict = {} |
| 153 | + for spl_line in nodes_block: |
| 154 | + node = spl_line[0] |
| 155 | + node = labels_dict[node] # '0.0.0' --> 'MRPNode2' |
| 156 | + token_span = spl_line[2] |
| 157 | + if alignments_with_toks: |
| 158 | + start_idx, end_idx = token_span.split('-') |
| 159 | + token_span = ' '.join(toks[int(start_idx):int(end_idx)]) |
| 160 | + alignments_dict[node] = token_span |
| 161 | + |
| 162 | + return alignments_dict |
| 163 | + |
| 164 | + def extract_info(self, alignments_with_toks=False): |
| 165 | + with open(self.amr2text_alingnment_path) as f: |
| 166 | + amrs = f.read().strip().split('\n\n') |
| 167 | + amrs = [amr.split('\n') for amr in amrs] |
| 168 | + |
| 169 | + alignment_pattern = re.compile(r'# ::alignments\s(.+?)\s::') |
| 170 | + |
| 171 | + for amr_analysis in amrs: |
| 172 | + amr_id = amr_analysis[0].split()[-1] |
| 173 | + |
| 174 | + toks = amr_analysis[2].split()[2:] # first 2 tokens are: '# ::tok' |
| 175 | + toks = [tok.lower() for tok in toks] |
| 176 | + |
| 177 | + amr_string = amr_analysis[-1] |
| 178 | + labels_dict, amr_graph = AMRAnalysis.alignment_labels2mrp_labels(amr_string) |
| 179 | + |
| 180 | + alignments_string = amr_analysis[3] |
| 181 | + nodes_block = [line.split()[2:] for line in amr_analysis if line.startswith('# ::node')] # first 2 tokens are: '# ::node' |
| 182 | + try: |
| 183 | + # function below works well, but the alignments string doesn't contain all alignments, |
| 184 | + # so a new function has to be defined |
| 185 | + #alignments_dict = AMRAnalysis.get_alignments_dict_from_string(alignments_string, alignment_pattern, toks, labels_dict) |
| 186 | + alignments_dict = AMRAnalysis.get_alignments_dict(nodes_block, labels_dict, alignments_with_toks, toks) |
| 187 | + alignments_dict = defaultdict(lambda: None, alignments_dict) |
| 188 | + except KeyError as e: |
| 189 | + print(amr_id) |
| 190 | + pprint(amr_string, reified=True) |
| 191 | + pprint(labels_dict) |
| 192 | + raise e |
| 193 | + |
| 194 | + self.info_dict[amr_id] = {'amr_string':penman.encode(amr_graph), \ |
| 195 | + 'toks':toks, \ |
| 196 | + 'alignments_dict':alignments_dict, \ |
| 197 | + 'labels_dict':labels_dict, \ |
| 198 | + 'amr_graph':amr_graph} |
| 199 | + if self.keep_meta: |
| 200 | + meta = amr_analysis[:3] # save '# ::id', '# ::snt' fields |
| 201 | + meta = '\n'.join(meta) |
| 202 | + self.info_dict[amr_id]['meta'] = meta |
| 203 | + if self.extended_meta: |
| 204 | + labels_dict_string = json.dumps(labels_dict) |
| 205 | + alignments_dict = json.dumps(alignments_dict) |
| 206 | + self.info_dict[amr_id]['meta'] += f'\n# ::labels_dict {labels_dict_string}\n# ::alignments_dict {alignments_dict}' |
| 207 | + |
| 208 | + return self |
| 209 | + |
| 210 | + @staticmethod |
| 211 | + def find_below(labels_dict): |
| 212 | + """ |
| 213 | + Finds nodes below a certain node using a dictionary of the following form |
| 214 | + (located in 'info_dict[amr_id]['labels_dict']'): |
| 215 | + |
| 216 | + Key Value |
| 217 | + 0 MRPNode-0 |
| 218 | + 0.0 MRPNode-1 |
| 219 | + 0.0.0 MRPNode-2 |
| 220 | + 0.0.0.0 MRPNode-3 |
| 221 | + 0.0.0.0.0 MRPNode-4 |
| 222 | + 0.0.0.0.1 MRPNode-5 |
| 223 | + 0.0.1 MRPNode-6 |
| 224 | + 0.0.1.0 MRPNode-7 |
| 225 | + |
| 226 | + Returns a dict where the key is the node label (e.g 'MRPNode-2') and |
| 227 | + the value is a list with all nodes represented as strings below it. |
| 228 | + """ |
| 229 | + nodes_below_dict = defaultdict(list) |
| 230 | + for key, value in labels_dict.items(): |
| 231 | + for k, v in labels_dict.items(): |
| 232 | + if k.startswith(key) and len(k) > len(key): |
| 233 | + nodes_below_dict[value].append(v) |
| 234 | + return nodes_below_dict |
| 235 | + |
| 236 | + @staticmethod |
| 237 | + def full_span(subtree_token_spans): |
| 238 | + """ |
| 239 | + Takes a list of token spans of a whole subtree |
| 240 | + and checks, if there are gaps. |
| 241 | + |
| 242 | + Returns a list of indices if a token span is full, else False. |
| 243 | + """ |
| 244 | + toks_indices = set() |
| 245 | + for token_span in subtree_token_spans: |
| 246 | + spl = token_span.split('-') |
| 247 | + i1, i2 = int(spl[0]), int(spl[1]) |
| 248 | + indices = set(range(i1, i2)) |
| 249 | + toks_indices.update(indices) |
| 250 | + if not toks_indices: |
| 251 | + return None |
| 252 | + minimum, maximum = min(toks_indices), max(toks_indices) |
| 253 | + toks_indices = sorted(list(toks_indices)) |
| 254 | + if toks_indices == list(range(minimum, maximum+1)): |
| 255 | + return toks_indices |
| 256 | + return None |
| 257 | + |
| 258 | + def concat_rel(self, rel=':mod'): |
| 259 | + if not self.info_dict: |
| 260 | + self.extract_info() |
| 261 | + self.graphs_concat_rel = {} |
| 262 | + |
| 263 | + # ONLY FOR DEBUGGING CERTAIN IDS!!! |
| 264 | + # DELETE FOR NORMAL USE!!! |
| 265 | + #self.info_dict = {k:v for k, v in self.info_dict.items() if k == '3'} |
| 266 | + |
| 267 | + for amr_id in self.info_dict: |
| 268 | + triples_filtered = [] |
| 269 | + g = self.info_dict[amr_id]['amr_graph'] |
| 270 | + toks = self.info_dict[amr_id]['toks'] |
| 271 | + alignments_dict = self.info_dict[amr_id]['alignments_dict'] |
| 272 | + nodes_below_dict = AMRAnalysis.find_below(self.info_dict[amr_id]['labels_dict']) |
| 273 | + instances_dict = defaultdict(lambda: None, {node:concept for node, _, concept in g.instances()}) |
| 274 | + reentrancies = defaultdict(lambda: None, g.reentrancies()) |
| 275 | + |
| 276 | + changed_instances = {} |
| 277 | + nodes_to_delete = [] |
| 278 | + epidata = {} |
| 279 | + |
| 280 | + for triple in g.triples: |
| 281 | + if triple[0] not in nodes_to_delete and triple[2] not in nodes_to_delete: |
| 282 | + if triple[1] == rel: |
| 283 | + invoked = triple[0] |
| 284 | + nodes_below_invoked = nodes_below_dict[invoked] |
| 285 | + nodes_below_invoked_with_invoked = nodes_below_invoked + [invoked] |
| 286 | + instances_below_invoked = [instances_dict[node] for node in nodes_below_invoked] |
| 287 | + |
| 288 | + span = [alignments_dict[node] for node in nodes_below_invoked_with_invoked if alignments_dict[node]] |
| 289 | + subtree_token_span = AMRAnalysis.full_span(span) |
| 290 | + reentrancies_below_invoked = any([reentrancies[node] for node in nodes_below_invoked]) |
| 291 | + |
| 292 | + if subtree_token_span and not reentrancies_below_invoked: |
| 293 | + merged = [toks[i] for i in subtree_token_span] |
| 294 | + num_nodes_in_subtree = len(nodes_below_invoked_with_invoked) |
| 295 | + changed_instances[invoked] = '_'.join(merged) + '_' + str(num_nodes_in_subtree) |
| 296 | + nodes_to_delete += nodes_below_invoked |
| 297 | + continue |
| 298 | + |
| 299 | + epidata[triple] = g.epidata[triple] |
| 300 | + triples_filtered.append(triple) |
| 301 | + |
| 302 | + for i in range(len(triples_filtered)): |
| 303 | + n, r, c = triples_filtered[i] |
| 304 | + old_tuple = (n, r, c) |
| 305 | + if n in changed_instances and r == ':instance': |
| 306 | + new_tuple = (n, r, changed_instances[n]) |
| 307 | + triples_filtered[i] = new_tuple |
| 308 | + epidata = {(k if k != old_tuple else new_tuple):(v if k != old_tuple else v+[penman.layout.Pop()]) |
| 309 | + for k, v in epidata.items()} |
| 310 | + |
| 311 | + new_g = Graph(triples=triples_filtered, epidata=epidata) |
| 312 | + new_t = layout.configure(new_g) |
| 313 | + new_t.reset_variables(fmt='MRPNode-{i}') |
| 314 | + new_g = layout.interpret(new_t) |
| 315 | + self.graphs_concat_rel[amr_id] = (g, new_g) |
| 316 | + |
| 317 | + collapsed_instance_nodes = len(nodes_to_delete) |
| 318 | + if self.keep_meta: |
| 319 | + self.info_dict[amr_id]['meta'] += f'\n# ::collapsed instance nodes {collapsed_instance_nodes}' |
| 320 | + #else: |
| 321 | + # self.info_dict[amr_id]['meta'] = f'# ::collapsed instance nodes {collapsed_instance_nodes}' |
| 322 | + |
| 323 | + return self |
| 324 | + |
| 325 | + |
| 326 | +def do_all_stuff(args): |
| 327 | + |
| 328 | + if (not args.concat_rel) and (not args.extended_meta): |
| 329 | + output_suffix = 'reif' |
| 330 | + |
| 331 | + elif args.concat_rel and args.extended_meta: |
| 332 | + output_suffix = 'concat_ext' |
| 333 | + |
| 334 | + elif args.concat_rel: |
| 335 | + output_suffix = 'concat' |
| 336 | + |
| 337 | + else: |
| 338 | + output_suffix = 'reif_ext' |
| 339 | + |
| 340 | + print(f'Input parameters: concat_rel={args.concat_rel}, extended_meta={args.extended_meta}') |
| 341 | + |
| 342 | + for i, f in enumerate(args.input[:2]): |
| 343 | + amr_analysis = AMRAnalysis(f, concat_rel=args.concat_rel, |
| 344 | + extended_meta=args.extended_meta) |
| 345 | + |
| 346 | + save = Path(f'{args.output_prefix}_corpus_{chr(97+i)}_{output_suffix}.amr') |
| 347 | + |
| 348 | + print(f'File: "{str(save)}" was sucessfully saved!') |
| 349 | + save_corpus(save, amr_analysis, concatenation=args.concat_rel) |
| 350 | + |
| 351 | + |
| 352 | +if __name__ == '__main__': |
| 353 | + parser = argparse.ArgumentParser() |
| 354 | + parser.add_argument('-i', '--input', nargs='+', |
| 355 | + help='path(s) of the amr2text alignment file') |
| 356 | + parser.add_argument('--extended_meta', action='store_true', default=False, |
| 357 | + help='defines whether alignment meta has to be added to AMRs') |
| 358 | + parser.add_argument('--concat_rel', action='store_true', default=False, |
| 359 | + help='defines whether AMR graphs have to be transformed according to their token alignments') |
| 360 | + parser.add_argument('--output_prefix', default='analysis/sts/STS2016', |
| 361 | + help='defines the prefix of the output file(s), e.g. if == STS2016 -> STS2016_corpus_(a|b)_(reif|ext|concat).amr') |
| 362 | + args = parser.parse_args() |
| 363 | + |
| 364 | + do_all_stuff(args) |
| 365 | + |
0 commit comments