Skip to content

Commit 9625f16

Browse files
chunnienccopybara-github
authored andcommitted
Use explicit fake impl for torch.ops.tfl.reshape
PiperOrigin-RevId: 746338648
1 parent 2d87897 commit 9625f16

File tree

1 file changed

+8
-1
lines changed
  • ai_edge_torch/odml_torch/experimental/torch_tfl

1 file changed

+8
-1
lines changed

ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,20 @@ def tfl_transpose(input: torch.Tensor, perm: Sequence[int]) -> torch.Tensor:
102102
return torch.permute(input, perm).clone()
103103

104104

105-
@custom_op_with_fake("tfl::reshape")
105+
@torch.library.custom_op("tfl::reshape", mutates_args=())
106106
def tfl_reshape(input: torch.Tensor, shape: Sequence[int]) -> torch.Tensor:
107107
assert torch.Size(shape).numel() == input.numel()
108108

109109
return input.view(shape).clone()
110110

111111

112+
# Use explicit fake implementation for tfl.reshape because dynamo cannot
113+
# derive the output's symbolic shape from the impl above.
114+
@torch.library.register_fake("tfl::reshape")
115+
def tfl_reshape_fake(input: torch.Tensor, shape: Sequence[int]) -> torch.Tensor:
116+
return torch.empty(shape, dtype=input.dtype)
117+
118+
112119
@custom_op_with_fake("tfl::softmax")
113120
def tfl_softmax(x: torch.Tensor) -> torch.Tensor:
114121
return torch.nn.functional.softmax(x)

0 commit comments

Comments
 (0)