Skip to content

Commit d2d2bf6

Browse files
authored
Implement a coversion pass from pow(E,x) to E-1 mul ops.
Differential Revision: D73473271 Pull Request resolved: #10564
1 parent a2b9952 commit d2d2bf6

File tree

2 files changed

+63
-15
lines changed

2 files changed

+63
-15
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2263,9 +2263,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
22632263

22642264

22652265
@register_cadence_pass(CadencePassAttribute(opt_level=1))
2266-
class ReplacePowWithMullPass(ExportPass):
2266+
class ReplacePowWithMulPass(ExportPass):
22672267
"""
2268-
Replace the pow op with degree 2 for a mul op.
2268+
Replace the pow op for a mul op.
22692269
"""
22702270

22712271
def call_operator(
@@ -2275,19 +2275,32 @@ def call_operator(
22752275
kwargs: Dict[str, Argument],
22762276
meta: NodeMetadata,
22772277
) -> ProxyValue:
2278-
# TODO(eigen): Add support for other degrees.
2279-
if (
2280-
op
2281-
not in {
2282-
exir_ops.edge.aten.pow.Scalar,
2278+
if not (
2279+
len(args) > 1
2280+
and isinstance(args[1], int)
2281+
and cast(int, args[1]) > 1
2282+
and cast(int, args[1]) < 5
2283+
and op
2284+
in {
2285+
exir_ops.edge.aten.pow.Tensor_Scalar,
22832286
}
2284-
or args[0] != 2
22852287
):
22862288
return super().call_operator(op, args, kwargs, meta)
22872289

2290+
x = args[0]
2291+
exponent = cast(int, args[1])
2292+
2293+
if exponent > 2:
2294+
for _ in range(exponent, 2, -1):
2295+
x = super().call_operator(
2296+
exir_ops.edge.aten.mul.Tensor,
2297+
(x, args[0]),
2298+
{},
2299+
meta,
2300+
)
22882301
return super().call_operator(
22892302
exir_ops.edge.aten.mul.Tensor,
2290-
(args[1], args[1]),
2303+
(x, args[0]),
22912304
{},
22922305
meta,
22932306
)
@@ -2429,5 +2442,5 @@ class CadenceReplaceOpsInGraph:
24292442
ReplaceWhereWithFullArgsWithWhereScalar,
24302443
ReplaceGeluWithApproximateGeluPass,
24312444
ReplaceSplitWithSlicePass,
2432-
ReplacePowWithMullPass,
2445+
ReplacePowWithMulPass,
24332446
]

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
ReplaceNopTransposeOrPermuteWithViewPass,
4242
ReplacePadWithCatPass,
4343
ReplacePermuteWithTransposePass,
44-
ReplacePowWithMullPass,
44+
ReplacePowWithMulPass,
4545
ReplaceRepeatWithCatPass,
4646
ReplaceScalarTensorWithFullPass,
4747
ReplaceScalarWithTensorArgPass,
@@ -1382,22 +1382,23 @@ def test_replace_split_with_sizes_with_slice(self):
13821382
2,
13831383
)
13841384

1385-
def test_replace_pow_with_mul(self):
1385+
@parameterized.expand([[2], [3], [4]])
1386+
def test_replace_pow_with_mul(self, exponent: int):
13861387
class Pow(torch.nn.Module):
13871388
def forward(self, input):
1388-
return torch.ops.aten.pow.Scalar(2, input)
1389+
return torch.ops.aten.pow.Tensor_Scalar(input, exponent)
13891390

13901391
input = torch.randn(2, 1, 64)
13911392

13921393
graph_module = export_to_edge(Pow(), (input,)).exported_program().graph_module
13931394

1394-
p = ReplacePowWithMullPass()
1395+
p = ReplacePowWithMulPass()
13951396
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
13961397

13971398
self.assertEqual(
13981399
count_node(
13991400
graph_after_passes,
1400-
exir_ops.edge.aten.pow.Scalar,
1401+
exir_ops.edge.aten.pow.Tensor_Scalar,
14011402
),
14021403
0,
14031404
)
@@ -1407,9 +1408,43 @@ def forward(self, input):
14071408
graph_after_passes,
14081409
exir_ops.edge.aten.mul.Tensor,
14091410
),
1411+
exponent - 1,
1412+
)
1413+
1414+
@parameterized.expand(
1415+
[
1416+
[1],
1417+
[1.5],
1418+
]
1419+
)
1420+
def test_replace_pow_with_mul_not_applied(self, exponent):
1421+
class Pow(torch.nn.Module):
1422+
def forward(self, input):
1423+
return torch.ops.aten.pow.Tensor_Scalar(input, exponent)
1424+
1425+
input = torch.randn(2, 1, 64)
1426+
1427+
graph_module = export_to_edge(Pow(), (input,)).exported_program().graph_module
1428+
1429+
p = ReplacePowWithMulPass()
1430+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1431+
1432+
self.assertEqual(
1433+
count_node(
1434+
graph_after_passes,
1435+
exir_ops.edge.aten.pow.Tensor_Scalar,
1436+
),
14101437
1,
14111438
)
14121439

1440+
self.assertEqual(
1441+
count_node(
1442+
graph_after_passes,
1443+
exir_ops.edge.aten.mul.Tensor,
1444+
),
1445+
0,
1446+
)
1447+
14131448

14141449
class TestReplaceIm2rowWithViewPass(unittest.TestCase):
14151450
def test_no_replacement_for_conv(self):

0 commit comments

Comments
 (0)