Skip to content

Commit c0db862

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Fix replace view ops pass to not be a hardcoded list (#12361)
Summary: Instead of maintaining a list we can check if its a view op from its property and then surgery out the name and inject _copy. Issue: #12103 Differential Revision: D78116054
1 parent 378f062 commit c0db862

File tree

2 files changed

+79
-36
lines changed

2 files changed

+79
-36
lines changed

exir/passes/replace_broken_ops_with_function_ops_pass.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,10 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8-
from typing import Dict
9-
108
import torch
119

1210
from executorch.exir.pass_base import ExportPass
1311

14-
from torch._ops import OpOverload
15-
16-
17-
_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: Dict[OpOverload, OpOverload] = {
18-
torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default,
19-
torch.ops.aten.t.default: torch.ops.aten.t_copy.default,
20-
torch.ops.aten.view.default: torch.ops.aten.view_copy.default,
21-
torch.ops.aten.expand.default: torch.ops.aten.expand_copy.default,
22-
torch.ops.aten.permute.default: torch.ops.aten.permute_copy.default,
23-
torch.ops.aten.squeeze.default: torch.ops.aten.squeeze_copy.default,
24-
torch.ops.aten.unsqueeze.default: torch.ops.aten.unsqueeze_copy.default,
25-
torch.ops.aten.slice.Tensor: torch.ops.aten.slice_copy.Tensor,
26-
}
27-
2812

2913
class ReplaceBrokenOpsWithFunctionalOpsPass(ExportPass):
3014
"""
@@ -37,8 +21,22 @@ class ReplaceBrokenOpsWithFunctionalOpsPass(ExportPass):
3721

3822
# pyre-ignore
3923
def call_operator(self, op, args, kwargs, meta):
40-
if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS:
41-
return super().call_operator(
42-
_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op], args, kwargs, meta
24+
if op.is_view:
25+
namespace, op_full_name = op.name().split("::")
26+
split = op_full_name.split(".")
27+
if len(split) == 2:
28+
op_name, overload_name = split[0], split[1]
29+
elif len(split) == 1:
30+
# Add default overload if no overload listed
31+
op_name = op_full_name
32+
overload_name = "default"
33+
else:
34+
raise RuntimeError(
35+
f"Invalid op name expected only one '.' to be present: {op_full_name}"
36+
)
37+
38+
view_copy_op = getattr(
39+
getattr(getattr(torch.ops, namespace), f"{op_name}_copy"), overload_name
4340
)
41+
return super().call_operator(view_copy_op, args, kwargs, meta)
4442
return super().call_operator(op, args, kwargs, meta)

exir/tests/test_passes.py

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -595,33 +595,78 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
595595
self.assertEqual(counter, 1)
596596

597597
def test_compile_fix_broken_ops(self) -> None:
598-
# When pass an input of more than 4 dimensions to Linear
599-
# aten._unsafe_view is used under the hood
600-
x = torch.randn([2, 3, 4, 5])
601-
model: torch.nn.Linear = torch.nn.Linear(5, 5)
602-
603-
class Foo(torch.nn.Module):
604-
def __init__(self):
598+
class ExportableLoop(nn.Module):
599+
def __init__(self, hidden_size, out_channels):
605600
super().__init__()
606-
self.model = model
607-
608-
def forward(self, inp: torch.Tensor) -> torch.Tensor:
609-
return self.model(inp)
610-
611-
f = Foo()
601+
self.hidden_size = hidden_size
602+
self.B = nn.Parameter(torch.randn(hidden_size, 1)) # (H, in_channels)
603+
self.C = nn.Parameter(
604+
torch.randn(out_channels, hidden_size)
605+
) # (C_out, H)
606+
A = torch.randn(2, hidden_size)
607+
self.A_real = nn.Parameter(A[0].clone())
608+
self.A_imag = nn.Parameter(A[1].clone())
609+
610+
def update_state(self, h, x_t):
611+
# h: [B, 2, H], x_t: [B, H]
612+
hr, hi = h[:, 0, :], h[:, 1, :] # [B, H]
613+
hrn = hr * self.A_real - hi * self.A_imag + x_t # [B, H]
614+
hin = hi * self.A_real + hr * self.A_imag # [B, H]
615+
hn = torch.stack([hrn, hin], dim=1) # [B, 2, H]
616+
return hn, hrn
617+
618+
def forward(self, u):
619+
# u: [B, 1, T]
620+
x = torch.matmul(self.B, u) # (B, H, T)
621+
B, H, T = x.shape
622+
623+
h = torch.zeros(B, 2, H, device=x.device, dtype=x.dtype) # [B, 2, H]
624+
h_accum = torch.zeros(
625+
B, H, T, device=x.device, dtype=x.dtype
626+
) # [B, H, T]
627+
i = torch.tensor(0, device=x.device, dtype=torch.int64)
628+
one = torch.tensor(1, device=x.device, dtype=torch.int64)
629+
630+
def cond(i, h, h_accum):
631+
return i < T
632+
633+
def body(i, h, h_accum):
634+
x_t = x.index_select(-1, i.unsqueeze(0)).squeeze(
635+
-1
636+
) # ✅ safe for export
637+
h, hr = self.update_state(h, x_t) # h: [B, 2, H], hr: [B, H]
638+
h_accum = h_accum.index_copy(
639+
-1, i.unsqueeze(0), hr.unsqueeze(-1)
640+
) # [B, H, T]
641+
i_next = i + one
642+
return i_next, h, h_accum
643+
644+
_, h, h_accum = torch._higher_order_ops.while_loop(
645+
cond, body, (i, h, h_accum)
646+
)
647+
y = torch.matmul(self.C, h_accum).transpose(0, 1) # (B, C_out, T)
648+
return y
612649

613-
# ReplaceBrokenOpsWithFunctionalOpsPass is used in to_edge()
650+
# Instantiate and export
651+
model = ExportableLoop(hidden_size=128, out_channels=10)
652+
inp = torch.randn(1, 1, 32) # (B, in_channels=1, T=32)
653+
ep = export(model, (inp,))
614654
prog = to_edge(
615-
export(f, (x,), strict=True),
655+
ep,
616656
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
617657
)
618658
gm = prog.exported_program().graph_module
619659
count_after = 0
620660
for node in gm.graph.nodes:
621-
if node.target == torch.ops.aten._unsafe_view.default:
661+
if (
662+
node.target == torch.ops.aten.squeeze.dims
663+
or node.target == torch.ops.aten.select.int
664+
):
622665
count_after += 1
623666
self.assertEqual(count_after, 0)
624-
self.assertTrue(torch.allclose(prog.exported_program().module()(x), f(x)))
667+
self.assertTrue(
668+
torch.allclose(prog.exported_program().module()(inp), model(inp))
669+
)
625670

626671
def test_convert_symb_ops(self) -> None:
627672
class Foo(torch.nn.Module):

0 commit comments

Comments
 (0)