Skip to content

Commit c32d209

Browse files
pytorchbotSS-JIA
andauthored
[ET-VK][ez] Explicitly skip marking output nodes that are mutable buffers (#12020)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11983 by @SS-JIA ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/249/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/249/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/249/orig @diff-train-skip-merge Co-authored-by: Stephen Jia <[email protected]>
1 parent c5ecea6 commit c32d209

File tree

5 files changed

+34
-7
lines changed

5 files changed

+34
-7
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -359,15 +359,11 @@ class GraphBuilder {
359359
vkFn(*compute_graph_, args);
360360
}
361361

362-
// Parse the outputs, which will be mostly tensors. For some reason,
363-
// mutable buffers are shown to be returned in the fx.Graph but do not get
364-
// returned by the delegate; this may be an implementation detail of how the
365-
// executorch emitter handles mutable buffers.
362+
// Parse the outputs, which will be mostly tensors but may contain tensorref
363+
// values as well if the source graph returns parameter nodes.
366364
for (const uint32_t fb_id : *flatbuffer_->output_ids()) {
367365
const ValueRef ref = get_fb_id_valueref(fb_id);
368-
if (compute_graph_->val_is_tensor(ref)) {
369-
compute_graph_->set_output_tensor(ref);
370-
}
366+
compute_graph_->set_output_value(ref);
371367
}
372368

373369
if (compute_graph_->graphconfig().enable_querypool) {
@@ -609,6 +605,12 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
609605
compute_graph->outputs()[i].staging,
610606
args[o]->toTensor().mutable_data_ptr(),
611607
args[o]->toTensor().numel());
608+
}
609+
// TensorRef values represent constant tensors which will not have been
610+
// modified by the graph execution. Therefore, if a constant tensor is
611+
// returned as an output, no action is required.
612+
else if (compute_graph->val_is_tref(oref)) {
613+
continue;
612614
} else {
613615
VK_THROW(
614616
"Could not handle output with type ",

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,14 @@ ValueRef ComputeGraph::set_output_tensor(
519519
return idx;
520520
}
521521

522+
ValueRef ComputeGraph::set_output_value(const ValueRef idx) {
523+
if (values_.at(idx).isTensor()) {
524+
return set_output_tensor(idx);
525+
}
526+
outputs_.push_back({idx, kDummyValueRef});
527+
return idx;
528+
}
529+
522530
vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
523531
const ValueRef idx) {
524532
if (values_.at(idx).isInt()) {

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,8 @@ class ComputeGraph final {
658658
ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true);
659659
ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true);
660660

661+
ValueRef set_output_value(const ValueRef idx);
662+
661663
template <typename Block>
662664
vkapi::BufferBindInfo create_params_buffer(const Block& data) {
663665
param_ubos_.emplace_back(api::ParamsBuffer(context_.get(), data));

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from executorch.backends.vulkan.utils import (
2121
is_constant,
2222
is_get_attr_node,
23+
is_mutable_buffer_node,
2324
is_param_node,
2425
is_symint_node,
2526
)
@@ -382,6 +383,11 @@ def process_output_node(self, node: Node) -> None:
382383
"the output node is being serialized before its corresponding "
383384
"internal node which is not allowed."
384385
)
386+
# Mutable buffers outputs are not included as an output to the
387+
# delegate call. Skip marking them as an output.
388+
if is_mutable_buffer_node(out_node, self.program):
389+
continue
390+
385391
self.output_ids.append(self.node_to_value_ids[out_node])
386392

387393
def process_node(self, node: Node, call_node_debug_hdl: int) -> None:

backends/vulkan/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,15 @@ def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool:
8484
)
8585

8686

87+
def is_mutable_buffer_node(
88+
node: torch.fx.Node, exported_program: ExportedProgram
89+
) -> bool:
90+
if node.target not in exported_program.graph_signature.inputs_to_buffers:
91+
return False
92+
buf = exported_program.graph_signature.inputs_to_buffers[node.target]
93+
return buf in exported_program.graph_signature.buffers_to_mutate.values()
94+
95+
8796
def is_symint_node(node: torch.fx.Node) -> bool:
8897
"""
8998
Returns true if the given node produces a SymInt value

0 commit comments

Comments
 (0)