Skip to content

Commit 954f2cb

Browse files
authored
Fix linter (#10404)
1 parent dfd3dbe commit 954f2cb

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,9 @@ class FuseCascadedViewOps(ExportPass):
528528

529529
def fuse_cascaded_view_ops(self, graph_module: torch.fx.GraphModule):
530530
view_target = exir_ops.edge.aten.view_copy.default
531-
for view_node in graph_module.graph.find_nodes(op="call_function", target=view_target, sort=True):
531+
for view_node in graph_module.graph.find_nodes(
532+
op="call_function", target=view_target, sort=True
533+
):
532534
input_view = view_node.args[0]
533535
if input_view.op != "call_function" or input_view.target != view_target:
534536
continue

backends/cadence/aot/replace_ops.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,7 +2259,6 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
22592259
return result
22602260

22612261

2262-
22632262
@register_cadence_pass(CadencePassAttribute(opt_level=1))
22642263
class ReplacePowWithMullPass(ExportPass):
22652264
"""
@@ -2274,9 +2273,13 @@ def call_operator(
22742273
meta: NodeMetadata,
22752274
) -> ProxyValue:
22762275
# TODO(eigen): Add support for other degrees.
2277-
if op not in {
2278-
exir_ops.edge.aten.pow.Scalar,
2279-
} or args[0] != 2:
2276+
if (
2277+
op
2278+
not in {
2279+
exir_ops.edge.aten.pow.Scalar,
2280+
}
2281+
or args[0] != 2
2282+
):
22802283
return super().call_operator(op, args, kwargs, meta)
22812284

22822285
return super().call_operator(

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@
3030
ReplaceEmptyTensorsWithFullPass,
3131
ReplaceFunctionallyEquivalentOpTargets,
3232
ReplaceGeluWithApproximateGeluPass,
33-
ReplacePowWithMullPass,
3433
ReplaceIm2RowWithViewPass,
3534
ReplaceLinearWithFullyConnectedOpPass,
3635
ReplaceMMWithAddMMPass,
3736
ReplaceNopTransposeOrPermuteWithViewPass,
3837
ReplacePadWithCatPass,
3938
ReplacePermuteWithTransposePass,
39+
ReplacePowWithMullPass,
4040
ReplaceRepeatWithCatPass,
4141
ReplaceScalarTensorWithFullPass,
4242
ReplaceScalarWithTensorArgPass,
@@ -1338,7 +1338,7 @@ def test_replace_split_with_sizes_with_slice(self):
13381338
def test_replace_pow_with_mul(self):
13391339
class Pow(torch.nn.Module):
13401340
def forward(self, input):
1341-
return torch.ops.aten.pow.Scalar(2, input)
1341+
return torch.ops.aten.pow.Scalar(2, input)
13421342

13431343
input = torch.randn(2, 1, 64)
13441344

@@ -1347,7 +1347,6 @@ def forward(self, input):
13471347
p = ReplacePowWithMullPass()
13481348
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
13491349

1350-
13511350
self.assertEqual(
13521351
count_node(
13531352
graph_after_passes,

0 commit comments

Comments
 (0)