Skip to content

Commit 80da097

Browse files
authored
Fix Transpose Optimization Bug With non-4D Tensor Input (#12520)
Summary: Fixed a bug in the channels last tagged reshape pass, where non-4d inputs were being tagged as contiguous/channels last memory formats, which isn't expected as these formats only apply to 4d tensors. The repro is in N7569847. The fix was completed by checking tensor shape size before tagging input nodes. Reviewed By: mcr229 Differential Revision: D78357428
1 parent 9e05d89 commit 80da097

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,11 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
457457
for node in original_nodes:
458458
if len(node.all_input_nodes) == 0:
459459
# This node has no inputs so we don't need to change anything, but still need to tag input nodes
460-
if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor):
460+
if (
461+
"val" in node.meta
462+
and isinstance(node.meta["val"], torch.Tensor)
463+
and len(node.meta["val"].shape) == 4
464+
):
461465
if node.meta["val"].is_contiguous():
462466
self.mark_as_nchw_node(node)
463467
else:

0 commit comments

Comments
 (0)