diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index 99fcadac081..38cc6e255de 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -17,9 +17,8 @@ import logging import tempfile -from collections import deque from itertools import count -from typing import cast, Dict, final, List, Set +from typing import cast, Dict, final, List import tosa_serializer as ts from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec @@ -43,35 +42,36 @@ def _annotate_external_ids(ep_graph: Graph) -> Dict[str, int]: - """Assign deterministic output IDs to nodes reachable from graph outputs. + """Assign deterministic output IDs to leaf outputs. + + Flattens the output structure and assigns the external ID + based on the leaf position in the exported output tuple/list. Args: ep_graph (Graph): FX graph produced by export preprocessing. Returns: - dict[str, int]: Mapping from node name to external output index. - + dict[str, int]: Mapping from *leaf output node name* to external output index. """ node2external_id = {} - def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]): - """Walk producer graph from ``start_nodes`` and record external IDs.""" - q = deque(start_nodes) - while q: - n = q.popleft() - if n in seen: - continue - seen.add(n) - node2external_id[n.name] = idx - # Walk backwards so we touch every producer - q.extend(n.all_input_nodes) + def _collect_leaves(arg, nodes): + # Collect only FX Nodes that are actual outputs + # (ignore ints/None/etc inside structured outputs). + if isinstance(arg, Node): + nodes.append(arg) + elif isinstance(arg, (list, tuple)): + for a in arg: + _collect_leaves(a, nodes) out = ep_graph.output_node() - # First argument of output node is tuple of outputs - output_list = cast(tuple, out.args[0]) - seen: Set[Node] = set() - for idx, val in enumerate(output_list): - bfs_mark([val], idx, seen) + out_leaves: list[Node] = [] + # First argument of output is the structured container (tuple/list) of outputs + _collect_leaves(out.args[0], out_leaves) + + # Map each output leaf's name to its position + node2external_id = {leaf.name: idx for idx, leaf in enumerate(out_leaves)} + return node2external_id