Skip to content

Commit 5b483ab

Browse files
Remove arg index from the get_arg and set_arg of a torch Node
Differential Revision: D77976838 Pull Request resolved: #12326
1 parent dc4d6ee commit 5b483ab

File tree

2 files changed

+46
-23
lines changed

2 files changed

+46
-23
lines changed

backends/cadence/aot/pass_utils.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -174,30 +174,53 @@ def nodes_not_adjacent_in_gm(
174174

175175
def get_arg(
176176
node: torch.fx.Node,
177-
arg_index: int,
178177
kwarg_name: str,
179-
*,
180-
default: torch.fx.node.Argument = None,
181178
) -> torch.fx.node.Argument:
182179
"""
183-
Get the arg at arg_index or kwarg with arg_name of the node. If neither is found
184-
return default.
180+
Get the arg with arg_name of the node, returns default value if not set.
185181
"""
186-
if arg_index < len(node.args):
187-
return node.args[arg_index]
188-
elif kwarg_name in node.kwargs:
182+
# Try to get the arg from kwargs first since this is faster
183+
if kwarg_name in node.kwargs:
189184
return node.kwargs[kwarg_name]
190-
else:
191-
return default
185+
186+
# If it's not found in kwargs, try to normalize the args
187+
normalized_args = node.normalized_arguments(
188+
node.graph.owning_module, normalize_to_only_use_kwargs=True
189+
)
190+
if not normalized_args:
191+
raise RuntimeError(
192+
f"get_arg: Node {node} does not support normalization of arguments"
193+
)
194+
195+
return normalized_args.kwargs[kwarg_name]
192196

193197

194198
def set_arg(
195-
node: torch.fx.Node, arg_index: int, kwarg_name: str, value: torch.fx.node.Argument
199+
node: torch.fx.Node, kwarg_name: str, value: torch.fx.node.Argument
196200
) -> None:
197201
"""
198-
Set the arg at arg_index if it exists, otherwise set the kwarg.
202+
Set the node's arg with its name to the given value.
199203
"""
200-
if arg_index < len(node.args):
201-
node.update_arg(arg_index, value)
204+
# Try to set the arg if it is present in kwargs first since this is faster
205+
if kwarg_name in node.kwargs:
206+
node.update_kwarg(kwarg_name, value)
207+
return
208+
209+
# If it's not found in kwargs, try to normalize the args and set the arg
210+
normalized_args = node.normalized_arguments(
211+
node.graph.owning_module, normalize_to_only_use_kwargs=True
212+
)
213+
if not normalized_args:
214+
raise RuntimeError(
215+
f"set_arg: Node {node} does not support normalization of arguments"
216+
)
217+
218+
kwargs = normalized_args.kwargs
219+
if kwarg_name not in kwargs:
220+
raise ValueError(f"set_arg: invalid arg name {kwarg_name} for node {node} used")
221+
222+
idx = list(kwargs.keys()).index(kwarg_name)
223+
if idx < len(node.args):
224+
node.update_arg(idx, value)
202225
else:
203226
node.update_kwarg(kwarg_name, value)

backends/cadence/aot/remove_ops.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -779,17 +779,17 @@ def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
779779
for slice_copy_node in graph_module.graph.find_nodes(
780780
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
781781
):
782-
cat_node = cast(Node, get_arg(slice_copy_node, 0, "input"))
783-
slice_dim = cast(int, get_arg(slice_copy_node, 1, "dim", default=0))
784-
start_idx = cast(int, get_arg(slice_copy_node, 2, "start", default=None))
785-
end_idx = cast(int, get_arg(slice_copy_node, 3, "end", default=None))
786-
step = cast(int, get_arg(slice_copy_node, 4, "step", default=1))
782+
cat_node = cast(Node, get_arg(slice_copy_node, "input"))
783+
slice_dim = cast(int, get_arg(slice_copy_node, "dim"))
784+
start_idx = cast(int, get_arg(slice_copy_node, "start"))
785+
end_idx = cast(int, get_arg(slice_copy_node, "end"))
786+
step = cast(int, get_arg(slice_copy_node, "step"))
787787

788788
if cat_node.target != exir_ops.edge.aten.cat.default or step != 1:
789789
continue
790790

791791
# Make sure cat and slice happens on the same dimension.
792-
cat_dim = cast(Node, get_arg(cat_node, 1, "dim", default=0))
792+
cat_dim = cast(Node, get_arg(cat_node, "dim"))
793793
if cat_dim != slice_dim:
794794
continue
795795

@@ -805,14 +805,14 @@ def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
805805
end_idx += cat_output_shape[cat_dim]
806806

807807
offset = 0
808-
for cat_input_node in cast(List[Node], get_arg(cat_node, 0, "tensors")):
808+
for cat_input_node in cast(List[Node], get_arg(cat_node, "tensors")):
809809
cat_input_shape = cat_input_node.meta["val"].shape
810810

811811
# Check if the slice range overlaps with the cat input range.
812812
if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]:
813813
slice_copy_node.replace_input_with(cat_node, cat_input_node)
814-
set_arg(slice_copy_node, 2, "start", start_idx - offset)
815-
set_arg(slice_copy_node, 3, "end", end_idx - offset)
814+
set_arg(slice_copy_node, "start", start_idx - offset)
815+
set_arg(slice_copy_node, "end", end_idx - offset)
816816
break
817817

818818
offset += cat_input_shape[cat_dim]

0 commit comments

Comments
 (0)