Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,15 @@ void main() {
packed_weight_tex = texelFetch(
t_weight, ivec2(weight_txcol + ${c}, pos + r), 0);

const uvec4 tmp1 = packed_weight_tex >> 4;
qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(tmp1.x);
qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(tmp1.y);
qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(tmp1.z);
qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(tmp1.w);

const uvec4 tmp2 = packed_weight_tex & 0x0F;
qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(tmp2.x);
qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(tmp2.y);
qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(tmp2.z);
qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(tmp2.w);
qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(packed_weight_tex.x >> 4);
qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(packed_weight_tex.y >> 4);
qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(packed_weight_tex.z >> 4);
qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(packed_weight_tex.w >> 4);

qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(packed_weight_tex.x & 0xF);
qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(packed_weight_tex.y & 0xF);
qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(packed_weight_tex.z & 0xF);
qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(packed_weight_tex.w & 0xF);
$else:
$for c in range(TILE_TXCOLS):
$if WEIGHT_STORAGE == "buffer":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ linear_qcsnw_tiled:
SUFFIX: o4x1
- VALUE: 2
SUFFIX: o4x2
- VALUE: 3
SUFFIX: o4x3
- VALUE: 4
SUFFIX: o4x4
shader_variants:
Expand Down
28 changes: 14 additions & 14 deletions backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ utils::uvec3 linear_qcsnw_tiled_global_wg_size(

std::vector<int64_t> mat1_sizes = graph->sizes_of(mat1);
const int64_t M = utils::val_at(-2, mat1_sizes);
uint32_t out_tile_nrows = 4;
if (M % 6 == 0) {
out_tile_nrows = 2;
uint32_t out_tile_nrows = 1;
if (M % 3 == 0) {
out_tile_nrows = 3;
} else if (M % 4 == 0) {
out_tile_nrows = 4;
} else if (M % 1 == 0) {
out_tile_nrows = 1;
} else if (M % 2 == 0) {
out_tile_nrows = 2;
} else {
out_tile_nrows = 4;
out_tile_nrows = 1;
}

// Number of output texels in the output tile
Expand Down Expand Up @@ -309,19 +309,19 @@ void add_linear_qcsnw_tiled_node(

std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1);
const int64_t M = utils::val_at(-2, mat1_sizes);
uint32_t out_tile_nrows = 4;
if (M % 6 == 0) {
kernel_name += "_o4x2";
out_tile_nrows = 2;
uint32_t out_tile_nrows = 1;
if (M % 3 == 0) {
kernel_name += "_o4x3";
out_tile_nrows = 3;
} else if (M % 4 == 0) {
kernel_name += "_o4x4";
out_tile_nrows = 4;
} else if (M % 1 == 0) {
} else if (M % 2 == 0) {
kernel_name += "_o4x2";
out_tile_nrows = 2;
} else {
kernel_name += "_o4x1";
out_tile_nrows = 1;
} else {
kernel_name += "_o4x4";
out_tile_nrows = 4;
}

// Number of output texels in the output tile
Expand Down