Skip to content
17 changes: 16 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,21 @@ void add_conv2d_node(
wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1};
}

utils::uvec3 local_wg_size;
if (method == Conv2dMethod::Pointwise) {
uint32_t local_wg_size_y = 1;
if (wg_size[1] % 8 == 0) {
local_wg_size_y = 8;
} else if (wg_size[1] % 4 == 0) {
local_wg_size_y = 4;
} else if (wg_size[1] % 2 == 0) {
local_wg_size_y = 2;
}
local_wg_size = {64 / local_wg_size_y, local_wg_size_y, 1};
} else {
local_wg_size = graph.create_local_wg_size(wg_size);
}

vkapi::ParamsBindList param_buffers;
std::vector<PushConstantDataInfo> push_constants;
if (method == Conv2dMethod::Pointwise) {
Expand Down Expand Up @@ -464,7 +479,7 @@ void add_conv2d_node(
graph,
shader,
wg_size,
graph.create_local_wg_size(wg_size),
local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}},
// Shader params buffers
Expand Down
Loading