Skip to content

Commit dc72ab8

Browse files
committed
[ET-VK] Using push constants for unary op.
Pull Request resolved: #12308 This diff transitions the unary op to utilize push constants, replacing the previous ubo implementation. ghstack-source-id: 295513701 Differential Revision: [D77706459](https://our.internmc.facebook.com/intern/diff/D77706459/)
1 parent 97a61f4 commit dc72ab8

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,15 @@ layout(std430) buffer;
2525

2626
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
2727
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
28+
29+
layout(push_constant) uniform restrict Block {
2830
$if STORAGE == "buffer":
29-
${layout_declare_ubo(2, "int", "numel")}
31+
int numel;
3032
$else:
31-
${layout_declare_ubo(2, "ivec3", "out_limits")}
32-
${layout_declare_ubo(3, "float", "minimum")}
33-
${layout_declare_ubo(4, "float", "maximum")}
33+
ivec4 out_limits;
34+
float minimum;
35+
float maximum;
36+
};
3437

3538
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3639

@@ -53,7 +56,7 @@ void main() {
5356
void main() {
5457
const ivec3 pos = ivec3(gl_GlobalInvocationID);
5558

56-
if (any(greaterThanEqual(pos, out_limits))) {
59+
if (any(greaterThanEqual(pos, out_limits.xyz))) {
5760
return;
5861
}
5962

backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,7 @@ void add_unary_op_node(
4343
add_dtype_suffix(kernel_name, graph.dtype_of(out));
4444
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
4545

46-
vkapi::ParamsBindList ubos({});
47-
if (graph.is_buffer_storage(out)) {
48-
ubos.append({graph.numel_ubo(out)});
49-
} else {
50-
ubos.append({graph.logical_limits_ubo(out)});
51-
}
52-
ubos.append(
53-
{graph.create_params_buffer(min), graph.create_params_buffer(max)});
54-
46+
const utils::vec2 min_max = {min, max};
5547
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
5648
graph,
5749
VK_KERNEL_FROM_STR(kernel_name),
@@ -60,9 +52,14 @@ void add_unary_op_node(
6052
// Inputs and Outputs
6153
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
6254
// Shader params buffers
63-
ubos,
64-
// Push Constants
6555
{},
56+
// Push Constants
57+
{
58+
graph.is_buffer_storage(out) ? graph.numel_pc_of(out)
59+
: graph.logical_limits_pc_of(out),
60+
PushConstantDataInfo(&min_max, sizeof(min_max)),
61+
},
62+
// pcs,
6663
// Specialization Constants
6764
{},
6865
// Resize Args

0 commit comments

Comments
 (0)