diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.glsl new file mode 100644 index 00000000000..174ea1cc9bb --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.glsl @@ -0,0 +1,232 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} +#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} + +#define TILE_ROWS ${TILE_ROWS} + +#define NGROUPS 8 +#define NWORKERS 8 + +${define_required_extensions(DTYPE)} +$if IN_STORAGE == "buffer": + ${define_required_extensions("int8")} +$if WEIGHT_STORAGE == "buffer": + ${define_required_extensions("uint8")} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_mat1", "int8", IN_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", "float", PARAMS_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_zeros", "int", PARAMS_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input_scale", "float", PARAMS_STORAGE, is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_input_zero_point", "int", PARAMS_STORAGE, is_scalar_array=True)} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 mat1_sizes; + ivec4 qmat2_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int group_size = 64; + +shared vec4 partial_results[NGROUPS][NWORKERS][TILE_ROWS][2]; + +/* + * This shader computes a linear operator between a quantized int8 input matrix + * x and a weights matrix that is quantized to 4 bits, producing a float output. + * + * This shader implements a co-operative algorithm to compute the output. The + * work group size is {NGROUP, 1, NWORKERS}, and each group of NWORKERS threads + * cooperative to compute TILE_ROWS * 2 output texels. Therefore, + * NGROUP * TILE_ROWS * 2 output texels are computed across one work group. + * + * The threads co-operate by each thread computing a partial reduction along the + * K dimension. To illustrate the computation, consider a scalar variant of the + * algorithm that computes the dot product of 2 vectors. Also assume that + * NWORKERS is 8. + * + * Thread 1 in each group will compute: + * (mat1[0] * mat2[0]) + (mat1[8] * mat2[8]) + (mat1[16] * mat2[16]) + ... + * + * Thread 2 in each group will compute: + * (mat1[1] * mat2[1]) + (mat2[9] * mat2[9]) + (mat1[17] * mat2[17]) + ... + * + * Thread 3 in each group will compute: + * (mat1[2] * mat2[2]) + (mat2[10] * mat2[10]) + (mat1[18] * mat2[18]) + ... + * + * The partial accumulations is structured such that memory accesses in each + * loop iteration can be coalesced. + * + * Then, at the end first thread in each group will accumulate the partial + * accumulations computed by each thread to obtain the final result. + * + * Note that this shader assumes that all tensors are width packed. + */ + +void main() { + const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; + const uint out_col = gl_GlobalInvocationID.x << 3; + const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1; + + const uint gid = gl_LocalInvocationID.x; // group id + const uint wid = gl_LocalInvocationID.z; // worker id + + if (out_col >= out_sizes.x || out_row >= out_sizes.y) { + return; + } + + const int num_blocks = mat1_sizes.x / group_size; + + ivec4 mat1_quantized[TILE_ROWS]; + ivec4 qmat2_quantized[4][2]; + vec4 final_result[TILE_ROWS][2]; + + // Initialize accumulators + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + final_result[r][0] = vec4(0.0); + final_result[r][1] = vec4(0.0); + } + + vec4 scales[2]; + vec4 zeros[2]; + + $if WEIGHT_STORAGE == "buffer": + const int qmat2_stride = qmat2_sizes.x >> 2; + $if PARAMS_STORAGE == "buffer": + const int qparams_stride = out_sizes.x >> 2; + + for (int block_idx = 0; block_idx < num_blocks; ++block_idx) { + $if PARAMS_STORAGE == "buffer": + scales[0] = t_weight_scales[block_idx * qparams_stride + out_col_texel_idx]; + scales[1] = t_weight_scales[block_idx * qparams_stride + out_col_texel_idx + 1]; + + zeros[0] = vec4(t_weight_zeros[block_idx * qparams_stride + out_col_texel_idx]); + zeros[1] = vec4(t_weight_zeros[block_idx * qparams_stride + out_col_texel_idx + 1]); + $else: + scales[0] = texelFetch(t_weight_scales, ivec3(out_col_texel_idx, block_idx, 0), 0); + scales[1] = texelFetch(t_weight_scales, ivec3(out_col_texel_idx + 1, block_idx, 0), 0); + + zeros[0] = vec4(texelFetch(t_weight_zeros, ivec3(out_col_texel_idx, block_idx, 0), 0)); + zeros[1] = vec4(texelFetch(t_weight_zeros, ivec3(out_col_texel_idx + 1, block_idx, 0), 0)); + + ivec4 int32_sums[TILE_ROWS][2]; + int input_sums[TILE_ROWS]; + + // Initialize accumulators for this block + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + int32_sums[r][0] = ivec4(0); + int32_sums[r][1] = ivec4(0); + input_sums[r] = 0; + } + + for (int g_idx = 4 * int(wid); g_idx < group_size; g_idx += (4 * NWORKERS)) { + const int k = block_idx * group_size + g_idx; + + // Preload B (weights) - keep as quantized integers + [[unroll]] for (int r = 0; r < 4; ++r) { + $if WEIGHT_STORAGE == "buffer": + const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x]; + $else: + const uvec4 packed_weight_tex = texelFetch( + t_qmat2, + ivec2(gl_GlobalInvocationID.x, k + r), + 0); + + // Unpack 4-bit weights to integers and subtract zero point (8 for 4-bit) + qmat2_quantized[r][0] = ivec4((packed_weight_tex & 0xF0) >> 4) - 8; + qmat2_quantized[r][1] = ivec4(packed_weight_tex & 0x0F) - 8; + } + + // Preload A (quantized input) - keep as quantized integers + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $if IN_STORAGE == "buffer": + mat1_quantized[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2] - t_input_zero_point[int(out_row) + r]; + $else: + mat1_quantized[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0) - t_input_zero_point[int(out_row) + r]; + } + + // Accumulate in integer arithmetic: (input_quantized - input_zero_point) * (weight_quantized - weight_zero_point) + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + input_sums[r] += mat1_quantized[r].x + mat1_quantized[r].y + mat1_quantized[r].z + mat1_quantized[r].w; + + int32_sums[r][0] += mat1_quantized[r].x * qmat2_quantized[0][0] + + mat1_quantized[r].y * qmat2_quantized[1][0] + + mat1_quantized[r].z * qmat2_quantized[2][0] + + mat1_quantized[r].w * qmat2_quantized[3][0]; + + int32_sums[r][1] += mat1_quantized[r].x * qmat2_quantized[0][1] + + mat1_quantized[r].y * qmat2_quantized[1][1] + + mat1_quantized[r].z * qmat2_quantized[2][1] + + mat1_quantized[r].w * qmat2_quantized[3][1]; + } + } + + // Incorporates this block's results into the final accumulation + // Following proper quantization paradigm: result = input_scale * weight_scale * + // Sum((input_quantized - input_zero) * (weight_quantized - weight_zero)) + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + if (out_row + r >= out_sizes.y) { + continue; + } + + float input_scale = t_input_scale[int(out_row) + r]; + float input_sum_scalar = float(input_sums[r]); + + // Apply proper quantization paradigm: input_scale * weight_scale * (accumulator - weight_zero * input_sum) + final_result[r][0] += input_scale * scales[0] * (vec4(int32_sums[r][0]) - zeros[0] * input_sum_scalar); + final_result[r][1] += input_scale * scales[1] * (vec4(int32_sums[r][1]) - zeros[1] * input_sum_scalar); + } + } + + // Store worker results in shared memory + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + partial_results[gid][wid][r][0] = final_result[r][0]; + partial_results[gid][wid][r][1] = final_result[r][1]; + } + + memoryBarrierShared(); + barrier(); + + // Only the first worker in each group accumulates and writes output + if (wid != 0) { + return; + } + + vec4 cooperative_result[TILE_ROWS][2]; + + for (int r = 0; r < TILE_ROWS; ++r) { + cooperative_result[r][0] = vec4(0.0); + cooperative_result[r][1] = vec4(0.0); + [[unroll]] for (int worker = 0; worker < NWORKERS; ++worker) { + cooperative_result[r][0] += partial_results[gid][worker][r][0]; + cooperative_result[r][1] += partial_results[gid][worker][r][1]; + } + } + + // Apply final output quantization + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $if OUT_STORAGE == "buffer": + t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = cooperative_result[r][0]; + t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = cooperative_result[r][1]; + $else: + imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), cooperative_result[r][0]); + imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), cooperative_result[r][1]); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.yaml new file mode 100644 index 00000000000..9f6db77094a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.yaml @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_qta8a_qga4w_coop: + parameter_names_with_default_values: + DTYPE: float + OUT_STORAGE: texture3d + IN_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + PARAMS_STORAGE: buffer + TILE_ROWS: 1 + shader_variants: + - NAME: linear_qta8a_qga4w_coop_texture3d_texture3d_texture2d_float + - NAME: linear_qta8a_qga4w_coop_buffer_buffer_texture2d_float + OUT_STORAGE: buffer + IN_STORAGE: buffer + - NAME: linear_qta8a_qga4w_coop_buffer_buffer_buffer_float + OUT_STORAGE: buffer + IN_STORAGE: buffer + WEIGHT_STORAGE: buffer + - NAME: linear_qta8a_qga4w_coop_buffer_texture2d_buffer_float + OUT_STORAGE: buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.glsl new file mode 100644 index 00000000000..dbb7da998f4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.glsl @@ -0,0 +1,196 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} +#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} + +#define TILE_ROWS ${TILE_ROWS} + +${define_required_extensions(DTYPE)} +$if IN_STORAGE == "buffer": + ${define_required_extensions("int8")} +$if WEIGHT_STORAGE == "buffer": + ${define_required_extensions("uint8")} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_mat1", "int8", IN_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", "float", PARAMS_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_zeros", "int", PARAMS_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input_scale", "float", "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_input_zero_point", "int", "buffer", is_scalar_array=True)} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 mat1_sizes; + ivec4 qmat2_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int group_size = 64; + +/* + * This shader computes a linear operator between a quantized int8 input matrix + * x and a weights matrix that is quantized to 4 bits, producing a float output. + * + * The (W, H, C) shape of each tensor is: + * - x: (K, M) - quantized int8 input with per-token quantization + * - weights: (N / 2, K) + * - The weights tensor has a data type of `uint8`. Each element in the tensor + * contains 2 4-bit values packed into a uint8. + * - See the pack_int4_linear_weight_transposed_interleave shader to see more + * details on how the weight tensor is stored. + * - qparams: (2, N, number_of_groups) + * - This tensor contains the scales and zeros quantization parameters for the + * weights tensor. The weight tensor is quantized group-wise, which means + * that every `group_size` elements along the K dimension of the weights + * tensor has independent quantization parameters. Along the width dim, the + * first value contains the scale for the group and the second value + * contains the zero point for the group. + * - input_scale: (num_tokens,) - per-token scale values for input quantization + * - input_zero_point: (num_tokens,) - per-token zero points for input quantization + * - output: (N, M) - float output + * + * Each thread computes a tile of TILE_ROWS * 2 texels of the output tensor. + * + * Note that this shader assumes that all tensors are width packed. + */ + +void main() { + const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; + const uint out_col = gl_GlobalInvocationID.x << 3; + const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1; + + if (out_col >= out_sizes.x || out_row >= out_sizes.y) { + return; + } + + const int num_blocks = mat1_sizes.x / group_size; + + ivec4 mat1_quantized[TILE_ROWS]; + ivec4 qmat2_quantized[4][2]; + vec4 final_result[TILE_ROWS][2]; + + // Initialize accumulatoxrs + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + final_result[r][0] = vec4(0.0); + final_result[r][1] = vec4(0.0); + } + + vec4 scales[2]; + vec4 zeros[2]; + + $if WEIGHT_STORAGE == "buffer": + const int qmat2_stride = qmat2_sizes.x >> 2; + $if PARAMS_STORAGE == "buffer": + const int qparams_stride = out_sizes.x >> 2; + + for (int block_idx = 0; block_idx < num_blocks; ++block_idx) { + $if PARAMS_STORAGE == "buffer": + scales[0] = t_weight_scales[block_idx * qparams_stride + out_col_texel_idx]; + scales[1] = t_weight_scales[block_idx * qparams_stride + out_col_texel_idx + 1]; + + zeros[0] = vec4(t_weight_zeros[block_idx * qparams_stride + out_col_texel_idx]); + zeros[1] = vec4(t_weight_zeros[block_idx * qparams_stride + out_col_texel_idx + 1]); + $else: + scales[0] = texelFetch(t_weight_scales, ivec3(out_col_texel_idx, block_idx, 0), 0); + scales[1] = texelFetch(t_weight_scales, ivec3(out_col_texel_idx + 1, block_idx, 0), 0); + + zeros[0] = vec4(texelFetch(t_weight_zeros, ivec3(out_col_texel_idx, block_idx, 0), 0)); + zeros[1] = vec4(texelFetch(t_weight_zeros, ivec3(out_col_texel_idx + 1, block_idx, 0), 0)); + + ivec4 int32_sums[TILE_ROWS][2]; + int input_sums[TILE_ROWS]; + + // Initialize accumulators + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + int32_sums[r][0] = ivec4(0); + int32_sums[r][1] = ivec4(0); + input_sums[r] = 0; + } + + for (int g_idx = 0; g_idx < group_size; g_idx += 4) { + const int k = block_idx * group_size + g_idx; + + // Preload B (weights) - keep as quantized integers + [[unroll]] for (int r = 0; r < 4; ++r) { + $if WEIGHT_STORAGE == "buffer": + const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x]; + $else: + const uvec4 packed_weight_tex = texelFetch( + t_qmat2, + ivec2(gl_GlobalInvocationID.x, k + r), + 0); + + // Unpack 4-bit weights to integers (subtract 8 as the 4-bit zero point) + qmat2_quantized[r][0] = ivec4((packed_weight_tex & 0xF0) >> 4) - 8; + qmat2_quantized[r][1] = ivec4(packed_weight_tex & 0x0F) - 8; + } + + // Preload A (quantized input) - keep as quantized integers + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $if IN_STORAGE == "buffer": + mat1_quantized[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2] - t_input_zero_point[int(out_row) + r]; + $else: + mat1_quantized[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0) - t_input_zero_point[int(out_row) + r]; + } + + // Accumulate in integer arithmetic: (input_quantized - input_zero_point) * (weight_quantized - weight_zero_point) + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + input_sums[r] += mat1_quantized[r].x + mat1_quantized[r].y + mat1_quantized[r].z + mat1_quantized[r].w; + + int32_sums[r][0] += mat1_quantized[r].x * qmat2_quantized[0][0] + + mat1_quantized[r].y * qmat2_quantized[1][0] + + mat1_quantized[r].z * qmat2_quantized[2][0] + + mat1_quantized[r].w * qmat2_quantized[3][0]; + + int32_sums[r][1] += mat1_quantized[r].x * qmat2_quantized[0][1] + + mat1_quantized[r].y * qmat2_quantized[1][1] + + mat1_quantized[r].z * qmat2_quantized[2][1] + + mat1_quantized[r].w * qmat2_quantized[3][1]; + } + } + + // Incorporates this block's results into the final accumulation + // Following proper quantization paradigm: result = input_scale * weight_scale * + // Sum((input_quantized - input_zero) * (weight_quantized - weight_zero)) + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + if (out_row + r >= out_sizes.y) { + continue; + } + + float input_scale = t_input_scale[int(out_row) + r]; + float input_sum_scalar = float(input_sums[r]); + + // Apply proper quantization paradigm: input_scale * weight_scale * (accumulator - weight_zero * input_sum) + final_result[r][0] += input_scale * scales[0] * (vec4(int32_sums[r][0]) - zeros[0] * input_sum_scalar); + final_result[r][1] += input_scale * scales[1] * (vec4(int32_sums[r][1]) - zeros[1] * input_sum_scalar); + } + } + + // Apply ALL scaling at the very end + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $if OUT_STORAGE == "buffer": + if (out_row + r < out_sizes.y) { + t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = final_result[r][0]; + t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = final_result[r][1]; + } + $else: + imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), final_result[r][0]); + imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), final_result[r][1]); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.yaml new file mode 100644 index 00000000000..c96d693834b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.yaml @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_qta8a_qga4w_tiled: + parameter_names_with_default_values: + DTYPE: float + OUT_STORAGE: texture3d + IN_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + PARAMS_STORAGE: buffer + TILE_ROWS: 3 + shader_variants: + - NAME: linear_qta8a_qga4w_tiled_texture3d_texture3d_texture2d_float + - NAME: linear_qta8a_qga4w_tiled_buffer_buffer_texture2d_float + OUT_STORAGE: buffer + IN_STORAGE: buffer + - NAME: linear_qta8a_qga4w_tiled_buffer_buffer_buffer_float + OUT_STORAGE: buffer + IN_STORAGE: buffer + WEIGHT_STORAGE: buffer + - NAME: linear_qta8a_qga4w_tiled_buffer_texture2d_buffer_float + OUT_STORAGE: buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp new file mode 100644 index 00000000000..a47c58b7ef6 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp @@ -0,0 +1,268 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include + +namespace vkcompute { + +void check_linear_qta8a_qga4w_args( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat1_scale, + const ValueRef mat1_zero_point, + const ValueRef mat2_data, + const ValueRef group_size, + const ValueRef weight_scales, + const ValueRef weight_zeros, + const ValueRef out) { + VK_CHECK_COND(graph.val_is_tensor(mat1)); + VK_CHECK_COND(graph.val_is_tensor(mat1_scale)); + VK_CHECK_COND(graph.val_is_tensor(mat1_zero_point)); + VK_CHECK_COND(graph.val_is_tref(mat2_data)); + VK_CHECK_COND(graph.val_is_tref(weight_scales)); + VK_CHECK_COND(graph.val_is_tref(weight_zeros)); + + VK_CHECK_COND(graph.dim_of(mat1) <= 3); + VK_CHECK_COND(graph.dim_of(mat2_data) == 2); + VK_CHECK_COND(graph.dim_of(weight_scales) == 2); + VK_CHECK_COND(graph.dim_of(weight_zeros) == 2); + + VK_CHECK_COND(graph.size_at(-3, mat1) == 1); + const int K = graph.size_at(-1, mat1); + VK_CHECK_COND(graph.size_at(-1, mat2_data) * 2 == K); + + const int group_size_val = graph.extract_scalar(group_size); + VK_CHECK_COND(K % group_size_val == 0); + // Due to the way weight packing works, group size needs to be a multiple of 8 + VK_CHECK_COND(group_size_val % 8 == 0); + + VK_CHECK_COND(graph.has_standard_axis_map(mat1)); + VK_CHECK_COND(graph.has_standard_axis_map(out)); + + // Check that scale and zero_point tensors are buffer storage with width + // packing + VK_CHECK_COND(graph.is_buffer_storage(mat1_scale)); + VK_CHECK_COND(graph.packed_dim_of(mat1_scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(mat1_zero_point)); + VK_CHECK_COND(graph.packed_dim_of(mat1_zero_point) == WHCN::kWidthDim); + + // Calculate number of tokens for input + int64_t input_num_tokens = 1; + const auto mat1_sizes = graph.sizes_of(mat1); + for (size_t i = 0; i < mat1_sizes.size() - 1; i++) { + input_num_tokens *= mat1_sizes[i]; + } + + // Verify scale and zero_point tensor sizes match number of tokens + const auto mat1_scale_sizes = graph.sizes_of(mat1_scale); + const auto mat1_zero_point_sizes = graph.sizes_of(mat1_zero_point); + + VK_CHECK_COND(mat1_scale_sizes.size() == 1); + VK_CHECK_COND(mat1_zero_point_sizes.size() == 1); + + VK_CHECK_COND(mat1_scale_sizes[0] == input_num_tokens); + VK_CHECK_COND(mat1_zero_point_sizes[0] == input_num_tokens); + + // Verify weight scales and zeros have the same shape + const auto weight_scales_sizes = graph.sizes_of(weight_scales); + const auto weight_zeros_sizes = graph.sizes_of(weight_zeros); + VK_CHECK_COND(weight_scales_sizes == weight_zeros_sizes); +} + +void resize_linear_qta8a_qga4w_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); + vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); + + const int64_t out_cols = utils::val_at(-2, mat1->sizes()); + const int64_t out_rows = utils::val_at(-1, mat2->sizes()) * 2; + + std::vector new_out_sizes(3); + if (mat1->sizes().size() == 2) { + new_out_sizes.resize(2); + new_out_sizes.at(0) = out_cols; + new_out_sizes.at(1) = out_rows; + } else { + new_out_sizes.at(0) = mat1->sizes().at(0); + new_out_sizes.at(1) = out_cols; + new_out_sizes.at(2) = out_rows; + } + + out->virtual_resize(new_out_sizes); +} + +/** + * Determines if the cooperative algorithm should be used based on input tensor + * dimensions. Apply the coop algorithm for vectors (GEMV cases), tiled for + * matrices (GEMM cases). + */ +bool should_use_coop_algorithm_qta8a_qga4w( + ComputeGraph* graph, + const ValueRef& mat1) { + const uint32_t M = graph->size_at(-2, mat1); + // Use coop algorithm for vectors (GEMV), tiled for larger matrices (GEMM) + return M == 1; +} + +vkapi::ShaderInfo pick_linear_qta8a_qga4w_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mat1 = args.at(1).refs.at(0); + const ValueRef mat2 = args.at(1).refs.at(1); + + const bool use_coop_algorithm = + should_use_coop_algorithm_qta8a_qga4w(graph, mat1); + + std::string kernel_name = "linear_qta8a_qga4w"; + if (use_coop_algorithm) { + kernel_name += "_coop"; + } else { + kernel_name += "_tiled"; + } + add_storage_type_suffix(kernel_name, graph->storage_type_of(out)); + add_storage_type_suffix(kernel_name, graph->storage_type_of(mat1)); + add_storage_type_suffix(kernel_name, graph->storage_type_of(mat2)); + add_dtype_suffix(kernel_name, graph->dtype_of(out)); + + return VK_KERNEL_FROM_STR(kernel_name); +} + +utils::uvec3 linear_qta8a_qga4w_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + + const bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; + + // C = 1, H = 2, W = 3 + // global_wg_size = {round_up(C / 2f), round_up(H / 3f), W} --> (2W, 1H, 0C) + // --> {1, 1, 3} global + + utils::uvec3 global_wg_size = graph->logical_limits_of(out); + global_wg_size[0] = utils::div_up(global_wg_size[0], uint32_t(2)); + if (!use_coop_algorithm) { // GEMM - TILED + global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(3)); + } + + return global_wg_size; +} + +utils::uvec3 linear_qta8a_qga4w_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)args; + (void)resize_args; + + const bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; + + utils::uvec3 local_wg_size; + if (use_coop_algorithm) { // GEMV - COOP + local_wg_size = {8, 1, 8}; + } else { // GEMM - TILED + local_wg_size = graph->create_local_wg_size(global_workgroup_size); + } + + return local_wg_size; +} + +void add_linear_qta8a_qga4w_node( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat1_scale, + const ValueRef mat1_zero_point, + const ValueRef mat2_data, + const ValueRef group_size, + const ValueRef weight_scales_data, + const ValueRef weight_zeros_data, + const ValueRef out) { + check_linear_qta8a_qga4w_args( + graph, + mat1, + mat1_scale, + mat1_zero_point, + mat2_data, + group_size, + weight_scales_data, + weight_zeros_data, + out); + const uint32_t group_size_val = graph.extract_scalar(group_size); + + ValueRef mat2 = + prepack_int4_linear_weight_transposed_interleaved(graph, mat2_data); + ValueRef weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + ValueRef weight_zeros = prepack_standard( + graph, weight_zeros_data, utils::kBuffer, utils::kWidthPacked); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + pick_linear_qta8a_qga4w_shader, + linear_qta8a_qga4w_global_wg_size, + linear_qta8a_qga4w_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, + {{mat1, mat2, weight_scales, weight_zeros, mat1_scale, mat1_zero_point}, + vkapi::kRead}}, + // Shader params buffers + {}, + // Push Constants + {graph.sizes_pc_of(out), + graph.sizes_pc_of(mat1), + graph.sizes_pc_of(mat2)}, + // Specialization Constants + {SV(group_size_val)}, + // Resize Args + {}, + // Resizing Logic + resize_linear_qta8a_qga4w_node)); +} + +void linear_qta8a_qga4w( + ComputeGraph& graph, + const std::vector& args) { + return add_linear_qta8a_qga4w_node( + graph, + args[0], // quantized input (char tensor) + args[1], // input_scale (float buffer tensor) + args[2], // input_zero_point (int buffer tensor) + args[3], // quantized weights (4-bit packed, byte) + args[4], // group_size (int) + args[5], // weight_scales (float tensor) + args[6], // weight_zeros (int tensor) + args[7] // float output tensor + ); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.linear_qta8a_qga4w.default, linear_qta8a_qga4w); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/quantized_linear_test.cpp b/backends/vulkan/test/op_tests/quantized_linear_test.cpp index 108770bb02e..0cd27ea07f3 100644 --- a/backends/vulkan/test/op_tests/quantized_linear_test.cpp +++ b/backends/vulkan/test/op_tests/quantized_linear_test.cpp @@ -581,6 +581,181 @@ void test_vulkan_linear_qcs4w( B, M, K, N, vkcompute::utils::kTexture3D, vkcompute::utils::kTexture3D); } +void test_vulkan_linear_qta8a_qga4w_impl( + const int B, + const int M, + const int K, + const int N, + const int group_size = 8, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { + assert(K % group_size == 0); + + const int64_t input_num_tokens = B * M; + const int k_groups = K / group_size; + + at::Tensor input_scale = + at::rand({input_num_tokens}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor input_zero_point = at::randint( + -10, 10, {input_num_tokens}, at::device(at::kCPU).dtype(at::kInt)); + + at::Tensor float_x = + at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat)); + + // Create a reference quantized tensor using per-token quantization + // Mimic per-token quantization using at::quantize_per_channel by reshaping + // [num_tokens, features] + at::Tensor float_x_reshaped = float_x.view({input_num_tokens, K}); + at::Tensor qx_ref_reshaped = at::quantize_per_channel( + float_x_reshaped, + input_scale.to(at::kDouble), + input_zero_point.to(at::kLong), + 0, // axis 0 for per-token (first dimension after reshape) + c10::ScalarType::QInt8); + + at::Tensor x = + at::int_repr(qx_ref_reshaped).view(float_x.sizes()).to(at::kChar); + + at::Tensor weights_4x2 = + at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte)); + at::Tensor weight_scales = + at::rand({k_groups, N}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor weight_zeros = at::randint( + -128, 128, {k_groups, N}, at::device(at::kCPU).dtype(at::kInt)); + + at::Tensor out_ref = linear_qta8a_qga4w_4bit_dequant_impl( + x, + input_scale, + input_zero_point, + weights_4x2, + group_size, + weight_scales, + weight_zeros); + + // Build Vulkan graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(utils::kTexture3D); + ComputeGraph graph(config); + +#define MAKE_TENSORREF_FOR(x) \ + ValueRef r_##x = graph.add_tensorref( \ + x.sizes().vec(), \ + from_at_scalartype(x.scalar_type()), \ + x.const_data_ptr()); + + MAKE_TENSORREF_FOR(weights_4x2); + MAKE_TENSORREF_FOR(weight_scales); + MAKE_TENSORREF_FOR(weight_zeros); + + IOValueRef r_x = graph.add_input_tensor( + x.sizes().vec(), from_at_scalartype(x.scalar_type()), in_storage); + + IOValueRef r_input_scale = graph.add_input_tensor( + input_scale.sizes().vec(), + from_at_scalartype(input_scale.scalar_type()), + utils::kBuffer); + + IOValueRef r_input_zero_point = graph.add_input_tensor( + input_zero_point.sizes().vec(), + from_at_scalartype(input_zero_point.scalar_type()), + utils::kBuffer); + + const ValueRef r_out = graph.add_tensor( + out_ref.sizes().vec(), + from_at_scalartype(out_ref.scalar_type()), + out_storage); + + VK_GET_OP_FN("et_vk.linear_qta8a_qga4w.default") + (graph, + {r_x.value, + r_input_scale.value, + r_input_zero_point.value, + r_weights_4x2, + graph.add_scalar(group_size), + r_weight_scales, + r_weight_zeros, + r_out}); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // + // Run model + // + + graph.propagate_resize(); + graph.copy_into_staging(r_x.staging, x.const_data_ptr(), x.numel()); + graph.copy_into_staging( + r_input_scale.staging, input_scale.const_data_ptr(), input_scale.numel()); + graph.copy_into_staging( + r_input_zero_point.staging, + input_zero_point.const_data_ptr(), + input_zero_point.numel()); + + graph.execute(); + + at::Tensor vk_out = at::empty_like(out_ref); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // This is a reference implementation that uses the quantized + // matmul paradigm. It should follow closely with how the vulkan + // implementation works, and demonstrates reasonably close results. + at::Tensor qmm_ref = linear_qta8a_qga4w_quantized_matmul( + x, + input_scale, + input_zero_point, + weights_4x2, + group_size, + weight_scales, + weight_zeros); + + // For quantized int8 operations, allow for 1-unit differences due to rounding + bool is_close = at::allclose(vk_out, out_ref, 5e-3, 5e-3); + if (!is_close) { + std::cout << "qmm_ref: \n" << qmm_ref << std::endl; + std::cout << "out_ref: \n" << out_ref << std::endl; + std::cout << "vk_out: \n" << vk_out << std::endl; + } + + ASSERT_TRUE(is_close); +} + +void test_vulkan_linear_qta8a_qga4w( + const int B, + const int M, + const int K, + const int N, + const int group_size = 32) { + test_vulkan_linear_qta8a_qga4w_impl( + B, + M, + K, + N, + group_size, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + test_vulkan_linear_qta8a_qga4w_impl( + B, + M, + K, + N, + group_size, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +// Test linear_qga4w operator + TEST(VulkanLinearQGA4WTest, test_reference_impl) { test_reference_linear_qga4w( /*B = */ 1, @@ -611,6 +786,8 @@ TEST(VulkanLinearQGA4WTest, test_vulkan_impl_gemm) { /*N = */ 256); } +// Test linear_qcs4w operator + TEST_F(VulkanLinearQCS4WTest, test_reference_impl) { test_reference_linear_qcs4w( /*B = */ 1, @@ -640,3 +817,87 @@ TEST_F(VulkanLinearQCS4WTest, test_vulkan_impl_gemm) { /*K = */ 32, /*N = */ 32); } + +// Test linear_qta8a_qga4w operator + +TEST_F( + VulkanLinearQTA8AQGA4WTest, + test_vulkan_linear_quant_gemm_custom_groupsize) { + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 2, + /*K = */ 8, + /*N = */ 8, + /*group_size = */ 8); + + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 2, + /*K = */ 16, + /*N = */ 8, + /*group_size = */ 8); +} + +TEST_F(VulkanLinearQTA8AQGA4WTest, test_vulkan_linear_quant_gemm) { + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 4, + /*K = */ 64, + /*N = */ 32); + + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 4, + /*K = */ 128, + /*N = */ 32); + + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 8, + /*K = */ 64, + /*N = */ 16); + + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 256, + /*K = */ 256, + /*N = */ 256); +} + +TEST_F( + VulkanLinearQTA8AQGA4WTest, + test_vulkan_linear_quant_gemv_custom_groupsize) { + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 1, + /*K = */ 8, + /*N = */ 8, + /*group_size = */ 8); + + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 1, + /*K = */ 16, + /*N = */ 8, + /*group_size = */ 8); +} + +TEST_F(VulkanLinearQTA8AQGA4WTest, test_vulkan_linear_quant_gemv) { + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 1, + /*K = */ 32, + /*N = */ 32); + + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 1, + /*K = */ 64, + /*N = */ 16); + + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 1, + /*K = */ 256, + /*N = */ 256); +} \ No newline at end of file