Skip to content

Remove arg index from the get_arg and set_arg of a torch Node #12326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions backends/cadence/aot/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,30 +174,53 @@ def nodes_not_adjacent_in_gm(

def get_arg(
node: torch.fx.Node,
arg_index: int,
kwarg_name: str,
*,
default: torch.fx.node.Argument = None,
) -> torch.fx.node.Argument:
"""
Get the arg at arg_index or kwarg with arg_name of the node. If neither is found
return default.
Get the arg with arg_name of the node, returns default value if not set.
"""
if arg_index < len(node.args):
return node.args[arg_index]
elif kwarg_name in node.kwargs:
# Try to get the arg from kwargs first since this is faster
if kwarg_name in node.kwargs:
return node.kwargs[kwarg_name]
else:
return default

# If it's not found in kwargs, try to normalize the args
normalized_args = node.normalized_arguments(
node.graph.owning_module, normalize_to_only_use_kwargs=True
)
if not normalized_args:
raise RuntimeError(
f"get_arg: Node {node} does not support normalization of arguments"
)

return normalized_args.kwargs[kwarg_name]


def set_arg(
node: torch.fx.Node, arg_index: int, kwarg_name: str, value: torch.fx.node.Argument
node: torch.fx.Node, kwarg_name: str, value: torch.fx.node.Argument
) -> None:
"""
Set the arg at arg_index if it exists, otherwise set the kwarg.
Set the node's arg with its name to the given value.
"""
if arg_index < len(node.args):
node.update_arg(arg_index, value)
# Try to set the arg if it is present in kwargs first since this is faster
if kwarg_name in node.kwargs:
node.update_kwarg(kwarg_name, value)
return

# If it's not found in kwargs, try to normalize the args and set the arg
normalized_args = node.normalized_arguments(
node.graph.owning_module, normalize_to_only_use_kwargs=True
)
if not normalized_args:
raise RuntimeError(
f"set_arg: Node {node} does not support normalization of arguments"
)

kwargs = normalized_args.kwargs
if kwarg_name not in kwargs:
raise ValueError(f"set_arg: invalid arg name {kwarg_name} for node {node} used")

idx = list(kwargs.keys()).index(kwarg_name)
if idx < len(node.args):
node.update_arg(idx, value)
else:
node.update_kwarg(kwarg_name, value)
18 changes: 9 additions & 9 deletions backends/cadence/aot/remove_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,17 +779,17 @@ def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
for slice_copy_node in graph_module.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
):
cat_node = cast(Node, get_arg(slice_copy_node, 0, "input"))
slice_dim = cast(int, get_arg(slice_copy_node, 1, "dim", default=0))
start_idx = cast(int, get_arg(slice_copy_node, 2, "start", default=None))
end_idx = cast(int, get_arg(slice_copy_node, 3, "end", default=None))
step = cast(int, get_arg(slice_copy_node, 4, "step", default=1))
cat_node = cast(Node, get_arg(slice_copy_node, "input"))
slice_dim = cast(int, get_arg(slice_copy_node, "dim"))
start_idx = cast(int, get_arg(slice_copy_node, "start"))
end_idx = cast(int, get_arg(slice_copy_node, "end"))
step = cast(int, get_arg(slice_copy_node, "step"))

if cat_node.target != exir_ops.edge.aten.cat.default or step != 1:
continue

# Make sure cat and slice happens on the same dimension.
cat_dim = cast(Node, get_arg(cat_node, 1, "dim", default=0))
cat_dim = cast(Node, get_arg(cat_node, "dim"))
if cat_dim != slice_dim:
continue

Expand All @@ -805,14 +805,14 @@ def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
end_idx += cat_output_shape[cat_dim]

offset = 0
for cat_input_node in cast(List[Node], get_arg(cat_node, 0, "tensors")):
for cat_input_node in cast(List[Node], get_arg(cat_node, "tensors")):
cat_input_shape = cat_input_node.meta["val"].shape

# Check if the slice range overlaps with the cat input range.
if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]:
slice_copy_node.replace_input_with(cat_node, cat_input_node)
set_arg(slice_copy_node, 2, "start", start_idx - offset)
set_arg(slice_copy_node, 3, "end", end_idx - offset)
set_arg(slice_copy_node, "start", start_idx - offset)
set_arg(slice_copy_node, "end", end_idx - offset)
break

offset += cat_input_shape[cat_dim]
Expand Down
Loading