Skip to content

Commit 1356e32

Browse files
trivedivivekfacebook-github-bot
authored andcommitted
Some minor performance improvements to buffer 4b mat mul. (pytorch#15989)
Summary: The code change in this diff aims to improve the performance of buffer 4b matrix multiplication by reducing unnecessary computations and by spreading operations to allow better latency hiding. Reviewed By: yipjustin Differential Revision: D87910988
1 parent d92c91a commit 1356e32

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ void main() {
7878

7979
const int in_row_txstride = div4(in_sizes.x);
8080

81+
$if WEIGHT_STORAGE == "buffer":
82+
$if QUANT_NBITS == 4:
83+
uint qmat2_bufi = weight_txcol;
84+
$else:
85+
uint qmat2_bufi = out_txcol;
86+
8187
for (int pos = 0, txpos = 0;
8288
txpos < in_row_txstride;
8389
pos += 4, txpos += 1) {
@@ -99,7 +105,6 @@ void main() {
99105
}
100106

101107
$if WEIGHT_STORAGE == "buffer":
102-
uint qmat2_bufi;
103108
uint weight_row_txstride = div4(weight_sizes.x);
104109
uint encoded_weight;
105110

@@ -114,26 +119,31 @@ void main() {
114119
$if QUANT_NBITS == 4:
115120
$for c in range(0, TILE_TXCOLS, 2):
116121
$if WEIGHT_STORAGE == "buffer":
117-
qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol;
118122
encoded_weight = t_weight[qmat2_bufi + ${c}];
119-
packed_weight_tex = uvec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24);
123+
qmat2[${c} * 4 * TILE_TXCOLS + 0] = T((encoded_weight >> 4) & 0xF);
124+
qmat2[${c} * 4 * TILE_TXCOLS + 1] = T((encoded_weight >> 12) & 0xF);
125+
qmat2[${c} * 4 * TILE_TXCOLS + 2] = T((encoded_weight >> 20) & 0xF);
126+
qmat2[${c} * 4 * TILE_TXCOLS + 3] = T((encoded_weight >> 28));
127+
128+
qmat2[${c} * 4 * TILE_TXCOLS + 4] = T((encoded_weight) & 0xF);
129+
qmat2[${c} * 4 * TILE_TXCOLS + 5] = T((encoded_weight >> 8) & 0xF);
130+
qmat2[${c} * 4 * TILE_TXCOLS + 6] = T((encoded_weight >> 16) & 0xF);
131+
qmat2[${c} * 4 * TILE_TXCOLS + 7] = T((encoded_weight >> 24) & 0xF);
120132
$else:
121133
packed_weight_tex = texelFetch(
122134
t_weight, ivec2(weight_txcol + ${c}, pos + r), 0);
123-
124-
qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(packed_weight_tex.x >> 4);
125-
qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(packed_weight_tex.y >> 4);
126-
qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(packed_weight_tex.z >> 4);
127-
qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(packed_weight_tex.w >> 4);
128-
129-
qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(packed_weight_tex.x & 0xF);
130-
qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(packed_weight_tex.y & 0xF);
131-
qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(packed_weight_tex.z & 0xF);
132-
qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(packed_weight_tex.w & 0xF);
135+
qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(packed_weight_tex.x >> 4);
136+
qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(packed_weight_tex.y >> 4);
137+
qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(packed_weight_tex.z >> 4);
138+
qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(packed_weight_tex.w >> 4);
139+
140+
qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(packed_weight_tex.x & 0xF);
141+
qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(packed_weight_tex.y & 0xF);
142+
qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(packed_weight_tex.z & 0xF);
143+
qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(packed_weight_tex.w & 0xF);
133144
$else:
134145
$for c in range(TILE_TXCOLS):
135146
$if WEIGHT_STORAGE == "buffer":
136-
qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol;
137147
encoded_weight = t_weight[qmat2_bufi + ${c}];
138148
packed_weight_tex = ivec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24);
139149
$else:
@@ -146,6 +156,8 @@ void main() {
146156
$for j in range(4):
147157
sums[tr * TILE_TXCOLS * 4 + ${c} * 4 + ${j}] += qmat2[${c} * 4 + ${j}] * mat1[tr * 4 + r];
148158
}
159+
$if WEIGHT_STORAGE == "buffer":
160+
qmat2_bufi += weight_row_txstride;
149161
}
150162
}
151163

0 commit comments

Comments
 (0)