Skip to content

Commit 2f8ecf3

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Program.fbs change to support serialized mutable state (#4216)
Summary: Pull Request resolved: #4216 Need a way to indicate values that have a meaningful initial state serialized in the program, who also are able to be mutated on device. https://docs.google.com/document/d/1D8WpMmIiQxU_n5OYWXl3mrpBYUewz79izyAO2UknSsM/edit?usp=sharing Reviewed By: dbort Differential Revision: D58747605 fbshipit-source-id: 096b40443ba4ecc8044a4d397838309e8c97c8fa
1 parent a567abf commit 2f8ecf3

File tree

10 files changed

+87
-34
lines changed

10 files changed

+87
-34
lines changed

exir/emit/_emitter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,7 @@ def _get_empty_tensor_evalue() -> EValue:
10901090
dim_order=[],
10911091
requires_grad=False,
10921092
layout=0,
1093-
constant_buffer_idx=0,
1093+
data_buffer_idx=0,
10941094
allocation_info=None,
10951095
shape_dynamism=TensorShapeDynamism.STATIC,
10961096
)

exir/emit/test/test_emit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def check_tensor_buffer_loc(
109109
value = typing.cast(schema.Tensor, values[value_index].val)
110110
self.assertIsInstance(value, schema.Tensor)
111111

112-
self.assertEqual(value.constant_buffer_idx, exp_buffer_idx)
112+
self.assertEqual(value.data_buffer_idx, exp_buffer_idx)
113113

114114
if not value.allocation_info:
115115
self.assertIsNone(exp_mem_id)
@@ -810,7 +810,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
810810
< non_const_buffer_size_without_const_prop_pass[1]
811811
)
812812

813-
# cant compare plans directly with __eq__ because of the plan names, and constant_buffer_idx in tensor values
813+
# cant compare plans directly with __eq__ because of the plan names, and data_buffer_idx in tensor values
814814
def _compare_execution_plans(
815815
self, plan_single: ExecutionPlan, plan_merged: ExecutionPlan
816816
) -> None:

