Skip to content

Commit 0db3444

Browse files
chunnienccopybara-github
authored andcommitted
Reduce number of retracing (Canonicalization) in conversion fx pass pipeline.
PiperOrigin-RevId: 749104077
1 parent 2619605 commit 0db3444

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

ai_edge_torch/_convert/conversion.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,12 @@ def _run_convert_passes(
3535
)
3636

3737
passes = [
38+
fx_passes.CastInputsBf16ToF32Pass(),
3839
fx_passes.BuildInterpolateCompositePass(),
39-
fx_passes.CanonicalizePass(),
4040
fx_passes.OptimizeLayoutTransposesPass(),
4141
fx_passes.CanonicalizePass(),
4242
fx_passes.BuildAtenCompositePass(),
4343
fx_passes.RemoveNonUserOutputsPass(),
44-
fx_passes.CastInputsBf16ToF32Pass(),
45-
fx_passes.CanonicalizePass(),
4644
]
4745

4846
# Debuginfo is not injected automatically by odml_torch. Only inject

ai_edge_torch/generative/fx_passes/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,5 @@ def run_generative_passes(
2323
) -> torch.export.ExportedProgram:
2424
return fx_infra.run_passes(
2525
exported_program,
26-
[
27-
RemoveSDPACompositeZeroMaskPass(),
28-
CanonicalizePass(),
29-
],
26+
[RemoveSDPACompositeZeroMaskPass()],
3027
)

ai_edge_torch/odml_torch/export.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,13 +264,16 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
264264
exported_program: The exported program to apply the pass.
265265
"""
266266

267+
is_modified = False
268+
267269
def in_i32(x: int):
268270
return -2147483648 <= x <= 2147483647
269271

270272
def to_int32(x: torch.Tensor):
271273
return torch.ops.aten._to_copy.default(x, dtype=torch.int32)
272274

273275
def rewrite_arange(node: torch.fx.Node):
276+
nonlocal is_modified
274277
tensor_meta = node.meta.get("tensor_meta", None)
275278
if not tensor_meta:
276279
return
@@ -282,12 +285,14 @@ def rewrite_arange(node: torch.fx.Node):
282285
return
283286
op = node.target
284287
node.target = lambda *args, **kwargs: to_int32(op(*args, **kwargs))
288+
is_modified = True
285289

286290
graph_module = exported_program.graph_module
287291
for node in graph_module.graph.nodes:
288292

289293
if node.target == torch.ops.aten.arange.start_step:
290294
rewrite_arange(node)
295+
return is_modified
291296

292297

293298
# TODO(b/331481564) Make this a ai_edge_torch FX pass.
@@ -351,9 +356,9 @@ def exported_program_to_mlir(
351356
exported_program,
352357
fx_infra.decomp.pre_lower_decomp(),
353358
)
354-
_convert_i64_to_i32(exported_program)
355-
# Run decompositions for retracing and cananicalization.
356-
exported_program = fx_infra.safe_run_decompositions(exported_program, {})
359+
if _convert_i64_to_i32(exported_program):
360+
# Run decompositions for retracing and cananicalization, if modified.
361+
exported_program = fx_infra.safe_run_decompositions(exported_program, {})
357362

358363
# Passes below mutate the exported program to a state not executable by torch.
359364
# Do not call run_decompositions after applying the passes.

0 commit comments

Comments
 (0)