From 9aba8d6e3f1ca872e35fbd56f5aa1ea08b3b4b5e Mon Sep 17 00:00:00 2001 From: morelos Date: Thu, 10 Jul 2025 11:26:22 -0700 Subject: [PATCH 1/2] [ET-VK][Ops] linear_qta8a_qga4w test framework MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/12005 # Context This test framework establishes the foundation for validating the `linear_qta8a_qga4w` operator implementation as part of enabling dynamic quantization. The motivation stems from advancing beyond weight-only quantization to full activation and weight quantized linear operations, enabling true integer arithmetic throughout the matrix multiplication process for improved performance on GPU hardware. The current weight-only quantized linear implementations in ET-VK dequantize weights to floating point before computation, missing the performance benefits of integer arithmetic. This operator nomenclature breakdown: - **qta8a**: Quantized per-token affine 8-bit activation inputs - **qga4w**: Quantized per-group affine 4-bit weights # Changes The reference implementation (`linear_qta8a_qga4w_4bit_dequant_impl`) provides a baseline for validating the GPU shader implementation through a deliberately simplified computation path. The quantized int8 input tensor is dequantized using the standard affine transformation `(quantized_input.to(at::kFloat) - input_zero_point) * input_scale`. After dequantization, the implementation performs standard floating point linear operation `at::linear(x_float, weights_dequantized)`. This two-stage approach of dequantize → compute provides a clear reference against which the GPU's integer arithmetic implementation can be validated. ghstack-source-id: 295393632 @exported-using-ghexport Differential Revision: [D77173442](https://our.internmc.facebook.com/intern/diff/D77173442/) --- ...nt4_test.cpp => quantized_linear_test.cpp} | 192 +++++++++++++++++- backends/vulkan/test/op_tests/targets.bzl | 2 +- 2 files changed, 190 insertions(+), 4 deletions(-) rename backends/vulkan/test/op_tests/{linear_weight_int4_test.cpp => quantized_linear_test.cpp} (64%) diff --git a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp b/backends/vulkan/test/op_tests/quantized_linear_test.cpp similarity index 64% rename from backends/vulkan/test/op_tests/linear_weight_int4_test.cpp rename to backends/vulkan/test/op_tests/quantized_linear_test.cpp index e48042c4620..108770bb02e 100644 --- a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp +++ b/backends/vulkan/test/op_tests/quantized_linear_test.cpp @@ -18,6 +18,36 @@ #include +class VulkanLinearQCS4WTest : public ::testing::Test { + public: + void SetUp() override { + if (!vkcompute::api::context() + ->adapter_ptr() + ->supports_int16_shader_types()) { + GTEST_SKIP(); + } + } + + void TearDown() override { + // Clean up any resources if needed + } +}; + +class VulkanLinearQTA8AQGA4WTest : public ::testing::Test { + public: + void SetUp() override { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + } + + void TearDown() override { + // Clean up any resources if needed + } +}; + // // Reference Implementations // @@ -149,6 +179,162 @@ at::Tensor linear_qcs4w_reference_impl( return out.reshape(out_shape); } +at::Tensor linear_qta8a_qga4w_quantized_matmul( + const at::Tensor& quantized_input, // [B, M, K] int8 quantized input + const at::Tensor& input_scale, // [B*M] per-token input scales + const at::Tensor& input_zero_point, // [B*M] per-token input zero points + const at::Tensor& weights_4x2, // [N, K/2] 4-bit packed weights + const int64_t group_size, // Group size for weight quantization + const at::Tensor& weight_scales, // [K/group_size, N] weight scales + const at::Tensor& weight_zeros) { // [K/group_size, N] weight zeros + + const int64_t B = quantized_input.size(0); + const int64_t M = quantized_input.size(1); + const int64_t K = quantized_input.size(2); + const int64_t N = weights_4x2.size(0); + + // Create output tensor for floating point results + at::Tensor float_output = + at::zeros({B, M, N}, at::device(at::kCPU).dtype(at::kFloat)); + + // Accessors for efficient access + auto input_accessor = quantized_input.accessor(); + auto output_accessor = float_output.accessor(); + auto weights_accessor = weights_4x2.accessor(); + auto weight_scales_accessor = weight_scales.accessor(); + auto weight_zeros_accessor = weight_zeros.accessor(); + auto input_scale_accessor = input_scale.accessor(); + auto input_zero_accessor = input_zero_point.accessor(); + + // Perform quantized matrix multiplication following quantization.md equation + // (5): result_real_value = lhs_scale * rhs_scale * Sum_over_k( + // (lhs_quantized_value[k] - lhs_zero_point) * + // (rhs_quantized_value[k] - rhs_zero_point) + // ) + for (int64_t b = 0; b < B; b++) { + for (int64_t m = 0; m < M; m++) { + const int64_t token_idx = b * M + m; + const float lhs_scale = + input_scale_accessor[token_idx]; // Per-token input scale + const int32_t lhs_zero_point = + input_zero_accessor[token_idx]; // Per-token input zero point + + for (int64_t n = 0; n < N; n++) { + float result_real_value = 0.0f; + + for (int64_t k = 0; k < K; k++) { + // Get per-group weight quantization parameters + const int64_t group_idx = k / group_size; + const float rhs_scale = + weight_scales_accessor[group_idx][n]; // Per-group weight scale + const int32_t rhs_zero_point = + weight_zeros_accessor[group_idx] + [n]; // Per-group weight zero point + + // Unpack the 4-bit weight for this position + const uint8_t packed_val = weights_accessor[n][k / 2]; + uint8_t weight_4bit; + if (k % 2 == 0) { + weight_4bit = (packed_val & 0xF0) >> 4; // First weight in pair + } else { + weight_4bit = packed_val & 0x0F; // Second weight in pair + } + + // Get quantized values + const int32_t lhs_quantized_value = + static_cast(input_accessor[b][m][k]); + // Convert 4-bit weight to signed: subtract 8 to get range [-8, 7] + const int32_t rhs_quantized_value = + static_cast(weight_4bit) - 8; + + // Apply proper quantization paradigm from quantization.md equation + // (3): real_value = scale * (quantized_value - zero_point) Following + // equation (5): result = lhs_scale * rhs_scale * + // (lhs_quantized - lhs_zero) * (rhs_quantized - rhs_zero) + const float lhs_diff = + static_cast(lhs_quantized_value - lhs_zero_point); + const float rhs_diff = + static_cast(rhs_quantized_value - rhs_zero_point); + + result_real_value += lhs_scale * rhs_scale * lhs_diff * rhs_diff; + } + + output_accessor[b][m][n] = result_real_value; + } + } + } + + return float_output; +} + +at::Tensor linear_qta8a_qga4w_4bit_dequant_impl( + const at::Tensor& quantized_input, + const at::Tensor& input_scale, + const at::Tensor& input_zero_point, + const at::Tensor& weights_4x2, + const int64_t group_size, + const at::Tensor& weight_scales, + const at::Tensor& weight_zeros) { + // Calculate number of input tokens + int64_t input_num_tokens = 1; + for (size_t i = 0; i < quantized_input.sizes().size() - 1; i++) { + input_num_tokens *= quantized_input.size(i); + } + + // Manually dequantize the char tensor using per-token quantization + at::Tensor x_float = at::zeros_like(quantized_input, at::kFloat); + + // Apply per-token dequantization + auto input_accessor = quantized_input.accessor(); + auto output_accessor = x_float.accessor(); + + for (int64_t token_idx = 0; token_idx < input_num_tokens; token_idx++) { + float scale_val = input_scale[token_idx].item(); + int zero_point_val = input_zero_point[token_idx].item(); + + // Calculate batch and sequence indices for this token + int64_t b = token_idx / quantized_input.size(1); + int64_t m = token_idx % quantized_input.size(1); + + // Apply dequantization for all features in this token + for (int64_t k = 0; k < quantized_input.size(-1); k++) { + float dequant_val = + (input_accessor[b][m][k] - zero_point_val) * scale_val; + output_accessor[b][m][k] = dequant_val; + } + } + + std::vector weights_shape(weights_4x2.sizes().vec()); + weights_shape[1] *= 2; + + at::Tensor weights_dequantized = + at::empty(weights_shape, at::device(at::kCPU).dtype(at::kFloat)); + + const int64_t N = weights_dequantized.size(0); + const int64_t K = weights_dequantized.size(1); + + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k += 2) { + const int group_idx = k / group_size; + const uint8_t packed_val = weights_4x2[n][k / 2].item().to(); + const uint8_t second_val = packed_val & 0x0F; + const uint8_t first_val = (packed_val & 0xF0) >> 4; + + const float scale = weight_scales[group_idx][n].item().to(); + const int zero = weight_zeros[group_idx][n].item().to(); + + weights_dequantized[n][k] = + ((float(first_val) - 8.0) - float(zero)) * scale; + weights_dequantized[n][k + 1] = + ((float(second_val) - 8.0) - float(zero)) * scale; + } + } + + at::Tensor linear_result = at::linear(x_float, weights_dequantized); + + return linear_result; +} + // // Test functions // @@ -425,7 +611,7 @@ TEST(VulkanLinearQGA4WTest, test_vulkan_impl_gemm) { /*N = */ 256); } -TEST(VulkanLinearQCS4WTest, test_reference_impl) { +TEST_F(VulkanLinearQCS4WTest, test_reference_impl) { test_reference_linear_qcs4w( /*B = */ 1, /*M = */ 4, @@ -433,7 +619,7 @@ TEST(VulkanLinearQCS4WTest, test_reference_impl) { /*N = */ 32); } -TEST(VulkanLinearQCS4WTest, test_vulkan_impl_small_m) { +TEST_F(VulkanLinearQCS4WTest, test_vulkan_impl_small_m) { test_vulkan_linear_qcs4w( /*B = */ 1, /*M = */ 4, @@ -447,7 +633,7 @@ TEST(VulkanLinearQCS4WTest, test_vulkan_impl_small_m) { /*N = */ 256); } -TEST(VulkanLinearQCS4WTest, test_vulkan_impl_gemm) { +TEST_F(VulkanLinearQCS4WTest, test_vulkan_impl_gemm) { test_vulkan_linear_qcs4w( /*B = */ 1, /*M = */ 32, diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index 0d014c7ef29..9eac90ac33d 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -205,7 +205,7 @@ def define_common_targets(is_fbcode = False): ] ) define_test_targets( - "linear_weight_int4_test", + "quantized_linear_test", extra_deps = [ ":test_utils", ] From 89548ff769b56e8caa307805857186d969b09014 Mon Sep 17 00:00:00 2001 From: morelos Date: Thu, 10 Jul 2025 12:02:01 -0700 Subject: [PATCH 2/2] [ET-VK][Ops] linear_qta8a_qga4w impl and shaders Pull Request resolved: https://github.com/pytorch/executorch/pull/12006 # Operator Description The linear_qta8a_qga4w operator implements a quantized linear transformation that enables efficient neural network inference through dynamic quantization. This operator performs matrix multiplication between quantized 8-bit activations and 4-bit grouped quantized weights, producing quantized 8-bit outputs. The quantization scheme follows the standard affine mapping where `real_value = scale * (quantized_value - zero_point)`. Input activations use 8-bit signed integers with per-token scale and zero-point parameters, while weights employ 4-bit quantization with group-wise parameters. # Implementation Architecture The operator provides two distinct computational approaches optimized for different matrix multiplication scenarios: the TILED algorithm for general matrix-matrix multiplication (GEMM) and the COOPERATIVE algorithm for matrix-vector multiplication (GEMV). ## TILED Algorithm (GEMM Cases) The tiled implementation processes the output matrix in rectangular blocks. Each thread is responsible for calculating a tile of output values, typically processing 3 rows and 2 columns worth of results in each iteration. The algorithm operates by having each thread load blocks of quantized weights and activations, perform integer arithmetic accumulation, and then apply the necessary scaling operations. Weight data is pre-packed in a specialized format where two 4-bit values are stored in each byte. Each thread loads multiple weight elements simultaneously and unpacks them during computation. The quantization parameters for weights are organized by groups, where each group of consecutive weight elements shares the same scale and zero-point values. ## COOPERATIVE Algorithm (GEMV Cases) The cooperative implementation uses shared memory and thread cooperation where this approach uses workgroups of 64 threads arranged as 8 groups of 8 workers each. The key insight is that GEMV operations have limited parallelism in the output dimension but substantial parallelism in the reduction dimension, making cooperative reduction strategies more effective than independent thread computation. Each group of 8 worker threads collaboratively computes a portion of the output vector. The workers divide the reduction work along the input feature dimension, with each worker processing every 8th element in a strided pattern. # Future Performance Improvements - Making use of dotPacked4x8EXT (this requires upgrading glslc and vulkan) - Fixed point math for pure integer operations - Might be more performant to avoid preloading tensors - Might also be more performant to avoid excessive register overhead by defining the ivec4 within each block operation (allowing more threads to be more register intensive) ghstack-source-id: 295447206 Differential Revision: [D77173441](https://our.internmc.facebook.com/intern/diff/D77173441/) --- .../ops/glsl/linear_qta8a_qga4w_coop.glsl | 232 +++++++++++++++ .../ops/glsl/linear_qta8a_qga4w_coop.yaml | 26 ++ .../ops/glsl/linear_qta8a_qga4w_tiled.glsl | 196 +++++++++++++ .../ops/glsl/linear_qta8a_qga4w_tiled.yaml | 26 ++ .../ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp | 268 ++++++++++++++++++ .../test/op_tests/quantized_linear_test.cpp | 261 +++++++++++++++++ 6 files changed, 1009 insertions(+) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.glsl new file mode 100644 index 00000000000..174ea1cc9bb --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.glsl @@ -0,0 +1,232 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} +#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} + +#define TILE_ROWS ${TILE_ROWS} + +#define NGROUPS 8 +#define NWORKERS 8 + +${define_required_extensions(DTYPE)} +$if IN_STORAGE == "buffer": + ${define_required_extensions("int8")} +$if WEIGHT_STORAGE == "buffer": + ${define_required_extensions("uint8")} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_mat1", "int8", IN_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", "float", PARAMS_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_zeros", "int", PARAMS_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input_scale", "float", PARAMS_STORAGE, is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_input_zero_point", "int", PARAMS_STORAGE, is_scalar_array=True)} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 mat1_sizes; + ivec4 qmat2_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int group_size = 64; + +shared vec4 partial_results[NGROUPS][NWORKERS][TILE_ROWS][2]; + +/* + * This shader computes a linear operator between a quantized int8 input matrix + * x and a weights matrix that is quantized to 4 bits, producing a float output. + * + * This shader implements a co-operative algorithm to compute the output. The + * work group size is {NGROUP, 1, NWORKERS}, and each group of NWORKERS threads + * cooperative to compute TILE_ROWS * 2 output texels. Therefore, + * NGROUP * TILE_ROWS * 2 output texels are computed across one work group. + * + * The threads co-operate by each thread computing a partial reduction along the + * K dimension. To illustrate the computation, consider a scalar variant of the + * algorithm that computes the dot product of 2 vectors. Also assume that + * NWORKERS is 8. + * + * Thread 1 in each group will compute: + * (mat1[0] * mat2[0]) + (mat1[8] * mat2[8]) + (mat1[16] * mat2[16]) + ... + * + * Thread 2 in each group will compute: + * (mat1[1] * mat2[1]) + (mat2[9] * mat2[9]) + (mat1[17] * mat2[17]) + ... + * + * Thread 3 in each group will compute: + * (mat1[2] * mat2[2]) + (mat2[10] * mat2[10]) + (mat1[18] * mat2[18]) + ... + * + * The partial accumulations is structured such that memory accesses in each + * loop iteration can be coalesced. + * + * Then, at the end first thread in each group will accumulate the partial + * accumulations computed by each thread to obtain the final result. + * + * Note that this shader assumes that all tensors are width packed. + */ + +void main() { + const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; + const uint out_col = gl_GlobalInvocationID.x << 3; + const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1; + + const uint gid = gl_LocalInvocationID.x; // group id + const uint wid = gl_LocalInvocationID.z; // worker id + + if (out_col >= out_sizes.x || out_row >= out_sizes.y) { + return; + } + + const int num_blocks = mat1_sizes.x / group_size; + + ivec4 mat1_quantized[TILE_ROWS]; + ivec4 qmat2_quantized[4][2]; + vec4 final_result[TILE_ROWS][2]; + + // Initialize accumulators + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + final_result[r][0] = vec4(0.0); + final_result[r][1] = vec4(0.0); + } + + vec4 scales[2]; + vec4 zeros[2]; + + $if WEIGHT_STORAGE == "buffer": + const int qmat2_stride = qmat2_sizes.x >> 2; + $if PARAMS_STORAGE == "buffer": + const int qparams_stride = out_sizes.x >> 2; + + for (int block_idx = 0; block_idx < num_blocks; ++block_idx) { + $if PARAMS_STORAGE == "buffer": + scales[0] = t_weight_scales[block_idx * qparams_stride + out_col_texel_idx]; + scales[1] = t_weight_scales[block_idx * qparams_stride + out_col_texel_idx + 1]; + + zeros[0] = vec4(t_weight_zeros[block_idx * qparams_stride + out_col_texel_idx]); + zeros[1] = vec4(t_weight_zeros[block_idx * qparams_stride + out_col_texel_idx + 1]); + $else: + scales[0] = texelFetch(t_weight_scales, ivec3(out_col_texel_idx, block_idx, 0), 0); + scales[1] = texelFetch(t_weight_scales, ivec3(out_col_texel_idx + 1, block_idx, 0), 0); + + zeros[0] = vec4(texelFetch(t_weight_zeros, ivec3(out_col_texel_idx, block_idx, 0), 0)); + zeros[1] = vec4(texelFetch(t_weight_zeros, ivec3(out_col_texel_idx + 1, block_idx, 0), 0)); + + ivec4 int32_sums[TILE_ROWS][2]; + int input_sums[TILE_ROWS]; + + // Initialize accumulators for this block + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + int32_sums[r][0] = ivec4(0); + int32_sums[r][1] = ivec4(0); + input_sums[r] = 0; + } + + for (int g_idx = 4 * int(wid); g_idx < group_size; g_idx += (4 * NWORKERS)) { + const int k = block_idx * group_size + g_idx; + + // Preload B (weights) - keep as quantized integers + [[unroll]] for (int r = 0; r < 4; ++r) { + $if WEIGHT_STORAGE == "buffer": + const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x]; + $else: + const uvec4 packed_weight_tex = texelFetch( + t_qmat2, + ivec2(gl_GlobalInvocationID.x, k + r), + 0); + + // Unpack 4-bit weights to integers and subtract zero point (8 for 4-bit) + qmat2_quantized[r][0] = ivec4((packed_weight_tex & 0xF0) >> 4) - 8; + qmat2_quantized[r][1] = ivec4(packed_weight_tex & 0x0F) - 8; + } + + // Preload A (quantized input) - keep as quantized integers + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $if IN_STORAGE == "buffer": + mat1_quantized[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2] - t_input_zero_point[int(out_row) + r]; + $else: + mat1_quantized[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0) - t_input_zero_point[int(out_row) + r]; + } + + // Accumulate in integer arithmetic: (input_quantized - input_zero_point) * (weight_quantized - weight_zero_point) + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + input_sums[r] += mat1_quantized[r].x + mat1_quantized[r].y + mat1_quantized[r].z + mat1_quantized[r].w; + + int32_sums[r][0] += mat1_quantized[r].x * qmat2_quantized[0][0] + + mat1_quantized[r].y * qmat2_quantized[1][0] + + mat1_quantized[r].z * qmat2_quantized[2][0] + + mat1_quantized[r].w * qmat2_quantized[3][0]; + + int32_sums[r][1] += mat1_quantized[r].x * qmat2_quantized[0][1] + + mat1_quantized[r].y * qmat2_quantized[1][1] + + mat1_quantized[r].z * qmat2_quantized[2][1] + + mat1_quantized[r].w * qmat2_quantized[3][1]; + } + } + + // Incorporates this block's results into the final accumulation + // Following proper quantization paradigm: result = input_scale * weight_scale * + // Sum((input_quantized - input_zero) * (weight_quantized - weight_zero)) + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + if (out_row + r >= out_sizes.y) { + continue; + } + + float input_scale = t_input_scale[int(out_row) + r]; + float input_sum_scalar = float(input_sums[r]); + + // Apply proper quantization paradigm: input_scale * weight_scale * (accumulator - weight_zero * input_sum) + final_result[r][0] += input_scale * scales[0] * (vec4(int32_sums[r][0]) - zeros[0] * input_sum_scalar); + final_result[r][1] += input_scale * scales[1] * (vec4(int32_sums[r][1]) - zeros[1] * input_sum_scalar); + } + } + + // Store worker results in shared memory + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + partial_results[gid][wid][r][0] = final_result[r][0]; + partial_results[gid][wid][r][1] = final_result[r][1]; + } + + memoryBarrierShared(); + barrier(); + + // Only the first worker in each group accumulates and writes output + if (wid != 0) { + return; + } + + vec4 cooperative_result[TILE_ROWS][2]; + + for (int r = 0; r < TILE_ROWS; ++r) { + cooperative_result[r][0] = vec4(0.0); + cooperative_result[r][1] = vec4(0.0); + [[unroll]] for (int worker = 0; worker < NWORKERS; ++worker) { + cooperative_result[r][0] += partial_results[gid][worker][r][0]; + cooperative_result[r][1] += partial_results[gid][worker][r][1]; + } + } + + // Apply final output quantization + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $if OUT_STORAGE == "buffer": + t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = cooperative_result[r][0]; + t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = cooperative_result[r][1]; + $else: + imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), cooperative_result[r][0]); + imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), cooperative_result[r][1]); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.yaml new file mode 100644 index 00000000000..9f6db77094a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.yaml @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_qta8a_qga4w_coop: + parameter_names_with_default_values: + DTYPE: float + OUT_STORAGE: texture3d + IN_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + PARAMS_STORAGE: buffer + TILE_ROWS: 1 + shader_variants: + - NAME: linear_qta8a_qga4w_coop_texture3d_texture3d_texture2d_float + - NAME: linear_qta8a_qga4w_coop_buffer_buffer_texture2d_float + OUT_STORAGE: buffer + IN_STORAGE: buffer + - NAME: linear_qta8a_qga4w_coop_buffer_buffer_buffer_float + OUT_STORAGE: buffer + IN_STORAGE: buffer + WEIGHT_STORAGE: buffer + - NAME: linear_qta8a_qga4w_coop_buffer_texture2d_buffer_float + OUT_STORAGE: buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.glsl new file mode 100644 index 00000000000..dbb7da998f4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.glsl @@ -0,0 +1,196 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} +#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} + +#define TILE_ROWS ${TILE_ROWS} + +${define_required_extensions(DTYPE)} +$if IN_STORAGE == "buffer": + ${define_required_extensions("int8")} +$if WEIGHT_STORAGE == "buffer": + ${define_required_extensions("uint8")} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_mat1", "int8", IN_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", "float", PARAMS_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_zeros", "int", PARAMS_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input_scale", "float", "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_input_zero_point", "int", "buffer", is_scalar_array=True)} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 mat1_sizes; + ivec4 qmat2_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int group_size = 64; + +/* + * This shader computes a linear operator between a quantized int8 input matrix + * x and a weights matrix that is quantized to 4 bits, producing a float output. + * + * The (W, H, C) shape of each tensor is: + * - x: (K, M) - quantized int8 input with per-token quantization + * - weights: (N / 2, K) + * - The weights tensor has a data type of `uint8`. Each element in the tensor + * contains 2 4-bit values packed into a uint8. + * - See the pack_int4_linear_weight_transposed_interleave shader to see more + * details on how the weight tensor is stored. + * - qparams: (2, N, number_of_groups) + * - This tensor contains the scales and zeros quantization parameters for the + * weights tensor. The weight tensor is quantized group-wise, which means + * that every `group_size` elements along the K dimension of the weights + * tensor has independent quantization parameters. Along the width dim, the + * first value contains the scale for the group and the second value + * contains the zero point for the group. + * - input_scale: (num_tokens,) - per-token scale values for input quantization + * - input_zero_point: (num_tokens,) - per-token zero points for input quantization + * - output: (N, M) - float output + * + * Each thread computes a tile of TILE_ROWS * 2 texels of the output tensor. + * + * Note that this shader assumes that all tensors are width packed. + */ + +void main() { + const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; + const uint out_col = gl_GlobalInvocationID.x << 3; + const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1; + + if (out_col >= out_sizes.x || out_row >= out_sizes.y) { + return; + } + + const int num_blocks = mat1_sizes.x / group_size; + + ivec4 mat1_quantized[TILE_ROWS]; + ivec4 qmat2_quantized[4][2]; + vec4 final_result[TILE_ROWS][2]; + + // Initialize accumulatoxrs + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + final_result[r][0] = vec4(0.0); + final_result[r][1] = vec4(0.0); + } + + vec4 scales[2]; + vec4 zeros[2]; + + $if WEIGHT_STORAGE == "buffer": + const int qmat2_stride = qmat2_sizes.x >> 2; + $if PARAMS_STORAGE == "buffer": + const int qparams_stride = out_sizes.x >> 2; + + for (int block_idx = 0; block_idx < num_blocks; ++block_idx) { + $if PARAMS_STORAGE == "buffer": + scales[0] = t_weight_scales[block_idx * qparams_stride + out_col_texel_idx]; + scales[1] = t_weight_scales[block_idx * qparams_stride + out_col_texel_idx + 1]; + + zeros[0] = vec4(t_weight_zeros[block_idx * qparams_stride + out_col_texel_idx]); + zeros[1] = vec4(t_weight_zeros[block_idx * qparams_stride + out_col_texel_idx + 1]); + $else: + scales[0] = texelFetch(t_weight_scales, ivec3(out_col_texel_idx, block_idx, 0), 0); + scales[1] = texelFetch(t_weight_scales, ivec3(out_col_texel_idx + 1, block_idx, 0), 0); + + zeros[0] = vec4(texelFetch(t_weight_zeros, ivec3(out_col_texel_idx, block_idx, 0), 0)); + zeros[1] = vec4(texelFetch(t_weight_zeros, ivec3(out_col_texel_idx + 1, block_idx, 0), 0)); + + ivec4 int32_sums[TILE_ROWS][2]; + int input_sums[TILE_ROWS]; + + // Initialize accumulators + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + int32_sums[r][0] = ivec4(0); + int32_sums[r][1] = ivec4(0); + input_sums[r] = 0; + } + + for (int g_idx = 0; g_idx < group_size; g_idx += 4) { + const int k = block_idx * group_size + g_idx; + + // Preload B (weights) - keep as quantized integers + [[unroll]] for (int r = 0; r < 4; ++r) { + $if WEIGHT_STORAGE == "buffer": + const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x]; + $else: + const uvec4 packed_weight_tex = texelFetch( + t_qmat2, + ivec2(gl_GlobalInvocationID.x, k + r), + 0); + + // Unpack 4-bit weights to integers (subtract 8 as the 4-bit zero point) + qmat2_quantized[r][0] = ivec4((packed_weight_tex & 0xF0) >> 4) - 8; + qmat2_quantized[r][1] = ivec4(packed_weight_tex & 0x0F) - 8; + } + + // Preload A (quantized input) - keep as quantized integers + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $if IN_STORAGE == "buffer": + mat1_quantized[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2] - t_input_zero_point[int(out_row) + r]; + $else: + mat1_quantized[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0) - t_input_zero_point[int(out_row) + r]; + } + + // Accumulate in integer arithmetic: (input_quantized - input_zero_point) * (weight_quantized - weight_zero_point) + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + input_sums[r] += mat1_quantized[r].x + mat1_quantized[r].y + mat1_quantized[r].z + mat1_quantized[r].w; + + int32_sums[r][0] += mat1_quantized[r].x * qmat2_quantized[0][0] + + mat1_quantized[r].y * qmat2_quantized[1][0] + + mat1_quantized[r].z * qmat2_quantized[2][0] + + mat1_quantized[r].w * qmat2_quantized[3][0]; + + int32_sums[r][1] += mat1_quantized[r].x * qmat2_quantized[0][1] + + mat1_quantized[r].y * qmat2_quantized[1][1] + + mat1_quantized[r].z * qmat2_quantized[2][1] + + mat1_quantized[r].w * qmat2_quantized[3][1]; + } + } + + // Incorporates this block's results into the final accumulation + // Following proper quantization paradigm: result = input_scale * weight_scale * + // Sum((input_quantized - input_zero) * (weight_quantized - weight_zero)) + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + if (out_row + r >= out_sizes.y) { + continue; + } + + float input_scale = t_input_scale[int(out_row) + r]; + float input_sum_scalar = float(input_sums[r]); + + // Apply proper quantization paradigm: input_scale * weight_scale * (accumulator - weight_zero * input_sum) + final_result[r][0] += input_scale * scales[0] * (vec4(int32_sums[r][0]) - zeros[0] * input_sum_scalar); + final_result[r][1] += input_scale * scales[1] * (vec4(int32_sums[r][1]) - zeros[1] * input_sum_scalar); + } + } + + // Apply ALL scaling at the very end + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $if OUT_STORAGE == "buffer": + if (out_row + r < out_sizes.y) { + t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = final_result[r][0]; + t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = final_result[r][1]; + } + $else: + imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), final_result[r][0]); + imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), final_result[r][1]); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.yaml new file mode 100644 index 00000000000..c96d693834b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.yaml @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_qta8a_qga4w_tiled: + parameter_names_with_default_values: + DTYPE: float + OUT_STORAGE: texture3d + IN_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + PARAMS_STORAGE: buffer + TILE_ROWS: 3 + shader_variants: + - NAME: linear_qta8a_qga4w_tiled_texture3d_texture3d_texture2d_float + - NAME: linear_qta8a_qga4w_tiled_buffer_buffer_texture2d_float + OUT_STORAGE: buffer + IN_STORAGE: buffer + - NAME: linear_qta8a_qga4w_tiled_buffer_buffer_buffer_float + OUT_STORAGE: buffer + IN_STORAGE: buffer + WEIGHT_STORAGE: buffer + - NAME: linear_qta8a_qga4w_tiled_buffer_texture2d_buffer_float + OUT_STORAGE: buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp new file mode 100644 index 00000000000..a47c58b7ef6 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp @@ -0,0 +1,268 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include + +namespace vkcompute { + +void check_linear_qta8a_qga4w_args( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat1_scale, + const ValueRef mat1_zero_point, + const ValueRef mat2_data, + const ValueRef group_size, + const ValueRef weight_scales, + const ValueRef weight_zeros, + const ValueRef out) { + VK_CHECK_COND(graph.val_is_tensor(mat1)); + VK_CHECK_COND(graph.val_is_tensor(mat1_scale)); + VK_CHECK_COND(graph.val_is_tensor(mat1_zero_point)); + VK_CHECK_COND(graph.val_is_tref(mat2_data)); + VK_CHECK_COND(graph.val_is_tref(weight_scales)); + VK_CHECK_COND(graph.val_is_tref(weight_zeros)); + + VK_CHECK_COND(graph.dim_of(mat1) <= 3); + VK_CHECK_COND(graph.dim_of(mat2_data) == 2); + VK_CHECK_COND(graph.dim_of(weight_scales) == 2); + VK_CHECK_COND(graph.dim_of(weight_zeros) == 2); + + VK_CHECK_COND(graph.size_at(-3, mat1) == 1); + const int K = graph.size_at(-1, mat1); + VK_CHECK_COND(graph.size_at(-1, mat2_data) * 2 == K); + + const int group_size_val = graph.extract_scalar(group_size); + VK_CHECK_COND(K % group_size_val == 0); + // Due to the way weight packing works, group size needs to be a multiple of 8 + VK_CHECK_COND(group_size_val % 8 == 0); + + VK_CHECK_COND(graph.has_standard_axis_map(mat1)); + VK_CHECK_COND(graph.has_standard_axis_map(out)); + + // Check that scale and zero_point tensors are buffer storage with width + // packing + VK_CHECK_COND(graph.is_buffer_storage(mat1_scale)); + VK_CHECK_COND(graph.packed_dim_of(mat1_scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(mat1_zero_point)); + VK_CHECK_COND(graph.packed_dim_of(mat1_zero_point) == WHCN::kWidthDim); + + // Calculate number of tokens for input + int64_t input_num_tokens = 1; + const auto mat1_sizes = graph.sizes_of(mat1); + for (size_t i = 0; i < mat1_sizes.size() - 1; i++) { + input_num_tokens *= mat1_sizes[i]; + } + + // Verify scale and zero_point tensor sizes match number of tokens + const auto mat1_scale_sizes = graph.sizes_of(mat1_scale); + const auto mat1_zero_point_sizes = graph.sizes_of(mat1_zero_point); + + VK_CHECK_COND(mat1_scale_sizes.size() == 1); + VK_CHECK_COND(mat1_zero_point_sizes.size() == 1); + + VK_CHECK_COND(mat1_scale_sizes[0] == input_num_tokens); + VK_CHECK_COND(mat1_zero_point_sizes[0] == input_num_tokens); + + // Verify weight scales and zeros have the same shape + const auto weight_scales_sizes = graph.sizes_of(weight_scales); + const auto weight_zeros_sizes = graph.sizes_of(weight_zeros); + VK_CHECK_COND(weight_scales_sizes == weight_zeros_sizes); +} + +void resize_linear_qta8a_qga4w_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); + vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); + + const int64_t out_cols = utils::val_at(-2, mat1->sizes()); + const int64_t out_rows = utils::val_at(-1, mat2->sizes()) * 2; + + std::vector new_out_sizes(3); + if (mat1->sizes().size() == 2) { + new_out_sizes.resize(2); + new_out_sizes.at(0) = out_cols; + new_out_sizes.at(1) = out_rows; + } else { + new_out_sizes.at(0) = mat1->sizes().at(0); + new_out_sizes.at(1) = out_cols; + new_out_sizes.at(2) = out_rows; + } + + out->virtual_resize(new_out_sizes); +} + +/** + * Determines if the cooperative algorithm should be used based on input tensor + * dimensions. Apply the coop algorithm for vectors (GEMV cases), tiled for + * matrices (GEMM cases). + */ +bool should_use_coop_algorithm_qta8a_qga4w( + ComputeGraph* graph, + const ValueRef& mat1) { + const uint32_t M = graph->size_at(-2, mat1); + // Use coop algorithm for vectors (GEMV), tiled for larger matrices (GEMM) + return M == 1; +} + +vkapi::ShaderInfo pick_linear_qta8a_qga4w_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mat1 = args.at(1).refs.at(0); + const ValueRef mat2 = args.at(1).refs.at(1); + + const bool use_coop_algorithm = + should_use_coop_algorithm_qta8a_qga4w(graph, mat1); + + std::string kernel_name = "linear_qta8a_qga4w"; + if (use_coop_algorithm) { + kernel_name += "_coop"; + } else { + kernel_name += "_tiled"; + } + add_storage_type_suffix(kernel_name, graph->storage_type_of(out)); + add_storage_type_suffix(kernel_name, graph->storage_type_of(mat1)); + add_storage_type_suffix(kernel_name, graph->storage_type_of(mat2)); + add_dtype_suffix(kernel_name, graph->dtype_of(out)); + + return VK_KERNEL_FROM_STR(kernel_name); +} + +utils::uvec3 linear_qta8a_qga4w_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + + const bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; + + // C = 1, H = 2, W = 3 + // global_wg_size = {round_up(C / 2f), round_up(H / 3f), W} --> (2W, 1H, 0C) + // --> {1, 1, 3} global + + utils::uvec3 global_wg_size = graph->logical_limits_of(out); + global_wg_size[0] = utils::div_up(global_wg_size[0], uint32_t(2)); + if (!use_coop_algorithm) { // GEMM - TILED + global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(3)); + } + + return global_wg_size; +} + +utils::uvec3 linear_qta8a_qga4w_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)args; + (void)resize_args; + + const bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; + + utils::uvec3 local_wg_size; + if (use_coop_algorithm) { // GEMV - COOP + local_wg_size = {8, 1, 8}; + } else { // GEMM - TILED + local_wg_size = graph->create_local_wg_size(global_workgroup_size); + } + + return local_wg_size; +} + +void add_linear_qta8a_qga4w_node( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat1_scale, + const ValueRef mat1_zero_point, + const ValueRef mat2_data, + const ValueRef group_size, + const ValueRef weight_scales_data, + const ValueRef weight_zeros_data, + const ValueRef out) { + check_linear_qta8a_qga4w_args( + graph, + mat1, + mat1_scale, + mat1_zero_point, + mat2_data, + group_size, + weight_scales_data, + weight_zeros_data, + out); + const uint32_t group_size_val = graph.extract_scalar(group_size); + + ValueRef mat2 = + prepack_int4_linear_weight_transposed_interleaved(graph, mat2_data); + ValueRef weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + ValueRef weight_zeros = prepack_standard( + graph, weight_zeros_data, utils::kBuffer, utils::kWidthPacked); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + pick_linear_qta8a_qga4w_shader, + linear_qta8a_qga4w_global_wg_size, + linear_qta8a_qga4w_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, + {{mat1, mat2, weight_scales, weight_zeros, mat1_scale, mat1_zero_point}, + vkapi::kRead}}, + // Shader params buffers + {}, + // Push Constants + {graph.sizes_pc_of(out), + graph.sizes_pc_of(mat1), + graph.sizes_pc_of(mat2)}, + // Specialization Constants + {SV(group_size_val)}, + // Resize Args + {}, + // Resizing Logic + resize_linear_qta8a_qga4w_node)); +} + +void linear_qta8a_qga4w( + ComputeGraph& graph, + const std::vector& args) { + return add_linear_qta8a_qga4w_node( + graph, + args[0], // quantized input (char tensor) + args[1], // input_scale (float buffer tensor) + args[2], // input_zero_point (int buffer tensor) + args[3], // quantized weights (4-bit packed, byte) + args[4], // group_size (int) + args[5], // weight_scales (float tensor) + args[6], // weight_zeros (int tensor) + args[7] // float output tensor + ); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.linear_qta8a_qga4w.default, linear_qta8a_qga4w); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/quantized_linear_test.cpp b/backends/vulkan/test/op_tests/quantized_linear_test.cpp index 108770bb02e..4c21face8f7 100644 --- a/backends/vulkan/test/op_tests/quantized_linear_test.cpp +++ b/backends/vulkan/test/op_tests/quantized_linear_test.cpp @@ -581,6 +581,181 @@ void test_vulkan_linear_qcs4w( B, M, K, N, vkcompute::utils::kTexture3D, vkcompute::utils::kTexture3D); } +void test_vulkan_linear_qta8a_qga4w_impl( + const int B, + const int M, + const int K, + const int N, + const int group_size = 8, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { + assert(K % group_size == 0); + + const int64_t input_num_tokens = B * M; + const int k_groups = K / group_size; + + at::Tensor input_scale = + at::rand({input_num_tokens}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor input_zero_point = at::randint( + -10, 10, {input_num_tokens}, at::device(at::kCPU).dtype(at::kInt)); + + at::Tensor float_x = + at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat)); + + // Create a reference quantized tensor using per-token quantization + // Mimic per-token quantization using at::quantize_per_channel by reshaping + // [num_tokens, features] + at::Tensor float_x_reshaped = float_x.view({input_num_tokens, K}); + at::Tensor qx_ref_reshaped = at::quantize_per_channel( + float_x_reshaped, + input_scale.to(at::kDouble), + input_zero_point.to(at::kLong), + 0, // axis 0 for per-token (first dimension after reshape) + c10::ScalarType::QInt8); + + at::Tensor x = + at::int_repr(qx_ref_reshaped).view(float_x.sizes()).to(at::kChar); + + at::Tensor weights_4x2 = + at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte)); + at::Tensor weight_scales = + at::rand({k_groups, N}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor weight_zeros = at::randint( + -128, 128, {k_groups, N}, at::device(at::kCPU).dtype(at::kInt)); + + at::Tensor out_ref = linear_qta8a_qga4w_4bit_dequant_impl( + x, + input_scale, + input_zero_point, + weights_4x2, + group_size, + weight_scales, + weight_zeros); + + // Build Vulkan graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(utils::kTexture3D); + ComputeGraph graph(config); + +#define MAKE_TENSORREF_FOR(x) \ + ValueRef r_##x = graph.add_tensorref( \ + x.sizes().vec(), \ + from_at_scalartype(x.scalar_type()), \ + x.const_data_ptr()); + + MAKE_TENSORREF_FOR(weights_4x2); + MAKE_TENSORREF_FOR(weight_scales); + MAKE_TENSORREF_FOR(weight_zeros); + + IOValueRef r_x = graph.add_input_tensor( + x.sizes().vec(), from_at_scalartype(x.scalar_type()), in_storage); + + IOValueRef r_input_scale = graph.add_input_tensor( + input_scale.sizes().vec(), + from_at_scalartype(input_scale.scalar_type()), + utils::kBuffer); + + IOValueRef r_input_zero_point = graph.add_input_tensor( + input_zero_point.sizes().vec(), + from_at_scalartype(input_zero_point.scalar_type()), + utils::kBuffer); + + const ValueRef r_out = graph.add_tensor( + out_ref.sizes().vec(), + from_at_scalartype(out_ref.scalar_type()), + out_storage); + + VK_GET_OP_FN("et_vk.linear_qta8a_qga4w.default") + (graph, + {r_x.value, + r_input_scale.value, + r_input_zero_point.value, + r_weights_4x2, + graph.add_scalar(group_size), + r_weight_scales, + r_weight_zeros, + r_out}); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // + // Run model + // + + graph.propagate_resize(); + graph.copy_into_staging(r_x.staging, x.const_data_ptr(), x.numel()); + graph.copy_into_staging( + r_input_scale.staging, input_scale.const_data_ptr(), input_scale.numel()); + graph.copy_into_staging( + r_input_zero_point.staging, + input_zero_point.const_data_ptr(), + input_zero_point.numel()); + + graph.execute(); + + at::Tensor vk_out = at::empty_like(out_ref); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // This is a reference implementation that uses the quantized + // matmul paradigm. It should follow closely with how the vulkan + // implementation works, and demonstrates reasonably close results. + at::Tensor qmm_ref = linear_qta8a_qga4w_quantized_matmul( + x, + input_scale, + input_zero_point, + weights_4x2, + group_size, + weight_scales, + weight_zeros); + + // For quantized int8 operations, allow for 1-unit differences due to rounding + bool is_close = at::allclose(vk_out, out_ref, 5e-3, 5e-3); + if (!is_close) { + std::cout << "qmm_ref: \n" << qmm_ref << std::endl; + std::cout << "out_ref: \n" << out_ref << std::endl; + std::cout << "vk_out: \n" << vk_out << std::endl; + } + + ASSERT_TRUE(is_close); +} + +void test_vulkan_linear_qta8a_qga4w( + const int B, + const int M, + const int K, + const int N, + const int group_size = 32) { + test_vulkan_linear_qta8a_qga4w_impl( + B, + M, + K, + N, + group_size, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + test_vulkan_linear_qta8a_qga4w_impl( + B, + M, + K, + N, + group_size, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +// Test linear_qga4w operator + TEST(VulkanLinearQGA4WTest, test_reference_impl) { test_reference_linear_qga4w( /*B = */ 1, @@ -611,6 +786,8 @@ TEST(VulkanLinearQGA4WTest, test_vulkan_impl_gemm) { /*N = */ 256); } +// Test linear_qcs4w operator + TEST_F(VulkanLinearQCS4WTest, test_reference_impl) { test_reference_linear_qcs4w( /*B = */ 1, @@ -640,3 +817,87 @@ TEST_F(VulkanLinearQCS4WTest, test_vulkan_impl_gemm) { /*K = */ 32, /*N = */ 32); } + +// Test linear_qta8a_qga4w operator + +TEST_F( + VulkanLinearQTA8AQGA4WTest, + test_vulkan_linear_quant_gemm_custom_groupsize) { + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 2, + /*K = */ 8, + /*N = */ 8, + /*group_size = */ 8); + + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 2, + /*K = */ 16, + /*N = */ 8, + /*group_size = */ 8); +} + +TEST_F(VulkanLinearQTA8AQGA4WTest, test_vulkan_linear_quant_gemm) { + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 4, + /*K = */ 64, + /*N = */ 32); + + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 4, + /*K = */ 128, + /*N = */ 32); + + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 8, + /*K = */ 64, + /*N = */ 16); + + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 256, + /*K = */ 256, + /*N = */ 256); +} + +TEST_F( + VulkanLinearQTA8AQGA4WTest, + test_vulkan_linear_quant_gemv_custom_groupsize) { + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 1, + /*K = */ 8, + /*N = */ 8, + /*group_size = */ 8); + + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 1, + /*K = */ 16, + /*N = */ 8, + /*group_size = */ 8); +} + +TEST_F(VulkanLinearQTA8AQGA4WTest, test_vulkan_linear_quant_gemv) { + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 1, + /*K = */ 32, + /*N = */ 32); + + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 1, + /*K = */ 64, + /*N = */ 16); + + test_vulkan_linear_qta8a_qga4w( + /*B = */ 1, + /*M = */ 1, + /*K = */ 256, + /*N = */ 256); +}