Skip to content

Commit 8b3b028

Browse files
authored
Fix memory planning algo for blocked mem IDs.
Differential Revision: D77310021 Pull Request resolved: #11969
1 parent 883084d commit 8b3b028

File tree

4 files changed

+118
-43
lines changed

4 files changed

+118
-43
lines changed

backends/cadence/aot/memory_constraints.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import typing
1212
from collections import defaultdict
1313
from dataclasses import dataclass
14-
from typing import cast, DefaultDict, Iterable, Optional, Sequence
14+
from typing import Callable, cast, DefaultDict, Iterable, Optional, Sequence, TypeAlias
1515

1616
import torch
1717
import torch.fx
@@ -573,23 +573,34 @@ def compute_slice_and_select_loc_constraints(
573573
graph_module.recompile()
574574

575575

576+
ConstraintsGenPass: TypeAlias = Callable[
577+
[MemConstraints],
578+
Callable[[torch.fx.GraphModule], Optional[PassResult]],
579+
]
580+
581+
576582
# The class to generate all the constraints that will be passed on to the memory
577583
# planning algorithm.
578584
class GenerateMemConstraints:
579585
def __init__(
580586
self,
581587
mem_constraints: MemConstraints,
582-
additional_constraint_gen_passes: list | None = None,
588+
additional_constraint_gen_passes: Sequence[ConstraintsGenPass] | None = None,
583589
) -> None:
584-
self.mem_constraints = mem_constraints
585-
self.additional_constraint_gen_passes = additional_constraint_gen_passes or []
590+
self.mem_constraints: MemConstraints = mem_constraints
591+
self.additional_constraint_gen_passes: Sequence[ConstraintsGenPass] = (
592+
additional_constraint_gen_passes or []
593+
)
586594

587595
def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
588-
constraint_gen_passes: list = [
589-
GenerateMemoryViewConstraints,
590-
GenerateSliceAndSelectNopConstraints,
591-
GenerateCatNopConstraints,
592-
] + self.additional_constraint_gen_passes
596+
constraint_gen_passes: Sequence[ConstraintsGenPass] = cast(
597+
list[ConstraintsGenPass],
598+
[
599+
GenerateMemoryViewConstraints,
600+
GenerateSliceAndSelectNopConstraints,
601+
GenerateCatNopConstraints,
602+
],
603+
) + list(self.additional_constraint_gen_passes)
593604
# Create a filter using the opt level in mem_constraints, and filter
594605
# the relevant passes.
595606
pass_filter = create_cadence_pass_filter(self.mem_constraints.opt_level)
@@ -602,6 +613,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
602613
typing.Callable[[torch.fx.GraphModule], Optional[PassResult]],
603614
]
604615
],
616+
# pyre-ignore[6]: Incompatible parameter type.
605617
list(filter(pass_filter, constraint_gen_passes)),
606618
)
607619
]

backends/cadence/aot/memory_planning.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
import collections
1010
import itertools
1111
import logging
12-
from typing import Callable, Iterable, List, Optional, Set, Tuple, TypeAlias
12+
from typing import Iterable, List, Optional, Sequence, Set, Tuple
1313

1414
import torch
1515
from executorch.backends.cadence.aot.memory_constraints import MemConstraints
1616
from executorch.backends.cadence.aot.memory_planning_algo import (
17+
ConstraintsGenPass,
1718
get_aligned_offset,
1819
MemoryPlanningAlgo,
1920
MemoryPlanningState,
@@ -126,10 +127,9 @@ def plan_spec(self, spec: TensorSpec, state: MemoryPlanningState) -> None:
126127
prev_offset,
127128
)
128129
if spec.mem_offset is None:
129-
if get_aligned_offset(
130-
prev_offset + spec.allocated_memory,
131-
self.get_alignment(spec.mem_id),
132-
) > self.get_size(spec.mem_id):
130+
spec.mem_offset = prev_offset
131+
if not self.is_valid_placement(spec):
132+
spec.mem_offset = None
133133
continue
134134
else:
135135
spec.mem_offset = prev_offset
@@ -344,12 +344,6 @@ def print_memory_planning_info(
344344
)
345345

346346

