Skip to content

Commit 953fa0e

Browse files
authored
Do not get numel for unbacked symint (#12490)
Summary: The current code in _check_inputs_are_valid_dtypes assumes that the tensor does not have any dim that is unbacked symint. Factorized joiner contains a data dependent intermediate tensor, where the batch dim is unbacked: N7582435. Currently the code fails there. This diff only checks for the numel() if the tensor dimensions are not unbacked. Differential Revision: D78343875
1 parent 1f885b9 commit 953fa0e

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

backends/xnnpack/partition/config/xnnpack_config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717
from executorch.exir.backend.utils import WhyNoPartition
1818
from torch.export import ExportedProgram
19+
from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
1920

2021
logger = logging.getLogger(__name__)
2122
why = WhyNoPartition(logger=logger)
@@ -168,8 +169,10 @@ def _check_inputs_are_valid_dtypes(self, node, valid_dtypes):
168169
if not isinstance(arg_val, torch.Tensor):
169170
return False
170171

171-
# XNNPACK does not support empty tensors
172-
if arg_val.numel() == 0:
172+
# XNNPACK does not support empty tensors. But we can't get numel()
173+
# for unbacked symints, so we conservatively bail out here if any
174+
# dimension of the tensor is unbacked symint.
175+
if has_free_unbacked_symbols(arg_val) or arg_val.numel() == 0:
173176
return False
174177

175178
if arg_val.dtype not in valid_dtypes:

0 commit comments

Comments
 (0)