exir/print_program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _format_evalue( # noqa: C901
7979
evstr = "\033[34m"
8080
if isinstance(evalue.val, Tensor):
8181
tensor = evalue.val
82-
if tensor.constant_buffer_idx > 0:
82+
if tensor.data_buffer_idx > 0:
8383
assert not _is_dynamic_shape_tensor(
8484
tensor
8585
), "A constant tensor can not be dynamic shape"

exir/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class Tensor:
5151
dim_order: List[bytes]
5252
requires_grad: bool
5353
layout: int
54-
constant_buffer_idx: int
54+
data_buffer_idx: int
5555
allocation_info: Optional[AllocationDetails]
5656

5757
# check schema.fbs for explanations

exir/tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def make_allocation_info(mem_id: int, mem_offset: int) -> schema.AllocationDetai
308308

309309

310310
def make_tensor_value(
311-
constant_buffer_idx: int,
311+
data_buffer_idx: int,
312312
allocation_info: Optional[schema.AllocationDetails],
313313
spec: TensorSpec,
314314
) -> schema.Tensor:
@@ -341,7 +341,7 @@ def to_list(
341341
sizes=tensor_size,
342342
dim_order=tensor_dim_order,
343343
requires_grad=spec.requires_grad,
344-
constant_buffer_idx=constant_buffer_idx,
344+
data_buffer_idx=data_buffer_idx,
345345
allocation_info=allocation_info,
346346
layout=layout_enum(spec.layout),
347347
shape_dynamism=spec.shape_dynamism,

exir/tests/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def get_test_program() -> Program:
4949
dim_order=typing.cast(List[bytes], [0, 1]),
5050
requires_grad=False,
5151
layout=0,
52-
constant_buffer_idx=0,
52+
data_buffer_idx=0,
5353
allocation_info=AllocationDetails(
5454
memory_id=1,
5555
memory_offset_high=0,

exir/tests/test_verification.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
79
import unittest
810

911
import torch
@@ -47,7 +49,7 @@ def f(x: torch.Tensor) -> torch.Tensor:
4749
for val_idx in range(len(test.execution_plan.values)):
4850
val = test.execution_plan.values[val_idx].val
4951
if not (
50-
isinstance(val, Tensor) and val.constant_buffer_idx == 0
52+
isinstance(val, Tensor) and val.data_buffer_idx == 0
5153
) and not isinstance(val, TensorList):
5254
test.load_value(val_idx)
5355
vlist = test.get_value_list()

exir/verification/interpreter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,11 @@ def get_constant_tensors(self) -> List[Tensor]:
166166
tensors = []
167167
for elem in self.execution_plan.values:
168168
val = elem.val
169-
if isinstance(val, Tensor) and val.constant_buffer_idx != 0:
169+
if isinstance(val, Tensor) and val.data_buffer_idx != 0:
170170
# load val into res
171171
# pyre-fixme[16]
172172
tensor = bindings.convert_to_tensor(
173-
self.data_buffers[val.constant_buffer_idx],
173+
self.data_buffers[val.data_buffer_idx],
174174
val.scalar_type,
175175
val.sizes,
176176
stride_from_dim_order(val.sizes, val.dim_order),
@@ -239,7 +239,7 @@ def load_from_value_list(self, idx: int) -> None: # noqa
239239
tensor_list.append(self._value_list[i])
240240
self._value_list[idx] = tensor_list
241241
elif isinstance(val, Tensor):
242-
if val.constant_buffer_idx == 0:
242+
if val.data_buffer_idx == 0:
243243
# TODO(zhengxu) Verify that argument is actually an out variant
244244
self._value_list[idx] = torch.empty(
245245
val.sizes, dtype=get_scalar_type(val.scalar_type)
@@ -248,7 +248,7 @@ def load_from_value_list(self, idx: int) -> None: # noqa
248248
# Constant Tensor conversion
249249
# pyre-fixme [16]
250250
tensor = bindings.convert_to_tensor(
251-
self.data_buffers[val.constant_buffer_idx],
251+
self.data_buffers[val.data_buffer_idx],
252252
val.scalar_type,
253253
val.sizes,
254254
stride_from_dim_order(val.sizes, val.dim_order),

runtime/executor/tensor_parser_exec_aten.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ __ET_NODISCARD Result<void*> getTensorDataPtr(
5353
const Program* program,
5454
size_t nbytes,
5555
HierarchicalAllocator* allocator) {
56-
if (s_tensor->constant_buffer_idx() > 0) {
57-
auto data = program->get_constant_buffer_data(
58-
s_tensor->constant_buffer_idx(), nbytes);
56+
if (s_tensor->data_buffer_idx() > 0) {
57+
auto data =
58+
program->get_constant_buffer_data(s_tensor->data_buffer_idx(), nbytes);
5959
if (!data.ok()) {
6060
return data.error();
6161
}

schema/program.fbs

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,20 @@ enum TensorShapeDynamism : byte {
5353
DYNAMIC_UNBOUND = 2,
5454
}
5555

56+
57+
// Table to put additional information about tensors in that is not applicable
58+
// to the vast majority of tensors in the vast majority of programs.
59+
table ExtraTensorInfo {
60+
// [Optional] Specifies the SubsegmentOffsets in
61+
// program.mutable_data_segments that specifies where the data is located in.
62+
// If not present and the data is located in a segment, then the data is in
63+
// the first index.
64+
mutable_data_segments_idx:uint64;
65+
66+
// [Optional] The unique name of the tensor. e.g. 'mod.linear.weight'
67+
fully_qualified_name:string;
68+
}
69+
5670
table Tensor {
5771
scalar_type:ScalarType;
5872

@@ -63,26 +77,47 @@ table Tensor {
6377

6478
sizes:[int];
6579

66-
// Specifies in what order the dimensions are laid out in memory (from outer to inner).
67-
// For example, given a rank 3 Tensor of size (3, 5, 2). If we name dimensions: [row, column, batch], then a dim_order of:
68-
// (2, 0, 1) represents a [batch, row, column] ordering where "column" is the innermost dimension, then comes "row", and the outermost dimension is "batch".
69-
// (0, 2, 1) represents a [row, batch, column] ordering where "column" is the innermost dimension, then comes "batch", and the outermost dimension is "row".
80+
// Specifies in what order the dimensions are laid out in memory (from outer
81+
// to inner).
82+
//
83+
// For example, given a rank 3 Tensor of size (3, 5, 2). If we name
84+
// dimensions: [row, column, batch], then a dim_order of:
85+
// - (2, 0, 1) represents a [batch, row, column] ordering where "column" is
86+
// the innermost dimension, then comes "row", and the outermost dimension is
87+
// "batch".
88+
// - (0, 2, 1) represents a [row, batch, column] ordering where "column" is
89+
// the innermost dimension, then comes "batch", and the outermost dimension
90+
// is "row".
7091
dim_order:[ubyte];
7192

7293
// out of scope M1
7394
requires_grad:bool;
7495

75-
// Overall, a Tensor is either constant or non-constant, except we differentiate 2 special
76-
// variants of non-constant Tensor ("input" and control-flow "placeholder") as a special
77-
// optimization to avoid holding unnecessary AllocationDetails.
96+
// Overall, a Tensor is either constant or mutable. At method load time
97+
// constant tensors receive a dataptr into the serialized program. Mutable
98+
// tensors can either receive a pointer from the heirarchical allocator or a
99+
// nullptr if they will receive a data pointer at execution time (inputs
100+
// and control flow placeholders can be like this). Mutable tensors may or
101+
// may not also have an initial value in the serialized program.
102+
//
78103
// In summary:
79-
// constant_buffer_idx > 0, allocation_info = Null: Tensor is a constant
80-
// constant_buffer_idx = 0, allocation_info = Non Null: Tensor is a non-constant.
81-
// constant_buffer_idx = 0, allocation_info = Null: Tensor is a non-constant
82-
// that will receive a dataptr at input time or during execution.
104+
// data_buffer_idx > 0, allocation_info = Null: Tensor is a constant.
105+
// data_buffer_idx = 0, allocation_info = Non Null: Tensor is mutable and
106+
// will receive a dataptr at method load time.
107+
// data_buffer_idx = 0, allocation_info = Null: Tensor is mutable and
108+
// will receive a dataptr at input time or during execution.
109+
// data_buffer_idx > 0, allocation_info = Non Null: Tensor is mutable and
110+
// will receive a dataptr at method load time, and has an initial state.
83111
//
84-
// Index to the program's constant buffer table, value 0 is reserved to indicate non constant
85-
constant_buffer_idx:uint;
112+
// Tensor data is stored inline if program.constant_buffer is null. Otherwise
113+
// it is in a segment. If this tensor's allocation_info is null then the
114+
// tensor data location is specified by program.constant_segment. If the
115+
// allocation_info is non_null then the data is somewhere in
116+
// program.mutable_data_segments. If tensor_info is Null, then the data is
117+
// in program.mutable_data_segments[0] otherwise if tensor_info is non-null
118+
// then the mutable_data_segment index is specified by
119+
// tensor_info.mutable_data_segments_index.
120+
data_buffer_idx:uint;
86121

87122
// [Optional] preallocation details for non-constants (null otherwise).
88123
allocation_info:AllocationDetails;
@@ -102,7 +137,11 @@ table Tensor {
102137
//
103138
// 3. dynamism == DYNAMIC_UNBOUND: the stored sizes field can be ignored since
104139
// shape is fully dynamic.
105-
shape_dynamism: TensorShapeDynamism;
140+
shape_dynamism:TensorShapeDynamism;
141+
142+
// [Optional] Additional information about the Tensor that is not applicable
143+
// to most tensors.
144+
extra_tensor_info:ExtraTensorInfo;
106145
}
107146

108147
table Int {
@@ -276,9 +315,11 @@ table BackendDelegate {
276315
compile_specs: [CompileSpec];
277316
}
278317

279-
// A sequence of blocking instructions to be executed in order. The abstraction is not currently leveraged,
280-
// all current programs are 1 chain. We are leaving chains as part of the program definition for future
281-
// use cases around graph level async where different threads will be represented as seperate chains.
318+
// A sequence of blocking instructions to be executed in order. The
319+
// abstraction is not currently leveraged, all current programs are 1 chain.
320+
// We are leaving chains as part of the program definition for future use cases
321+
// around graph level async where different threads will be represented as
322+
// seperate chains.
282323
table Chain {
283324
// Indices of the values that are (non-static) inputs into this Chain.
284325
inputs:[int];
@@ -401,7 +442,17 @@ table Program {
401442
// offset. If constant_segment.offsets field is non-empty, constant_buffer
402443
// must be empty. constant_segment.offsets[0] is reserved to be pointed to by
403444
// non-constant Tensors.
404-
constant_segment: SubsegmentOffsets;
445+
constant_segment:SubsegmentOffsets;
446+
447+
// [Optional] Describes the offsets into various segments for each mutable
448+
// tensor. Only mutable tensors with a meaningful initial state are
449+
// serialized here (for example weights that will be trained on-device as
450+
// opposed to just layer activations). Seperate from the constant_segment to
451+
// reduce peak memory usage by letting us read directly from the PTE file
452+
// into the mutable tensor, as opposed to loading the .pte data into
453+
// constant memory, copying it over, and then being unable to release the
454+
// constant segment. No two elements should point to the same segment.
455+
mutable_data_segments:[SubsegmentOffsets];
405456
}
406457

407458
root_type Program;

0 commit comments

Comments
 (0)