@@ -78,6 +78,12 @@ void main() {
7878
7979 const int in_row_txstride = div4(in_sizes.x);
8080
81+ $if WEIGHT_STORAGE == "buffer ":
82+ $if QUANT_NBITS == 4 :
83+ uint qmat2_bufi = weight_txcol;
84+ $else :
85+ uint qmat2_bufi = out_txcol;
86+
8187 for (int pos = 0 , txpos = 0 ;
8288 txpos < in_row_txstride;
8389 pos += 4 , txpos += 1 ) {
@@ -99,7 +105,6 @@ void main() {
99105 }
100106
101107 $if WEIGHT_STORAGE == "buffer ":
102- uint qmat2_bufi;
103108 uint weight_row_txstride = div4(weight_sizes.x);
104109 uint encoded_weight;
105110
@@ -114,26 +119,31 @@ void main() {
114119 $if QUANT_NBITS == 4 :
115120 $for c in range(0 , TILE_TXCOLS, 2 ):
116121 $if WEIGHT_STORAGE == "buffer ":
117- qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol;
118122 encoded_weight = t_weight[qmat2_bufi + ${c}];
119- packed_weight_tex = uvec4 (encoded_weight & 0xFF, (encoded_weight >> 8 ) & 0xFF, (encoded_weight >> 16 ) & 0xFF, encoded_weight >> 24 );
123+ qmat2[${c} * 4 * TILE_TXCOLS + 0 ] = T((encoded_weight >> 4 ) & 0xF);
124+ qmat2[${c} * 4 * TILE_TXCOLS + 1 ] = T((encoded_weight >> 12 ) & 0xF);
125+ qmat2[${c} * 4 * TILE_TXCOLS + 2 ] = T((encoded_weight >> 20 ) & 0xF);
126+ qmat2[${c} * 4 * TILE_TXCOLS + 3 ] = T((encoded_weight >> 28 ));
127+
128+ qmat2[${c} * 4 * TILE_TXCOLS + 4 ] = T((encoded_weight) & 0xF);
129+ qmat2[${c} * 4 * TILE_TXCOLS + 5 ] = T((encoded_weight >> 8 ) & 0xF);
130+ qmat2[${c} * 4 * TILE_TXCOLS + 6 ] = T((encoded_weight >> 16 ) & 0xF);
131+ qmat2[${c} * 4 * TILE_TXCOLS + 7 ] = T((encoded_weight >> 24 ) & 0xF);
120132 $else :
121133 packed_weight_tex = texelFetch(
122134 t_weight, ivec2 (weight_txcol + ${c}, pos + r), 0 );
123-
124- qmat2[${c} * 4 * TILE_TXCOLS + 0 ] = T(packed_weight_tex.x >> 4 );
125- qmat2[${c} * 4 * TILE_TXCOLS + 1 ] = T(packed_weight_tex.y >> 4 );
126- qmat2[${c} * 4 * TILE_TXCOLS + 2 ] = T(packed_weight_tex.z >> 4 );
127- qmat2[${c} * 4 * TILE_TXCOLS + 3 ] = T(packed_weight_tex.w >> 4 );
128-
129- qmat2[${c} * 4 * TILE_TXCOLS + 4 ] = T(packed_weight_tex.x & 0xF);
130- qmat2[${c} * 4 * TILE_TXCOLS + 5 ] = T(packed_weight_tex.y & 0xF);
131- qmat2[${c} * 4 * TILE_TXCOLS + 6 ] = T(packed_weight_tex.z & 0xF);
132- qmat2[${c} * 4 * TILE_TXCOLS + 7 ] = T(packed_weight_tex.w & 0xF);
135+ qmat2[${c} * 4 * TILE_TXCOLS + 0 ] = T(packed_weight_tex.x >> 4 );
136+ qmat2[${c} * 4 * TILE_TXCOLS + 1 ] = T(packed_weight_tex.y >> 4 );
137+ qmat2[${c} * 4 * TILE_TXCOLS + 2 ] = T(packed_weight_tex.z >> 4 );
138+ qmat2[${c} * 4 * TILE_TXCOLS + 3 ] = T(packed_weight_tex.w >> 4 );
139+
140+ qmat2[${c} * 4 * TILE_TXCOLS + 4 ] = T(packed_weight_tex.x & 0xF);
141+ qmat2[${c} * 4 * TILE_TXCOLS + 5 ] = T(packed_weight_tex.y & 0xF);
142+ qmat2[${c} * 4 * TILE_TXCOLS + 6 ] = T(packed_weight_tex.z & 0xF);
143+ qmat2[${c} * 4 * TILE_TXCOLS + 7 ] = T(packed_weight_tex.w & 0xF);
133144 $else :
134145 $for c in range(TILE_TXCOLS):
135146 $if WEIGHT_STORAGE == "buffer ":
136- qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol;
137147 encoded_weight = t_weight[qmat2_bufi + ${c}];
138148 packed_weight_tex = ivec4 (encoded_weight & 0xFF, (encoded_weight >> 8 ) & 0xFF, (encoded_weight >> 16 ) & 0xFF, encoded_weight >> 24 );
139149 $else :
@@ -146,6 +156,8 @@ void main() {
146156 $for j in range(4 ):
147157 sums[tr * TILE_TXCOLS * 4 + ${c} * 4 + ${j}] += qmat2[${c} * 4 + ${j}] * mat1[tr * 4 + r];
148158 }
159+ $if WEIGHT_STORAGE == "buffer ":
160+ qmat2_bufi += weight_row_txstride;
149161 }
150162 }
151163
0 commit comments