diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py index 170c81f571e..01045590f1e 100644 --- a/backends/cadence/aot/pass_utils.py +++ b/backends/cadence/aot/pass_utils.py @@ -174,30 +174,53 @@ def nodes_not_adjacent_in_gm( def get_arg( node: torch.fx.Node, - arg_index: int, kwarg_name: str, - *, - default: torch.fx.node.Argument = None, ) -> torch.fx.node.Argument: """ - Get the arg at arg_index or kwarg with arg_name of the node. If neither is found - return default. + Get the arg with arg_name of the node, returns default value if not set. """ - if arg_index < len(node.args): - return node.args[arg_index] - elif kwarg_name in node.kwargs: + # Try to get the arg from kwargs first since this is faster + if kwarg_name in node.kwargs: return node.kwargs[kwarg_name] - else: - return default + + # If it's not found in kwargs, try to normalize the args + normalized_args = node.normalized_arguments( + node.graph.owning_module, normalize_to_only_use_kwargs=True + ) + if not normalized_args: + raise RuntimeError( + f"get_arg: Node {node} does not support normalization of arguments" + ) + + return normalized_args.kwargs[kwarg_name] def set_arg( - node: torch.fx.Node, arg_index: int, kwarg_name: str, value: torch.fx.node.Argument + node: torch.fx.Node, kwarg_name: str, value: torch.fx.node.Argument ) -> None: """ - Set the arg at arg_index if it exists, otherwise set the kwarg. + Set the node's arg with its name to the given value. """ - if arg_index < len(node.args): - node.update_arg(arg_index, value) + # Try to set the arg if it is present in kwargs first since this is faster + if kwarg_name in node.kwargs: + node.update_kwarg(kwarg_name, value) + return + + # If it's not found in kwargs, try to normalize the args and set the arg + normalized_args = node.normalized_arguments( + node.graph.owning_module, normalize_to_only_use_kwargs=True + ) + if not normalized_args: + raise RuntimeError( + f"set_arg: Node {node} does not support normalization of arguments" + ) + + kwargs = normalized_args.kwargs + if kwarg_name not in kwargs: + raise ValueError(f"set_arg: invalid arg name {kwarg_name} for node {node} used") + + idx = list(kwargs.keys()).index(kwarg_name) + if idx < len(node.args): + node.update_arg(idx, value) else: node.update_kwarg(kwarg_name, value) diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index fe23ea73754..faee453346c 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -779,17 +779,17 @@ def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None: for slice_copy_node in graph_module.graph.find_nodes( op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor ): - cat_node = cast(Node, get_arg(slice_copy_node, 0, "input")) - slice_dim = cast(int, get_arg(slice_copy_node, 1, "dim", default=0)) - start_idx = cast(int, get_arg(slice_copy_node, 2, "start", default=None)) - end_idx = cast(int, get_arg(slice_copy_node, 3, "end", default=None)) - step = cast(int, get_arg(slice_copy_node, 4, "step", default=1)) + cat_node = cast(Node, get_arg(slice_copy_node, "input")) + slice_dim = cast(int, get_arg(slice_copy_node, "dim")) + start_idx = cast(int, get_arg(slice_copy_node, "start")) + end_idx = cast(int, get_arg(slice_copy_node, "end")) + step = cast(int, get_arg(slice_copy_node, "step")) if cat_node.target != exir_ops.edge.aten.cat.default or step != 1: continue # Make sure cat and slice happens on the same dimension. - cat_dim = cast(Node, get_arg(cat_node, 1, "dim", default=0)) + cat_dim = cast(Node, get_arg(cat_node, "dim")) if cat_dim != slice_dim: continue @@ -805,14 +805,14 @@ def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None: end_idx += cat_output_shape[cat_dim] offset = 0 - for cat_input_node in cast(List[Node], get_arg(cat_node, 0, "tensors")): + for cat_input_node in cast(List[Node], get_arg(cat_node, "tensors")): cat_input_shape = cat_input_node.meta["val"].shape # Check if the slice range overlaps with the cat input range. if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]: slice_copy_node.replace_input_with(cat_node, cat_input_node) - set_arg(slice_copy_node, 2, "start", start_idx - offset) - set_arg(slice_copy_node, 3, "end", end_idx - offset) + set_arg(slice_copy_node, "start", start_idx - offset) + set_arg(slice_copy_node, "end", end_idx - offset) break offset += cat_input_shape[cat_dim]