From ef53422b809f3203f9911e7e4d7405e82f6cfe46 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi Date: Wed, 26 Nov 2025 13:16:34 -0800 Subject: [PATCH] Use 4x3 tiled shader for linear mat mul which performs slightly better. (#15988) Summary: This diff optimizes the performance of the quantized linear matrix multiplication operation by using a 4x3 tiled shader, which performs slightly better than the previous implementation. Reviewed By: yipjustin Differential Revision: D87902847 --- .../graph/ops/glsl/linear_qcsnw_tiled.glsl | 20 ++++++------- .../graph/ops/glsl/linear_qcsnw_tiled.yaml | 2 ++ .../graph/ops/impl/QuantizedLinearQCSNW.cpp | 28 +++++++++---------- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl index c364e70bc9f..1e5de21cffc 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl @@ -121,17 +121,15 @@ void main() { packed_weight_tex = texelFetch( t_weight, ivec2(weight_txcol + ${c}, pos + r), 0); - const uvec4 tmp1 = packed_weight_tex >> 4; - qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(tmp1.x); - qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(tmp1.y); - qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(tmp1.z); - qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(tmp1.w); - - const uvec4 tmp2 = packed_weight_tex & 0x0F; - qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(tmp2.x); - qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(tmp2.y); - qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(tmp2.z); - qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(tmp2.w); + qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(packed_weight_tex.x >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(packed_weight_tex.y >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(packed_weight_tex.z >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(packed_weight_tex.w >> 4); + + qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(packed_weight_tex.x & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(packed_weight_tex.y & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(packed_weight_tex.z & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(packed_weight_tex.w & 0xF); $else: $for c in range(TILE_TXCOLS): $if WEIGHT_STORAGE == "buffer": diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml index 287b2ee9333..81824a12026 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml @@ -20,6 +20,8 @@ linear_qcsnw_tiled: SUFFIX: o4x1 - VALUE: 2 SUFFIX: o4x2 + - VALUE: 3 + SUFFIX: o4x3 - VALUE: 4 SUFFIX: o4x4 shader_variants: diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp index e4e08363c6d..18958ccc3ce 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp @@ -61,15 +61,15 @@ utils::uvec3 linear_qcsnw_tiled_global_wg_size( std::vector mat1_sizes = graph->sizes_of(mat1); const int64_t M = utils::val_at(-2, mat1_sizes); - uint32_t out_tile_nrows = 4; - if (M % 6 == 0) { - out_tile_nrows = 2; + uint32_t out_tile_nrows = 1; + if (M % 3 == 0) { + out_tile_nrows = 3; } else if (M % 4 == 0) { out_tile_nrows = 4; - } else if (M % 1 == 0) { - out_tile_nrows = 1; + } else if (M % 2 == 0) { + out_tile_nrows = 2; } else { - out_tile_nrows = 4; + out_tile_nrows = 1; } // Number of output texels in the output tile @@ -309,19 +309,19 @@ void add_linear_qcsnw_tiled_node( std::vector mat1_sizes = graph.sizes_of(mat1); const int64_t M = utils::val_at(-2, mat1_sizes); - uint32_t out_tile_nrows = 4; - if (M % 6 == 0) { - kernel_name += "_o4x2"; - out_tile_nrows = 2; + uint32_t out_tile_nrows = 1; + if (M % 3 == 0) { + kernel_name += "_o4x3"; + out_tile_nrows = 3; } else if (M % 4 == 0) { kernel_name += "_o4x4"; out_tile_nrows = 4; - } else if (M % 1 == 0) { + } else if (M % 2 == 0) { + kernel_name += "_o4x2"; + out_tile_nrows = 2; + } else { kernel_name += "_o4x1"; out_tile_nrows = 1; - } else { - kernel_name += "_o4x4"; - out_tile_nrows = 4; } // Number of output texels in the output tile