diff --git a/exir/passes/replace_broken_ops_with_function_ops_pass.py b/exir/passes/replace_broken_ops_with_function_ops_pass.py index 22619e28bac..4fbaa539132 100644 --- a/exir/passes/replace_broken_ops_with_function_ops_pass.py +++ b/exir/passes/replace_broken_ops_with_function_ops_pass.py @@ -5,26 +5,10 @@ # LICENSE file in the root directory of this source tree. # pyre-strict -from typing import Dict - import torch from executorch.exir.pass_base import ExportPass -from torch._ops import OpOverload - - -_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: Dict[OpOverload, OpOverload] = { - torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default, - torch.ops.aten.t.default: torch.ops.aten.t_copy.default, - torch.ops.aten.view.default: torch.ops.aten.view_copy.default, - torch.ops.aten.expand.default: torch.ops.aten.expand_copy.default, - torch.ops.aten.permute.default: torch.ops.aten.permute_copy.default, - torch.ops.aten.squeeze.default: torch.ops.aten.squeeze_copy.default, - torch.ops.aten.unsqueeze.default: torch.ops.aten.unsqueeze_copy.default, - torch.ops.aten.slice.Tensor: torch.ops.aten.slice_copy.Tensor, -} - class ReplaceBrokenOpsWithFunctionalOpsPass(ExportPass): """ @@ -37,8 +21,22 @@ class ReplaceBrokenOpsWithFunctionalOpsPass(ExportPass): # pyre-ignore def call_operator(self, op, args, kwargs, meta): - if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: - return super().call_operator( - _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op], args, kwargs, meta + if op.is_view: + namespace, op_full_name = op.name().split("::") + split = op_full_name.split(".") + if len(split) == 2: + op_name, overload_name = split[0], split[1] + elif len(split) == 1: + # Add default overload if no overload listed + op_name = op_full_name + overload_name = "default" + else: + raise RuntimeError( + f"Invalid op name expected only one '.' to be present: {op_full_name}" + ) + + view_copy_op = getattr( + getattr(getattr(torch.ops, namespace), f"{op_name}_copy"), overload_name ) + return super().call_operator(view_copy_op, args, kwargs, meta) return super().call_operator(op, args, kwargs, meta) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 656e20e2fb7..422b133f0e0 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -595,33 +595,78 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: self.assertEqual(counter, 1) def test_compile_fix_broken_ops(self) -> None: - # When pass an input of more than 4 dimensions to Linear - # aten._unsafe_view is used under the hood - x = torch.randn([2, 3, 4, 5]) - model: torch.nn.Linear = torch.nn.Linear(5, 5) - - class Foo(torch.nn.Module): - def __init__(self): + class ExportableLoop(nn.Module): + def __init__(self, hidden_size, out_channels): super().__init__() - self.model = model - - def forward(self, inp: torch.Tensor) -> torch.Tensor: - return self.model(inp) - - f = Foo() + self.hidden_size = hidden_size + self.B = nn.Parameter(torch.randn(hidden_size, 1)) # (H, in_channels) + self.C = nn.Parameter( + torch.randn(out_channels, hidden_size) + ) # (C_out, H) + A = torch.randn(2, hidden_size) + self.A_real = nn.Parameter(A[0].clone()) + self.A_imag = nn.Parameter(A[1].clone()) + + def update_state(self, h, x_t): + # h: [B, 2, H], x_t: [B, H] + hr, hi = h[:, 0, :], h[:, 1, :] # [B, H] + hrn = hr * self.A_real - hi * self.A_imag + x_t # [B, H] + hin = hi * self.A_real + hr * self.A_imag # [B, H] + hn = torch.stack([hrn, hin], dim=1) # [B, 2, H] + return hn, hrn + + def forward(self, u): + # u: [B, 1, T] + x = torch.matmul(self.B, u) # (B, H, T) + B, H, T = x.shape + + h = torch.zeros(B, 2, H, device=x.device, dtype=x.dtype) # [B, 2, H] + h_accum = torch.zeros( + B, H, T, device=x.device, dtype=x.dtype + ) # [B, H, T] + i = torch.tensor(0, device=x.device, dtype=torch.int64) + one = torch.tensor(1, device=x.device, dtype=torch.int64) + + def cond(i, h, h_accum): + return i < T + + def body(i, h, h_accum): + x_t = x.index_select(-1, i.unsqueeze(0)).squeeze( + -1 + ) # ✅ safe for export + h, hr = self.update_state(h, x_t) # h: [B, 2, H], hr: [B, H] + h_accum = h_accum.index_copy( + -1, i.unsqueeze(0), hr.unsqueeze(-1) + ) # [B, H, T] + i_next = i + one + return i_next, h, h_accum + + _, h, h_accum = torch._higher_order_ops.while_loop( + cond, body, (i, h, h_accum) + ) + y = torch.matmul(self.C, h_accum).transpose(0, 1) # (B, C_out, T) + return y - # ReplaceBrokenOpsWithFunctionalOpsPass is used in to_edge() + # Instantiate and export + model = ExportableLoop(hidden_size=128, out_channels=10) + inp = torch.randn(1, 1, 32) # (B, in_channels=1, T=32) + ep = export(model, (inp,)) prog = to_edge( - export(f, (x,), strict=True), + ep, compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), ) gm = prog.exported_program().graph_module count_after = 0 for node in gm.graph.nodes: - if node.target == torch.ops.aten._unsafe_view.default: + if ( + node.target == torch.ops.aten.squeeze.dims + or node.target == torch.ops.aten.select.int + ): count_after += 1 self.assertEqual(count_after, 0) - self.assertTrue(torch.allclose(prog.exported_program().module()(x), f(x))) + self.assertTrue( + torch.allclose(prog.exported_program().module()(inp), model(inp)) + ) def test_convert_symb_ops(self) -> None: class Foo(torch.nn.Module):