Skip to content

Commit f8572ef

Browse files
authored
[ET-VK] Tuning local workgroup size calculation for conv2d pw to improve performance.
Differential Revision: D75420517 Pull Request resolved: #11135
1 parent a0cfa86 commit f8572ef

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,21 @@ void add_conv2d_node(
404404
wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1};
405405
}
406406

407+
utils::uvec3 local_wg_size;
408+
if (method == Conv2dMethod::Pointwise) {
409+
uint32_t local_wg_size_y = 1;
410+
if (wg_size[1] % 8 == 0) {
411+
local_wg_size_y = 8;
412+
} else if (wg_size[1] % 4 == 0) {
413+
local_wg_size_y = 4;
414+
} else if (wg_size[1] % 2 == 0) {
415+
local_wg_size_y = 2;
416+
}
417+
local_wg_size = {64 / local_wg_size_y, local_wg_size_y, 1};
418+
} else {
419+
local_wg_size = graph.create_local_wg_size(wg_size);
420+
}
421+
407422
vkapi::ParamsBindList param_buffers;
408423
std::vector<PushConstantDataInfo> push_constants;
409424
if (method == Conv2dMethod::Pointwise) {
@@ -464,7 +479,7 @@ void add_conv2d_node(
464479
graph,
465480
shader,
466481
wg_size,
467-
graph.create_local_wg_size(wg_size),
482+
local_wg_size,
468483
// Inputs and Outputs
469484
{{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}},
470485
// Shader params buffers

0 commit comments

Comments
 (0)