347-
ConstraintGenPassType: TypeAlias = Callable[
348-
[MemConstraints],
349-
Callable[[torch.fx.GraphModule], Optional[PassResult]],
350-
]
351-
352-
353347
class CadenceMemoryPlanning:
354348
def __init__(
355349
self,
@@ -358,7 +352,7 @@ def __init__(
358352
mem_algo: int,
359353
alloc_graph_input: bool = True,
360354
alloc_graph_output: bool = True,
361-
additional_constraint_gen_passes: Optional[list[ConstraintGenPassType]] = None,
355+
additional_constraint_gen_passes: Optional[Sequence[ConstraintsGenPass]] = None,
362356
) -> None:
363357
self.memory_config = memory_config
364358
self.opt_level = opt_level
@@ -379,7 +373,7 @@ def get_mem_algos(
379373
opt_level: int,
380374
alloc_graph_input: bool,
381375
alloc_graph_output: bool,
382-
additional_constraint_gen_passes: Optional[list[ConstraintGenPassType]],
376+
additional_constraint_gen_passes: Optional[Sequence[ConstraintsGenPass]],
383377
) -> list[MemoryPlanningAlgo]:
384378
return [
385379
PositionBasedGreedyWithHierarchy(

backends/cadence/aot/memory_planning_algo.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
import logging
66
import math
77
from abc import ABC, abstractmethod
8-
from typing import Callable, Optional
8+
from typing import Optional, Sequence
99

1010
import torch
1111
from executorch.backends.cadence.aot.memory_constraints import (
12+
ConstraintsGenPass,
1213
GenerateMemConstraints,
1314
MemConstraints,
1415
)
1516
from executorch.backends.cadence.aot.utils import MemoryConfig
1617
from executorch.exir.memory_planning import Verifier
17-
from executorch.exir.pass_base import PassResult
1818
from executorch.exir.tensor import TensorSpec
1919
from torch.export.exported_program import ExportGraphSignature
2020

@@ -68,18 +68,13 @@ def __init__(
6868
self,
6969
memory_config: MemoryConfig,
7070
placement_constraints: MemConstraints,
71-
additional_constraint_gen_passes: Optional[
72-
list[
73-
Callable[
74-
[MemConstraints],
75-
Callable[[torch.fx.GraphModule], Optional[PassResult]],
76-
]
77-
]
78-
] = None,
71+
additional_constraint_gen_passes: Optional[Sequence[ConstraintsGenPass]] = None,
7972
) -> None:
80-
self.memory_config = memory_config
81-
self.placement_constraints = placement_constraints
82-
self.additional_constraint_gen_passes = additional_constraint_gen_passes
73+
self.memory_config: MemoryConfig = memory_config
74+
self.placement_constraints: MemConstraints = placement_constraints
75+
self.additional_constraint_gen_passes: Optional[
76+
Sequence[ConstraintsGenPass]
77+
] = additional_constraint_gen_passes
8378

8479
def get_num_memories(self) -> int:
8580
"""Get num memories indexed from 1..N, compatible with EXIR's spec.mem_id."""
@@ -102,10 +97,14 @@ def populate_constraints(self, graph_module: torch.fx.GraphModule) -> None:
10297
)(graph_module)
10398

10499
def is_valid_placement(self, spec: TensorSpec) -> bool:
105-
return get_aligned_offset(
100+
"""Returns true if the spec can be placed at the given memory id."""
101+
end_of_allocation = get_aligned_offset(
106102
spec.mem_offset + spec.allocated_memory,
107103
self.get_alignment(spec.mem_id),
108-
) <= self.get_size(spec.mem_id)
104+
)
105+
return end_of_allocation <= self.get_size(
106+
spec.mem_id
107+
) and not self.placement_constraints.is_mem_id_in_blocklist(spec, spec.mem_id)
109108

110109
@abstractmethod
111110
def plan(
@@ -133,10 +132,7 @@ def __call__(
133132
# First plan the memory allocation for specs without relative constraints.
134133
specs_without_relative_constraints = set(
135134
filter(
136-
lambda spec: not self.placement_constraints.skipped_spec(spec)
137-
and not self.placement_constraints.is_mem_id_in_blocklist(
138-
spec, spec.mem_id
139-
),
135+
lambda spec: not self.placement_constraints.skipped_spec(spec),
140136
specs,
141137
)
142138
)

backends/cadence/aot/tests/test_memory_passes.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,33 @@
88

99
import math
1010
import unittest
11-
from typing import cast, List, Optional
11+
from typing import cast, List, Optional, Sequence
1212

1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
1515
from executorch.backends.cadence.aot import compiler
1616
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
17+
from executorch.backends.cadence.aot.memory_constraints import ConstraintsGenPass
1718
from executorch.backends.cadence.aot.memory_planning import (
1819
CadenceMemoryPlanning,
1920
find_peak_memory_usage,
2021
)
21-
from executorch.backends.cadence.aot.pass_utils import count_node
22+
from executorch.backends.cadence.aot.pass_utils import (
23+
CadencePassAttribute,
24+
count_node,
25+
register_cadence_pass,
26+
)
2227
from executorch.backends.cadence.aot.typing_stubs import expand
2328
from executorch.backends.cadence.aot.utils import (
2429
get_default_memory_config,
2530
MemoryConfig,
2631
)
2732
from executorch.exir.dialects._ops import ops as exir_ops
2833
from executorch.exir.memory_planning import collect_specs_from_nodes
34+
from executorch.exir.pass_base import PassBase, PassResult
2935
from executorch.exir.passes.spec_prop_pass import SpecPropPass
3036
from executorch.exir.tests.models import MultiLayerPerceptron
37+
from parameterized import parameterized
3138
from torch.fx import GraphModule
3239

3340

@@ -230,6 +237,7 @@ def run_memory_planning(
230237
alloc_graph_input: bool = True,
231238
alloc_graph_output: bool = True,
232239
memory_config: Optional[MemoryConfig] = None,
240+
additional_constraint_gen_passes: Optional[Sequence[ConstraintsGenPass]] = None,
233241
) -> GraphModule:
234242
if memory_config is None:
235243
memory_config = get_default_memory_config()
@@ -240,6 +248,7 @@ def run_memory_planning(
240248
mem_algo=mem_algo,
241249
alloc_graph_input=alloc_graph_input,
242250
alloc_graph_output=alloc_graph_output,
251+
additional_constraint_gen_passes=additional_constraint_gen_passes,
243252
)(graph_module).graph_module
244253

245254
@expand(
@@ -984,3 +993,67 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
984993
):
985994
if spec and spec.mem_offset:
986995
self.assertEqual(spec.mem_offset % 37, 0)
996+
997+
@parameterized.expand([0, 1])
998+
def test_block_mem_id(self, mem_algo: int) -> None:
999+
builder = GraphBuilder()
1000+
x = builder.placeholder("x", torch.randn(16))
1001+
add = builder.call_operator(
1002+
op=torch.ops.aten.add.Scalar,
1003+
args=(x, 2.0),
1004+
)
1005+
mul = builder.call_operator(
1006+
op=torch.ops.aten.mul.Scalar,
1007+
args=(add, 2.0),
1008+
)
1009+
builder.output([mul])
1010+
original = builder.get_graph_module()
1011+
1012+
dummy_memory_config = MemoryConfig([1024, 1024, 1024, 1024])
1013+
1014+
add_scalar_block_mem_ids = [2, 3]
1015+
mul_scalar_block_mem_ids = [1, 3]
1016+
1017+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
1018+
class DummyMemIdBlockConstraintGen(PassBase):
1019+
"""Blocks placement based on op type.
1020+
add: blocks 2, 3
1021+
mul: blocks 1, 3
1022+
1023+
"""
1024+
1025+
def __init__(self, memory_constraints: MemoryConfig):
1026+
self.memory_constraints = memory_constraints
1027+
1028+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
1029+
for node in graph_module.graph.find_nodes(
1030+
op="call_function", target=torch.ops.aten.add.Scalar
1031+
):
1032+
spec = node.meta["spec"]
1033+
for mem_id in add_scalar_block_mem_ids:
1034+
self.memory_constraints.add_mem_id_to_blocklist(spec, mem_id)
1035+
for node in graph_module.graph.find_nodes(
1036+
op="call_function", target=torch.ops.aten.mul.Scalar
1037+
):
1038+
spec = node.meta["spec"]
1039+
for mem_id in mul_scalar_block_mem_ids:
1040+
self.memory_constraints.add_mem_id_to_blocklist(spec, mem_id)
1041+
1042+
graph_module = self.run_memory_planning(
1043+
original,
1044+
mem_algo=mem_algo,
1045+
memory_config=dummy_memory_config,
1046+
additional_constraint_gen_passes=[DummyMemIdBlockConstraintGen],
1047+
)
1048+
for node in graph_module.graph.find_nodes(
1049+
op="call_function", target=torch.ops.aten.add.Scalar
1050+
):
1051+
spec = node.meta["spec"]
1052+
self.assertIsNotNone(spec.mem_id)
1053+
self.assertNotIn(spec.mem_id, add_scalar_block_mem_ids)
1054+
for node in graph_module.graph.find_nodes(
1055+
op="call_function", target=torch.ops.aten.mul.Scalar
1056+
):
1057+
spec = node.meta["spec"]
1058+
self.assertIsNotNone(spec.mem_id)
1059+
self.assertNotIn(spec.mem_id, mul_scalar_block_mem_ids)

0 commit comments

Comments
 (0)