|
| 1 | +# Copyright 2025 Arm Limited and/or its affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import math |
| 7 | + |
| 8 | +import torch |
| 9 | +import torch.fx as fx |
| 10 | +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor |
| 11 | +from executorch.backends.arm.operator_support.tosa_supported_operators import ( |
| 12 | + register_tosa_support_check, |
| 13 | + SupportedTOSAOperatorCheck, |
| 14 | +) |
| 15 | +from executorch.backends.arm.tosa_specification import TosaSpecification |
| 16 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 17 | + |
| 18 | + |
| 19 | +@register_tosa_support_check |
| 20 | +class IndexTensorSupported(SupportedTOSAOperatorCheck): |
| 21 | + """ |
| 22 | + This support check is intended to prevent the partitioning of |
| 23 | + currently unsupported usages of the index.Tensor operator. |
| 24 | +
|
| 25 | + 1. Usages where indexing tensors are of rank 4 or higher. |
| 26 | + This is due to the AnnotateChannelsLastDimOrder pass and |
| 27 | + the rarity of such operation. |
| 28 | + Support is possible but would require further changes to the above |
| 29 | + pass which can be added at such a time as is necessary. |
| 30 | +
|
| 31 | + 2. Usages where slice, ellipsis or None are present before an indexing tensor: |
| 32 | + t[{start}:{end}, indexTensor] - slicing |
| 33 | + t[None, indexTensor] - unsqueeze |
| 34 | + t[..., indexTensor] - ellipsis |
| 35 | +
|
| 36 | + 3. Usages where the value tensor contains more than int32.max elements |
| 37 | + This is due to int32 TOSA limitation and the fact that we flatten out |
| 38 | + and accumulate all index tensors. |
| 39 | + As such to avoid overflow we reject lowering of this operator if it is |
| 40 | + possible for indices to go over the int32 limit. |
| 41 | +
|
| 42 | + Extra information regarding #2: |
| 43 | + Pytorch decomposes slice and None usages before they reach aten. |
| 44 | + In the case of Slicing and Unsqueeze, Pytorch will add the relevant |
| 45 | + operation just before the index.Tensor op. |
| 46 | + In the case of Ellipsis no extra operation is added. |
| 47 | +
|
| 48 | + In all three cases Pytorch will insert "None"(s) in the index list |
| 49 | + only if the above operations are done on a dimension BEFORE one being indexed. |
| 50 | +
|
| 51 | + When slicing, unsqueeze and ellipsis are done on dimensions after |
| 52 | + the ones being indexed, then they do not affect the final output |
| 53 | + values, only the shape. Thus None is not passed to the index.Tensor op. |
| 54 | +
|
| 55 | + The purpose of None is to signify to index.Tensor that a dimension |
| 56 | + should not be indexed. |
| 57 | + In such cases the logic behaves similar to batching along that dimension. |
| 58 | + For the sake of simplicity we have not implemented this behavior yet |
| 59 | + and thus have put this support check in place to prevent the partitioning |
| 60 | + of index.Tensor ops which include None. |
| 61 | +
|
| 62 | + Examples: |
| 63 | + #1 - Slice ----------------------------------------------------- |
| 64 | + t = torch.randint(25, size(25, 3, 6)) |
| 65 | + t[1:5, torch.arange(3)] |
| 66 | +
|
| 67 | + Turns into: (edge pseudo code) |
| 68 | + slice_res = ...edge__ops_aten_slice_copy_Tensor(t, dim=0, start=1, end=2) |
| 69 | + out = ...edge__ops_aten_index_Tensor(slice_res, [None, torch.arange(3)]) |
| 70 | +
|
| 71 | + #2 - None (Unsqueeze) ------------------------------------------ |
| 72 | + t = torch.randint(25, size(25, 3, 6)) |
| 73 | + t[None, torch.arange(3)] |
| 74 | +
|
| 75 | + Turns into: edge pseudo code) |
| 76 | + unsqueeze_res = ...edge__ops_aten_unsqueeze(t, dim=0) |
| 77 | + out = ...edge__ops_aten_index_Tensor(unsqueeze_res, [None, torch.arange(3)]) |
| 78 | +
|
| 79 | + #3 - None (Unsqueeze) After index ------------------------------ |
| 80 | + t = torch.randint(25, size(25, 3, 6)) |
| 81 | + t[torch.arange(3), None] |
| 82 | +
|
| 83 | + Turns into: edge pseudo code) |
| 84 | + unsqueeze_res = ...edge__ops_aten_unsqueeze(t, dim=1) |
| 85 | + out = ...edge__ops_aten_index_Tensor(unsqueeze_res, [torch.arange(3)]) |
| 86 | +
|
| 87 | + NB. |
| 88 | + With the current implementation of flattening tensors and indices out, |
| 89 | + supporting None (Unsqueeze) is simply a matter of ignoring the |
| 90 | + None dimension. |
| 91 | + This is not the case for Slice and Ellipsis operators, where |
| 92 | + the size of the new dimension can be > 1. |
| 93 | +
|
| 94 | + Note that slice ops interleaved between indexes such as: |
| 95 | + t[1:3, torch.arange(5), 2:3, torch.arange(3).reshape(3,1)] |
| 96 | + are also possible and can result in some unintuitive behaviors |
| 97 | + where batching and indexing are mixed together. |
| 98 | + """ |
| 99 | + |
| 100 | + targets = [exir_ops.edge.aten.index.Tensor] |
| 101 | + |
| 102 | + tosa_specs = [ |
| 103 | + TosaSpecification.create_from_string("TOSA-0.80+BI"), |
| 104 | + TosaSpecification.create_from_string("TOSA-0.80+MI"), |
| 105 | + TosaSpecification.create_from_string("TOSA-1.0+INT"), |
| 106 | + TosaSpecification.create_from_string("TOSA-1.0+FP"), |
| 107 | + ] |
| 108 | + |
| 109 | + def is_node_tosa_supported( |
| 110 | + self, node: fx.Node, tosa_spec: TosaSpecification |
| 111 | + ) -> bool: # type: ignore[override, misc] |
| 112 | + indices = node.args[1] |
| 113 | + for index in indices: # type: ignore[union-attr] |
| 114 | + # Usage 2 guard |
| 115 | + if index is None: |
| 116 | + return False |
| 117 | + |
| 118 | + # Usage 1 guard |
| 119 | + fake_tensor = get_first_fake_tensor(index) # type: ignore[arg-type] |
| 120 | + if len(fake_tensor.size()) > 3: |
| 121 | + return False |
| 122 | + |
| 123 | + # Usage 3 guard |
| 124 | + total_vals = math.prod(get_first_fake_tensor(node.args[0]).shape) # type: ignore[arg-type] |
| 125 | + if total_vals > torch.iinfo(torch.int32).max: |
| 126 | + return False |
| 127 | + |
| 128 | + return True |
0 commit comments