Skip to content

Commit 7304def

Browse files
trivedivivekfacebook-github-bot
authored andcommitted
Use 4x3 tiled shader for linear mat mul which performs slightly better. (#15988)
Summary: This diff optimizes the performance of the quantized linear matrix multiplication operation by using a 4x3 tiled shader, which performs slightly better than the previous implementation. Reviewed By: yipjustin Differential Revision: D87902847
1 parent 12d17ef commit 7304def

File tree

3 files changed

+25
-25
lines changed

3 files changed

+25
-25
lines changed

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,15 @@ void main() {
121121
packed_weight_tex = texelFetch(
122122
t_weight, ivec2(weight_txcol + ${c}, pos + r), 0);
123123

124-
const uvec4 tmp1 = packed_weight_tex >> 4;
125-
qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(tmp1.x);
126-
qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(tmp1.y);
127-
qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(tmp1.z);
128-
qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(tmp1.w);
129-
130-
const uvec4 tmp2 = packed_weight_tex & 0x0F;
131-
qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(tmp2.x);
132-
qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(tmp2.y);
133-
qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(tmp2.z);
134-
qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(tmp2.w);
124+
qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(packed_weight_tex.x >> 4);
125+
qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(packed_weight_tex.y >> 4);
126+
qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(packed_weight_tex.z >> 4);
127+
qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(packed_weight_tex.w >> 4);
128+
129+
qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(packed_weight_tex.x & 0xF);
130+
qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(packed_weight_tex.y & 0xF);
131+
qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(packed_weight_tex.z & 0xF);
132+
qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(packed_weight_tex.w & 0xF);
135133
$else:
136134
$for c in range(TILE_TXCOLS):
137135
$if WEIGHT_STORAGE == "buffer":

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ linear_qcsnw_tiled:
2020
SUFFIX: o4x1
2121
- VALUE: 2
2222
SUFFIX: o4x2
23+
- VALUE: 3
24+
SUFFIX: o4x3
2325
- VALUE: 4
2426
SUFFIX: o4x4
2527
shader_variants:

backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,15 @@ utils::uvec3 linear_qcsnw_tiled_global_wg_size(
6161

6262
std::vector<int64_t> mat1_sizes = graph->sizes_of(mat1);
6363
const int64_t M = utils::val_at(-2, mat1_sizes);
64-
uint32_t out_tile_nrows = 4;
65-
if (M % 6 == 0) {
66-
out_tile_nrows = 2;
64+
uint32_t out_tile_nrows = 1;
65+
if (M % 3 == 0) {
66+
out_tile_nrows = 3;
6767
} else if (M % 4 == 0) {
6868
out_tile_nrows = 4;
69-
} else if (M % 1 == 0) {
70-
out_tile_nrows = 1;
69+
} else if (M % 2 == 0) {
70+
out_tile_nrows = 2;
7171
} else {
72-
out_tile_nrows = 4;
72+
out_tile_nrows = 1;
7373
}
7474

7575
// Number of output texels in the output tile
@@ -309,19 +309,19 @@ void add_linear_qcsnw_tiled_node(
309309

310310
std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1);
311311
const int64_t M = utils::val_at(-2, mat1_sizes);
312-
uint32_t out_tile_nrows = 4;
313-
if (M % 6 == 0) {
314-
kernel_name += "_o4x2";
315-
out_tile_nrows = 2;
312+
uint32_t out_tile_nrows = 1;
313+
if (M % 3 == 0) {
314+
kernel_name += "_o4x3";
315+
out_tile_nrows = 3;
316316
} else if (M % 4 == 0) {
317317
kernel_name += "_o4x4";
318318
out_tile_nrows = 4;
319-
} else if (M % 1 == 0) {
319+
} else if (M % 2 == 0) {
320+
kernel_name += "_o4x2";
321+
out_tile_nrows = 2;
322+
} else {
320323
kernel_name += "_o4x1";
321324
out_tile_nrows = 1;
322-
} else {
323-
kernel_name += "_o4x4";
324-
out_tile_nrows = 4;
325325
}
326326

327327
// Number of output texels in the output tile

0 commit comments

Comments
 (0)