File tree Expand file tree Collapse file tree 1 file changed +8
-1
lines changed
ai_edge_torch/odml_torch/experimental/torch_tfl Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Original file line number Diff line number Diff line change @@ -102,13 +102,20 @@ def tfl_transpose(input: torch.Tensor, perm: Sequence[int]) -> torch.Tensor:
102
102
return torch .permute (input , perm ).clone ()
103
103
104
104
105
- @custom_op_with_fake ("tfl::reshape" )
105
+ @torch . library . custom_op ("tfl::reshape" , mutates_args = () )
106
106
def tfl_reshape (input : torch .Tensor , shape : Sequence [int ]) -> torch .Tensor :
107
107
assert torch .Size (shape ).numel () == input .numel ()
108
108
109
109
return input .view (shape ).clone ()
110
110
111
111
112
+ # Use explicit fake implementation for tfl.reshape because dynamo cannot
113
+ # derive the output's symbolic shape from the impl above.
114
+ @torch .library .register_fake ("tfl::reshape" )
115
+ def tfl_reshape_fake (input : torch .Tensor , shape : Sequence [int ]) -> torch .Tensor :
116
+ return torch .empty (shape , dtype = input .dtype )
117
+
118
+
112
119
@custom_op_with_fake ("tfl::softmax" )
113
120
def tfl_softmax (x : torch .Tensor ) -> torch .Tensor :
114
121
return torch .nn .functional .softmax (x )
You can’t perform that action at this time.
0 commit comments