@@ -779,17 +779,17 @@ def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
779
779
for slice_copy_node in graph_module .graph .find_nodes (
780
780
op = "call_function" , target = exir_ops .edge .aten .slice_copy .Tensor
781
781
):
782
- cat_node = cast (Node , get_arg (slice_copy_node , 0 , "input" ))
783
- slice_dim = cast (int , get_arg (slice_copy_node , 1 , "dim" , default = 0 ))
784
- start_idx = cast (int , get_arg (slice_copy_node , 2 , "start" , default = None ))
785
- end_idx = cast (int , get_arg (slice_copy_node , 3 , "end" , default = None ))
786
- step = cast (int , get_arg (slice_copy_node , 4 , "step" , default = 1 ))
782
+ cat_node = cast (Node , get_arg (slice_copy_node , "input" ))
783
+ slice_dim = cast (int , get_arg (slice_copy_node , "dim" ))
784
+ start_idx = cast (int , get_arg (slice_copy_node , "start" ))
785
+ end_idx = cast (int , get_arg (slice_copy_node , "end" ))
786
+ step = cast (int , get_arg (slice_copy_node , "step" ))
787
787
788
788
if cat_node .target != exir_ops .edge .aten .cat .default or step != 1 :
789
789
continue
790
790
791
791
# Make sure cat and slice happens on the same dimension.
792
- cat_dim = cast (Node , get_arg (cat_node , 1 , "dim" , default = 0 ))
792
+ cat_dim = cast (Node , get_arg (cat_node , "dim" ))
793
793
if cat_dim != slice_dim :
794
794
continue
795
795
@@ -805,14 +805,14 @@ def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
805
805
end_idx += cat_output_shape [cat_dim ]
806
806
807
807
offset = 0
808
- for cat_input_node in cast (List [Node ], get_arg (cat_node , 0 , "tensors" )):
808
+ for cat_input_node in cast (List [Node ], get_arg (cat_node , "tensors" )):
809
809
cat_input_shape = cat_input_node .meta ["val" ].shape
810
810
811
811
# Check if the slice range overlaps with the cat input range.
812
812
if offset <= start_idx and end_idx <= offset + cat_input_shape [cat_dim ]:
813
813
slice_copy_node .replace_input_with (cat_node , cat_input_node )
814
- set_arg (slice_copy_node , 2 , "start" , start_idx - offset )
815
- set_arg (slice_copy_node , 3 , "end" , end_idx - offset )
814
+ set_arg (slice_copy_node , "start" , start_idx - offset )
815
+ set_arg (slice_copy_node , "end" , end_idx - offset )
816
816
break
817
817
818
818
offset += cat_input_shape [cat_dim ]
0 commit comments