Skip to content

Commit eae6134

Browse files
[ET-VK] Using push constants for unary op. (#12418)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12308 by @trivedivivek ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/114/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/114/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/114/orig @diff-train-skip-merge Co-authored-by: Vivek Trivedi <[email protected]>
1 parent 1540659 commit eae6134

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,15 @@ layout(std430) buffer;
2525

2626
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
2727
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
28+
29+
layout(push_constant) uniform restrict Block {
2830
$if STORAGE == "buffer":
29-
${layout_declare_ubo(2, "int", "numel")}
31+
int numel;
3032
$else:
31-
${layout_declare_ubo(2, "ivec3", "out_limits")}
32-
${layout_declare_ubo(3, "float", "minimum")}
33-
${layout_declare_ubo(4, "float", "maximum")}
33+
ivec4 out_limits;
34+
float minimum;
35+
float maximum;
36+
};
3437

3538
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3639

@@ -53,7 +56,7 @@ void main() {
5356
void main() {
5457
const ivec3 pos = ivec3(gl_GlobalInvocationID);
5558

56-
if (any(greaterThanEqual(pos, out_limits))) {
59+
if (any(greaterThanEqual(pos, out_limits.xyz))) {
5760
return;
5861
}
5962

backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,7 @@ void add_unary_op_node(
4343
add_dtype_suffix(kernel_name, graph.dtype_of(out));
4444
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
4545

46-
vkapi::ParamsBindList ubos({});
47-
if (graph.is_buffer_storage(out)) {
48-
ubos.append({graph.numel_ubo(out)});
49-
} else {
50-
ubos.append({graph.logical_limits_ubo(out)});
51-
}
52-
ubos.append(
53-
{graph.create_params_buffer(min), graph.create_params_buffer(max)});
54-
46+
const utils::vec2 min_max = {min, max};
5547
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
5648
graph,
5749
VK_KERNEL_FROM_STR(kernel_name),
@@ -60,9 +52,14 @@ void add_unary_op_node(
6052
// Inputs and Outputs
6153
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
6254
// Shader params buffers
63-
ubos,
64-
// Push Constants
6555
{},
56+
// Push Constants
57+
{
58+
graph.is_buffer_storage(out) ? graph.numel_pc_of(out)
59+
: graph.logical_limits_pc_of(out),
60+
PushConstantDataInfo(&min_max, sizeof(min_max)),
61+
},
62+
// pcs,
6663
// Specialization Constants
6764
{},
6865
// Resize Args

0 commit comments

Comments
 (0)