Skip to content

Commit 1793bae

Browse files
Arm backend: Add partial support for index.Tensor (#11851)
- Add Node visitor to facilitate lowering - Add dedicated support check to prevent unsupported usage - Add tests - Supports 4D tensors being indexed - Does not support Slice, Ellipsis and None before an indexing tensor - Does not support indexing tensors of rank 4 or higher Signed-off-by: Iliyan Georgiev <[email protected]>
1 parent d59fddc commit 1793bae

File tree

8 files changed

+1100
-8
lines changed

8 files changed

+1100
-8
lines changed

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,18 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
203203
- 1D/2D tensors
204204
"""
205205
for node in graph_module.graph.nodes:
206-
if node.op != "call_function":
206+
# call_function and placeholder allowed due to
207+
# index.Tensor being able to come in as both
208+
if node.op not in ["call_function", "placeholder"]:
207209
continue
208210

209-
elif node.target == exir_ops.edge.aten.view_copy.default:
211+
elif node.target in (
212+
exir_ops.edge.aten.view_copy.default,
213+
exir_ops.edge.aten.index.Tensor,
214+
):
215+
# For index.Tensor:
216+
# If we want to support 4D indexing tensors this logic
217+
# should be updated.
210218
input_node = node.args[0]
211219
input_shape = input_node.meta["val"].shape
212220
output_shape = node.meta["val"].shape

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
embedding_support,
1111
ethos_u55_support,
1212
index_select_support,
13+
index_tensor_support,
1314
minmax_support,
1415
pool_2d_support,
1516
reduce_sum_support,
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import math
7+
8+
import torch
9+
import torch.fx as fx
10+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
11+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
12+
register_tosa_support_check,
13+
SupportedTOSAOperatorCheck,
14+
)
15+
from executorch.backends.arm.tosa_specification import TosaSpecification
16+
from executorch.exir.dialects._ops import ops as exir_ops
17+
18+
19+
@register_tosa_support_check
20+
class IndexTensorSupported(SupportedTOSAOperatorCheck):
21+
"""
22+
This support check is intended to prevent the partitioning of
23+
currently unsupported usages of the index.Tensor operator.
24+
25+
1. Usages where indexing tensors are of rank 4 or higher.
26+
This is due to the AnnotateChannelsLastDimOrder pass and
27+
the rarity of such operation.
28+
Support is possible but would require further changes to the above
29+
pass which can be added at such a time as is necessary.
30+
31+
2. Usages where slice, ellipsis or None are present before an indexing tensor:
32+
t[{start}:{end}, indexTensor] - slicing
33+
t[None, indexTensor] - unsqueeze
34+
t[..., indexTensor] - ellipsis
35+
36+
3. Usages where the value tensor contains more than int32.max elements
37+
This is due to int32 TOSA limitation and the fact that we flatten out
38+
and accumulate all index tensors.
39+
As such to avoid overflow we reject lowering of this operator if it is
40+
possible for indices to go over the int32 limit.
41+
42+
Extra information regarding #2:
43+
Pytorch decomposes slice and None usages before they reach aten.
44+
In the case of Slicing and Unsqueeze, Pytorch will add the relevant
45+
operation just before the index.Tensor op.
46+
In the case of Ellipsis no extra operation is added.
47+
48+
In all three cases Pytorch will insert "None"(s) in the index list
49+
only if the above operations are done on a dimension BEFORE one being indexed.
50+
51+
When slicing, unsqueeze and ellipsis are done on dimensions after
52+
the ones being indexed, then they do not affect the final output
53+
values, only the shape. Thus None is not passed to the index.Tensor op.
54+
55+
The purpose of None is to signify to index.Tensor that a dimension
56+
should not be indexed.
57+
In such cases the logic behaves similar to batching along that dimension.
58+
For the sake of simplicity we have not implemented this behavior yet
59+
and thus have put this support check in place to prevent the partitioning
60+
of index.Tensor ops which include None.
61+
62+
Examples:
63+
#1 - Slice -----------------------------------------------------
64+
t = torch.randint(25, size(25, 3, 6))
65+
t[1:5, torch.arange(3)]
66+
67+
Turns into: (edge pseudo code)
68+
slice_res = ...edge__ops_aten_slice_copy_Tensor(t, dim=0, start=1, end=2)
69+
out = ...edge__ops_aten_index_Tensor(slice_res, [None, torch.arange(3)])
70+
71+
#2 - None (Unsqueeze) ------------------------------------------
72+
t = torch.randint(25, size(25, 3, 6))
73+
t[None, torch.arange(3)]
74+
75+
Turns into: edge pseudo code)
76+
unsqueeze_res = ...edge__ops_aten_unsqueeze(t, dim=0)
77+
out = ...edge__ops_aten_index_Tensor(unsqueeze_res, [None, torch.arange(3)])
78+
79+
#3 - None (Unsqueeze) After index ------------------------------
80+
t = torch.randint(25, size(25, 3, 6))
81+
t[torch.arange(3), None]
82+
83+
Turns into: edge pseudo code)
84+
unsqueeze_res = ...edge__ops_aten_unsqueeze(t, dim=1)
85+
out = ...edge__ops_aten_index_Tensor(unsqueeze_res, [torch.arange(3)])
86+
87+
NB.
88+
With the current implementation of flattening tensors and indices out,
89+
supporting None (Unsqueeze) is simply a matter of ignoring the
90+
None dimension.
91+
This is not the case for Slice and Ellipsis operators, where
92+
the size of the new dimension can be > 1.
93+
94+
Note that slice ops interleaved between indexes such as:
95+
t[1:3, torch.arange(5), 2:3, torch.arange(3).reshape(3,1)]
96+
are also possible and can result in some unintuitive behaviors
97+
where batching and indexing are mixed together.
98+
"""
99+
100+
targets = [exir_ops.edge.aten.index.Tensor]
101+
102+
tosa_specs = [
103+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
104+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
105+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
106+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
107+
]
108+
109+
def is_node_tosa_supported(
110+
self, node: fx.Node, tosa_spec: TosaSpecification
111+
) -> bool: # type: ignore[override, misc]
112+
indices = node.args[1]
113+
for index in indices: # type: ignore[union-attr]
114+
# Usage 2 guard
115+
if index is None:
116+
return False
117+
118+
# Usage 1 guard
119+
fake_tensor = get_first_fake_tensor(index) # type: ignore[arg-type]
120+
if len(fake_tensor.size()) > 3:
121+
return False
122+
123+
# Usage 3 guard
124+
total_vals = math.prod(get_first_fake_tensor(node.args[0]).shape) # type: ignore[arg-type]
125+
if total_vals > torch.iinfo(torch.int32).max:
126+
return False
127+
128+
return True

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
op_ge,
2626
op_gt,
2727
op_index_select,
28+
op_index_tensor,
2829
op_le,
2930
op_log,
3031
op_lt,

0 commit comments

Comments
 (0)