Skip to content

Commit f6ae97d

Browse files
Ritwik Dasfacebook-github-bot
authored andcommitted
Remove arg index from the get_arg and set_arg of a torch Node (#12326)
Summary: Pull Request resolved: #12326 Instead of using the arg index, rely on the normalized kwargs dictionary for a uniform view of a node's arguments. This diff cleans up the get_arg and set_arg api so that callers do not need to additionally keep track of arg indices. This simplifies the calling code and also prevents unnecessary code maintenance in case the op signature changes thereby changing the relative positioning of the args within a node. Reviewed By: abeakkas Differential Revision: D77976838
1 parent bdbad3f commit f6ae97d

File tree

2 files changed

+46
-23
lines changed

2 files changed

+46
-23
lines changed

backends/cadence/aot/pass_utils.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -174,30 +174,53 @@ def nodes_not_adjacent_in_gm(
174174

175175
def get_arg(
176176
node: torch.fx.Node,
177-
arg_index: int,
178177
kwarg_name: str,
179-
*,
180-
default: torch.fx.node.Argument = None,
181178
) -> torch.fx.node.Argument:
182179
"""
183-
Get the arg at arg_index or kwarg with arg_name of the node. If neither is found
184-
return default.
180+
Get the arg with arg_name of the node, returns default value if not set.
185181
"""
186-
if arg_index < len(node.args):
187-
return node.args[arg_index]
188-
elif kwarg_name in node.kwargs:
182+
# Try to get the arg from kwargs first since this is faster
183+
if kwarg_name in node.kwargs:
189184
return node.kwargs[kwarg_name]
190-
else:
191-
return default
185+
186+
# If it's not found in kwargs, try to normalize the args
187+
normalized_args = node.normalized_arguments(
188+
node.graph.owning_module, normalize_to_only_use_kwargs=True
189+
)
190+
if not normalized_args:
191+
raise RuntimeError(
192+
f"get_arg: Node {node} does not support normalization of arguments"
193+
)
194+
195+
return normalized_args.kwargs[kwarg_name]
192196

193197

194198
def set_arg(
195-
node: torch.fx.Node, arg_index: int, kwarg_name: str, value: torch.fx.node.Argument
199+
node: torch.fx.Node, kwarg_name: str, value: torch.fx.node.Argument
196200
) -> None:
197201
"""
198-
Set the arg at arg_index if it exists, otherwise set the kwarg.
202+
Set the node's arg with its name to the given value.
199203
"""
200-
if arg_index < len(node.args):
201-
node.update_arg(arg_index, value)
204+
# Try to set the arg if it is present in kwargs first since this is faster
205+
if kwarg_name in node.kwargs:
206+
node.update_kwarg(kwarg_name, value)
207+
return
208+
209+
# If it's not found in kwargs, try to normalize the args and set the arg
210+
normalized_args = node.normalized_arguments(
211+
node.graph.owning_module, normalize_to_only_use_kwargs=True
212+
)
213+
if not normalized_args:
214+
raise RuntimeError(
215+
f"set_arg: Node {node} does not support normalization of arguments"
216+
)
217+
218+
kwargs = normalized_args.kwargs
219+
if kwarg_name not in kwargs:
220+
raise ValueError(f"set_arg: invalid arg name {kwarg_name} for node {node} used")
221+
222+
idx = list(kwargs.keys()).index(kwarg_name)
223+
if idx < len(node.args):
224+
node.update_arg(idx, value)
202225
else:
203226
node.update_kwarg(kwarg_name, value)

backends/cadence/aot/remove_ops.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -779,17 +779,17 @@ def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
779779
for slice_copy_node in graph_module.graph.find_nodes(
780780
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
781781
):
782-
cat_node = cast(Node, get_arg(slice_copy_node, 0, "input"))
783-
slice_dim = cast(int, get_arg(slice_copy_node, 1, "dim", default=0))
784-
start_idx = cast(int, get_arg(slice_copy_node, 2, "start", default=None))
785-
end_idx = cast(int, get_arg(slice_copy_node, 3, "end", default=None))
786-
step = cast(int, get_arg(slice_copy_node, 4, "step", default=1))
782+
cat_node = cast(Node, get_arg(slice_copy_node, "input"))
783+
slice_dim = cast(int, get_arg(slice_copy_node, "dim"))
784+
start_idx = cast(int, get_arg(slice_copy_node, "start"))
785+
end_idx = cast(int, get_arg(slice_copy_node, "end"))
786+
step = cast(int, get_arg(slice_copy_node, "step"))
787787

788788
if cat_node.target != exir_ops.edge.aten.cat.default or step != 1:
789789
continue
790790

791791
# Make sure cat and slice happens on the same dimension.
792-
cat_dim = cast(Node, get_arg(cat_node, 1, "dim", default=0))
792+
cat_dim = cast(Node, get_arg(cat_node, "dim"))
793793
if cat_dim != slice_dim:
794794
continue
795795

@@ -805,14 +805,14 @@ def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
805805
end_idx += cat_output_shape[cat_dim]
806806

807807
offset = 0
808-
for cat_input_node in cast(List[Node], get_arg(cat_node, 0, "tensors")):
808+
for cat_input_node in cast(List[Node], get_arg(cat_node, "tensors")):
809809
cat_input_shape = cat_input_node.meta["val"].shape
810810

811811
# Check if the slice range overlaps with the cat input range.
812812
if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]:
813813
slice_copy_node.replace_input_with(cat_node, cat_input_node)
814-
set_arg(slice_copy_node, 2, "start", start_idx - offset)
815-
set_arg(slice_copy_node, 3, "end", end_idx - offset)
814+
set_arg(slice_copy_node, "start", start_idx - offset)
815+
set_arg(slice_copy_node, "end", end_idx - offset)
816816
break
817817

818818
offset += cat_input_shape[cat_dim]

0 commit comments

Comments
 (0)