Skip to content

Fix replace view ops pass to not be a hardcoded list #12361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 17 additions & 19 deletions exir/passes/replace_broken_ops_with_function_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,10 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
from typing import Dict

import torch

from executorch.exir.pass_base import ExportPass

from torch._ops import OpOverload


_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: Dict[OpOverload, OpOverload] = {
torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default,
torch.ops.aten.t.default: torch.ops.aten.t_copy.default,
torch.ops.aten.view.default: torch.ops.aten.view_copy.default,
torch.ops.aten.expand.default: torch.ops.aten.expand_copy.default,
torch.ops.aten.permute.default: torch.ops.aten.permute_copy.default,
torch.ops.aten.squeeze.default: torch.ops.aten.squeeze_copy.default,
torch.ops.aten.unsqueeze.default: torch.ops.aten.unsqueeze_copy.default,
torch.ops.aten.slice.Tensor: torch.ops.aten.slice_copy.Tensor,
}


class ReplaceBrokenOpsWithFunctionalOpsPass(ExportPass):
"""
Expand All @@ -37,8 +21,22 @@ class ReplaceBrokenOpsWithFunctionalOpsPass(ExportPass):

# pyre-ignore
def call_operator(self, op, args, kwargs, meta):
if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS:
return super().call_operator(
_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op], args, kwargs, meta
if op.is_view:
namespace, op_full_name = op.name().split("::")
split = op_full_name.split(".")
if len(split) == 2:
op_name, overload_name = split[0], split[1]
elif len(split) == 1:
# Add default overload if no overload listed
op_name = op_full_name
overload_name = "default"
else:
raise RuntimeError(
f"Invalid op name expected only one '.' to be present: {op_full_name}"
)

view_copy_op = getattr(
getattr(getattr(torch.ops, namespace), f"{op_name}_copy"), overload_name
)
return super().call_operator(view_copy_op, args, kwargs, meta)
return super().call_operator(op, args, kwargs, meta)
79 changes: 62 additions & 17 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,33 +595,78 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
self.assertEqual(counter, 1)

def test_compile_fix_broken_ops(self) -> None:
# When pass an input of more than 4 dimensions to Linear
# aten._unsafe_view is used under the hood
x = torch.randn([2, 3, 4, 5])
model: torch.nn.Linear = torch.nn.Linear(5, 5)

class Foo(torch.nn.Module):
def __init__(self):
class ExportableLoop(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.model = model

def forward(self, inp: torch.Tensor) -> torch.Tensor:
return self.model(inp)

f = Foo()
self.hidden_size = hidden_size
self.B = nn.Parameter(torch.randn(hidden_size, 1)) # (H, in_channels)
self.C = nn.Parameter(
torch.randn(out_channels, hidden_size)
) # (C_out, H)
A = torch.randn(2, hidden_size)
self.A_real = nn.Parameter(A[0].clone())
self.A_imag = nn.Parameter(A[1].clone())

def update_state(self, h, x_t):
# h: [B, 2, H], x_t: [B, H]
hr, hi = h[:, 0, :], h[:, 1, :] # [B, H]
hrn = hr * self.A_real - hi * self.A_imag + x_t # [B, H]
hin = hi * self.A_real + hr * self.A_imag # [B, H]
hn = torch.stack([hrn, hin], dim=1) # [B, 2, H]
return hn, hrn

def forward(self, u):
# u: [B, 1, T]
x = torch.matmul(self.B, u) # (B, H, T)
B, H, T = x.shape

h = torch.zeros(B, 2, H, device=x.device, dtype=x.dtype) # [B, 2, H]
h_accum = torch.zeros(
B, H, T, device=x.device, dtype=x.dtype
) # [B, H, T]
i = torch.tensor(0, device=x.device, dtype=torch.int64)
one = torch.tensor(1, device=x.device, dtype=torch.int64)

def cond(i, h, h_accum):
return i < T

def body(i, h, h_accum):
x_t = x.index_select(-1, i.unsqueeze(0)).squeeze(
-1
) # ✅ safe for export
h, hr = self.update_state(h, x_t) # h: [B, 2, H], hr: [B, H]
h_accum = h_accum.index_copy(
-1, i.unsqueeze(0), hr.unsqueeze(-1)
) # [B, H, T]
i_next = i + one
return i_next, h, h_accum

_, h, h_accum = torch._higher_order_ops.while_loop(
cond, body, (i, h, h_accum)
)
y = torch.matmul(self.C, h_accum).transpose(0, 1) # (B, C_out, T)
return y

# ReplaceBrokenOpsWithFunctionalOpsPass is used in to_edge()
# Instantiate and export
model = ExportableLoop(hidden_size=128, out_channels=10)
inp = torch.randn(1, 1, 32) # (B, in_channels=1, T=32)
ep = export(model, (inp,))
prog = to_edge(
export(f, (x,), strict=True),
ep,
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)
gm = prog.exported_program().graph_module
count_after = 0
for node in gm.graph.nodes:
if node.target == torch.ops.aten._unsafe_view.default:
if (
node.target == torch.ops.aten.squeeze.dims
or node.target == torch.ops.aten.select.int
):
count_after += 1
self.assertEqual(count_after, 0)
self.assertTrue(torch.allclose(prog.exported_program().module()(x), f(x)))
self.assertTrue(
torch.allclose(prog.exported_program().module()(inp), model(inp))
)

def test_convert_symb_ops(self) -> None:
class Foo(torch.nn.Module):
Expand Down
Loading