Skip to content

Commit c2e51fe

Browse files
authored
Add support for parameters/constants/buffer in program builder.
Differential Revision: D78421771 Pull Request resolved: #12555
1 parent 25637b2 commit c2e51fe

File tree

4 files changed

+148
-5
lines changed

4 files changed

+148
-5
lines changed

backends/cadence/aot/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,19 @@ python_library(
199199
],
200200
)
201201

202+
python_unittest(
203+
name = "test_program_builder",
204+
srcs = [
205+
"tests/test_program_builder.py",
206+
],
207+
typing = True,
208+
deps = [
209+
":program_builder",
210+
"//caffe2:torch",
211+
"//later:lib",
212+
],
213+
)
214+
202215
python_library(
203216
name = "fuse_ops",
204217
srcs = [

backends/cadence/aot/graph_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,13 @@ def placeholder(
6666
) -> ProxyValue:
6767
if not isinstance(fake_tensor, FakeTensor):
6868
fake_tensor = self.fake_tensor_mode.from_tensor(fake_tensor)
69-
logging.info(f"Creating placeholder {target} => {fake_tensor.shape}")
69+
logging.debug(f"Creating placeholder {target} => {fake_tensor.shape}")
7070
placeholder = super().placeholder(target, fake_tensor, NodeMetadata({}))
7171
return placeholder
7272

7373
# pyre-ignore[14]: Inconsistent override.
7474
def output(self, results: list[ProxyValue]) -> ProxyValue:
75-
logging.info(f"Creating outputs {results}")
75+
logging.debug(f"Creating outputs {results}")
7676
return super().output(results, NodeMetadata({}))
7777

7878
def get_graph_module(self) -> torch.fx.GraphModule:

backends/cadence/aot/program_builder.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,18 @@ def __init__(self) -> None:
3434
def insert_input_spec(
3535
self, target: str, input_kind: InputKind, value: Tensor
3636
) -> None:
37-
if input_kind == InputKind.USER_INPUT:
38-
self.input_specs.append(
39-
InputSpec(input_kind, TensorArgument(target), target=target)
37+
persistent: Optional[bool] = None
38+
if input_kind == InputKind.BUFFER:
39+
persistent = True
40+
self.input_specs.append(
41+
InputSpec(
42+
input_kind, TensorArgument(target), target=target, persistent=persistent
4043
)
44+
)
45+
if input_kind == InputKind.PARAMETER or input_kind == InputKind.BUFFER:
46+
self.state_dict[target] = value
47+
elif input_kind == InputKind.CONSTANT_TENSOR:
48+
self.constants[target] = value
4149

4250
def placeholder(
4351
self,
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
# pyre-strict
4+
5+
import torch
6+
from executorch.backends.cadence.aot.program_builder import ProgramBuilder
7+
from later.unittest import TestCase
8+
from torch.export.graph_signature import InputKind, OutputKind
9+
10+
11+
class TestProgramBuilder(TestCase):
12+
def test_user_input_with_parameter(self) -> None:
13+
inp = torch.randn([3, 5])
14+
w = torch.nn.Parameter(torch.randn([5]))
15+
# Create a exported program with one user input and one parameter.
16+
# Returns inp + w, w + 2 tuple.
17+
builder = ProgramBuilder()
18+
inp_proxy = builder.placeholder("inp", inp)
19+
w_proxy = builder.placeholder("w", w, input_kind=InputKind.PARAMETER)
20+
add = builder.call_operator(torch.ops.aten.add.Tensor, (inp_proxy, w_proxy))
21+
add_w = builder.call_operator(torch.ops.aten.add.Scalar, (w_proxy, 2))
22+
builder.output([add, add_w])
23+
program = builder.get_program()
24+
25+
self.assertEqual(len(program.graph_signature.input_specs), 2)
26+
self.assertEqual(
27+
program.graph_signature.input_specs[0].kind, InputKind.USER_INPUT
28+
)
29+
self.assertEqual(
30+
program.graph_signature.input_specs[1].kind, InputKind.PARAMETER
31+
)
32+
self.assertEqual(len(program.graph_signature.output_specs), 2)
33+
self.assertEqual(
34+
program.graph_signature.output_specs[0].kind, OutputKind.USER_OUTPUT
35+
)
36+
self.assertEqual(
37+
program.graph_signature.output_specs[1].kind, OutputKind.USER_OUTPUT
38+
)
39+
40+
def test_user_input_with_constant(self) -> None:
41+
inp = torch.randn([3, 5])
42+
const = torch.randn([5])
43+
# Create a exported program with one user input and one constant tensor.
44+
# Returns inp + const
45+
builder = ProgramBuilder()
46+
inp_proxy = builder.placeholder("inp", inp)
47+
const_proxy = builder.placeholder(
48+
"const", const, input_kind=InputKind.CONSTANT_TENSOR
49+
)
50+
add = builder.call_operator(torch.ops.aten.add.Tensor, (inp_proxy, const_proxy))
51+
builder.output([add])
52+
program = builder.get_program()
53+
54+
# Verify the program has the correct inputs and outputs
55+
self.assertEqual(len(program.graph_signature.input_specs), 2)
56+
self.assertEqual(
57+
program.graph_signature.input_specs[0].kind, InputKind.USER_INPUT
58+
)
59+
self.assertEqual(
60+
program.graph_signature.input_specs[1].kind, InputKind.CONSTANT_TENSOR
61+
)
62+
self.assertEqual(len(program.graph_signature.output_specs), 1)
63+
self.assertEqual(
64+
program.graph_signature.output_specs[0].kind, OutputKind.USER_OUTPUT
65+
)
66+
67+
def test_mutable_buffer(self) -> None:
68+
inp = torch.randn([3, 5])
69+
buffer = torch.randn([5])
70+
# Create a exported program with one user input and one buffer that gets mutated.
71+
# Returns inp + buffer, updated_buffer
72+
builder = ProgramBuilder()
73+
inp_proxy = builder.placeholder("inp", inp)
74+
buffer_proxy = builder.placeholder(
75+
"buffer", buffer, input_kind=InputKind.BUFFER
76+
)
77+
add = builder.call_operator(
78+
torch.ops.aten.add.Tensor, (inp_proxy, buffer_proxy)
79+
)
80+
# Mutate the buffer by adding 1
81+
updated_buffer = builder.call_operator(
82+
torch.ops.aten.add.Scalar, (buffer_proxy, 1)
83+
)
84+
builder.output(
85+
[add, updated_buffer], [OutputKind.USER_OUTPUT, OutputKind.BUFFER_MUTATION]
86+
)
87+
program = builder.get_program()
88+
89+
# Verify the program has the correct inputs and outputs
90+
self.assertEqual(len(program.graph_signature.input_specs), 2)
91+
self.assertEqual(
92+
program.graph_signature.input_specs[0].kind, InputKind.USER_INPUT
93+
)
94+
self.assertEqual(program.graph_signature.input_specs[1].kind, InputKind.BUFFER)
95+
self.assertEqual(len(program.graph_signature.output_specs), 2)
96+
self.assertEqual(
97+
program.graph_signature.output_specs[0].kind, OutputKind.USER_OUTPUT
98+
)
99+
self.assertEqual(
100+
program.graph_signature.output_specs[1].kind, OutputKind.BUFFER_MUTATION
101+
)
102+
103+
def test_user_input_mutation(self) -> None:
104+
inp = torch.randn([3, 5])
105+
# Create a exported program with one user input that gets mutated.
106+
# Returns updated_inp
107+
builder = ProgramBuilder()
108+
inp_proxy = builder.placeholder("inp", inp)
109+
# Mutate the input by adding 1
110+
updated_inp = builder.call_operator(torch.ops.aten.add.Scalar, (inp_proxy, 1))
111+
builder.output([updated_inp], [OutputKind.USER_INPUT_MUTATION])
112+
program = builder.get_program()
113+
114+
# Verify the program has the correct inputs and outputs
115+
self.assertEqual(len(program.graph_signature.input_specs), 1)
116+
self.assertEqual(
117+
program.graph_signature.input_specs[0].kind, InputKind.USER_INPUT
118+
)
119+
self.assertEqual(len(program.graph_signature.output_specs), 1)
120+
self.assertEqual(
121+
program.graph_signature.output_specs[0].kind, OutputKind.USER_INPUT_MUTATION
122+
)

0 commit comments

Comments
 (0)