Skip to content

Commit 2653477

Browse files
leafs1facebook-github-bot
authored andcommitted
Fix Transpose Optimization Bug With non-4D Tensor Input
Summary: Fixed a bug in the channels last tagged reshape pass, where non-4d inputs were being tagged as contiguous/channels last memory formats, which isn't expected as these formats only apply to 4d tensors. The repro is in N7569847. The fix was completed by checking tensor shape size before tagging input nodes. Reviewed By: mcr229 Differential Revision: D78357428
1 parent 1f885b9 commit 2653477

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
457457
for node in original_nodes:
458458
if len(node.all_input_nodes) == 0:
459459
# This node has no inputs so we don't need to change anything, but still need to tag input nodes
460-
if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor):
460+
if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor) and len(node.meta["val"].shape) == 4:
461461
if node.meta["val"].is_contiguous():
462462
self.mark_as_nchw_node(node)
463463
else:

0 commit comments

Comments
 (0)