From 6ecf0d1e6f26a260d5cbde72bbf3b47b26c95b97 Mon Sep 17 00:00:00 2001 From: morelos Date: Thu, 3 Jul 2025 11:17:28 -0700 Subject: [PATCH] [ET-VK][Ops] quantize_per_tensor.tensor variant # Context We need a tensor variant for dequantize/quantize operators since that is the expected output of choose_qparams. # Changes This extends the logic that currently exists to support a tensor variant for scales and zeros. Differential Revision: [D77746136](https://our.internmc.facebook.com/intern/diff/D77746136/) [ghstack-poisoned] --- .../graph/ops/glsl/quantize_buffer.glsl | 14 +- .../graph/ops/glsl/quantize_buffer.yaml | 4 + .../graph/ops/glsl/quantize_texture.glsl | 15 +- .../graph/ops/glsl/quantize_texture.yaml | 4 + .../runtime/graph/ops/impl/Quantize.cpp | 62 ++- .../vulkan/test/op_tests/quantize_test.cpp | 355 +++++++++++++++++- 6 files changed, 413 insertions(+), 41 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl index c3e58286efe..ada34223696 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl @@ -27,9 +27,14 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")} ${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} $if MODE == "per_tensor": + $if SHAPE == "tensor": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + layout(push_constant) uniform restrict Block { - float scale; - int zero_point; + $if SHAPE == "scalar": + float scale; + int zero_point; int quant_min; int quant_max; }; @@ -142,7 +147,10 @@ void quantize_per_tensor() { const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); IN_T value = t_in[in_bufi]; - OUT_T qvalue = quantize_val(value, scale, zero_point); + $if SHAPE == "scalar": + OUT_T qvalue = quantize_val(value, scale, zero_point); + $if SHAPE == "tensor": + OUT_T qvalue = quantize_val(value, t_scale[0], t_zero_point[0]); t_out[out_bufi] = qvalue; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml index 1dd8e6e2ffe..abab883861f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml @@ -3,6 +3,7 @@ quantize_buffer: IN_DTYPE: float OUT_DTYPE: int32 MODE: per_tensor + SHAPE: tensor generate_variant_forall: IN_DTYPE: - VALUE: half @@ -15,6 +16,9 @@ quantize_buffer: shader_variants: - NAME: quantize_per_tensor_buffer MODE: per_tensor + SHAPE: scalar + - NAME: quantize_per_tensor_tensor_buffer + MODE: per_tensor - NAME: quantize_per_token_buffer MODE: per_token - NAME: quantize_per_channel_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl index bdaba3ffaf9..d76c059c145 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl @@ -17,6 +17,7 @@ #define IVEC4_T ${texel_load_type(OUT_DTYPE, "texture3d")} #define ${MODE} +#define ${SHAPE} ${define_active_storage_type("texture3d")} ${define_required_extensions(IN_DTYPE)} @@ -32,9 +33,14 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} ${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} $if MODE == "per_tensor": + $if SHAPE == "tensor": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + layout(push_constant) uniform restrict Block { - float scale; - int zero_point; + $if SHAPE == "scalar": + float scale; + int zero_point; int quant_min; int quant_max; }; @@ -146,7 +152,10 @@ void quantize_per_tensor() { [[unroll]] for (int i = 0; i < 4; ++i) { IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, scale, zero_point); + $if SHAPE == "scalar": + OUT_T qvalue = quantize_val(value, scale, zero_point); + $if SHAPE == "tensor": + OUT_T qvalue = quantize_val(value, t_scale[0], t_zero_point[0]); outtex[i] = qvalue; } write_texel(t_out, pos, outtex); diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml index 47e532be8b9..5b9d8da4f2e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml @@ -3,6 +3,7 @@ quantize_texture: IN_DTYPE: float OUT_DTYPE: int32 MODE: per_tensor + SHAPE: tensor generate_variant_forall: IN_DTYPE: - VALUE: half @@ -15,6 +16,9 @@ quantize_texture: shader_variants: - NAME: quantize_per_tensor_texture3d MODE: per_tensor + SHAPE: scalar + - NAME: quantize_per_tensor_tensor_texture3d + MODE: per_tensor - NAME: quantize_per_token_texture3d MODE: per_token - NAME: quantize_per_channel_texture3d diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp index f08df99373f..59db9143756 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp @@ -51,17 +51,19 @@ utils::uvec3 quantize_per_channel_local_wg_size( const ValueRef input = args.at(1).refs.at(0); - utils::uvec3 local_wg_size = graph->create_local_wg_size(global_workgroup_size); - - // WORKAROUND: The CommandBuffer::dispatch function divides global_workgroup_size - // by local_workgroup_size to get the number of workgroups to dispatch. - // For per-channel quantization along the batch axis, we need to ensure that - // we dispatch the correct number of workgroups in the Z dimension to cover - // all batch-channel combinations. + utils::uvec3 local_wg_size = + graph->create_local_wg_size(global_workgroup_size); + + // WORKAROUND: The CommandBuffer::dispatch function divides + // global_workgroup_size by local_workgroup_size to get the number of + // workgroups to dispatch. For per-channel quantization along the batch axis, + // we need to ensure that we dispatch the correct number of workgroups in the + // Z dimension to cover all batch-channel combinations. // - // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], local_wg_size[2]) - // might reduce the number of workgroups dispatched. To ensure we dispatch - // global_workgroup_size[2] workgroups in the Z dimension, we set local_wg_size[2] = 1. + // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], + // local_wg_size[2]) might reduce the number of workgroups dispatched. To + // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, + // we set local_wg_size[2] = 1. const auto input_sizes = graph->sizes_of(input); if (global_workgroup_size[2] > 1 && input_sizes[3] > 0) { local_wg_size[2] = 1; @@ -78,13 +80,23 @@ void add_quantize_per_tensor_node( const ValueRef& quant_min, const ValueRef& quant_max, const ValueRef& output) { + const bool is_tensor_scale_zp = + graph.val_is_tensor(scale) && graph.val_is_tensor(zero_point); + std::string kernel_name("quantize_per_tensor"); + if (is_tensor_scale_zp) { + kernel_name += "_tensor"; + } add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(output)); - float scale_val = static_cast(graph.get_double(scale)); - int zero_point_val = static_cast(graph.get_int(zero_point)); + float scale_val = 1.0; + int zero_point_val = 0; + if (!is_tensor_scale_zp) { + scale_val = static_cast(graph.get_double(scale)); + zero_point_val = static_cast(graph.get_int(zero_point)); + } int quant_min_val = static_cast(graph.get_int(quant_min)); int quant_max_val = static_cast(graph.get_int(quant_max)); @@ -98,15 +110,17 @@ void add_quantize_per_tensor_node( graph.strides_ubo(input), graph.sizes_ubo(output), graph.strides_ubo(output)}; + } else { + param_ubos = { + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; + } + + if (is_tensor_scale_zp) { push_constants = { - PushConstantDataInfo(&scale_val, sizeof(float)), - PushConstantDataInfo(&zero_point_val, sizeof(int)), PushConstantDataInfo(&quant_min_val, sizeof(int)), PushConstantDataInfo(&quant_max_val, sizeof(int)), }; } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; push_constants = { PushConstantDataInfo(&scale_val, sizeof(float)), PushConstantDataInfo(&zero_point_val, sizeof(int)), @@ -120,13 +134,20 @@ void add_quantize_per_tensor_node( graph.hashed_layout_of(input), }; + std::vector inputs_and_outputs = { + {output, vkapi::kWrite}, {input, vkapi::kRead}}; + if (is_tensor_scale_zp) { + inputs_and_outputs.emplace_back( + ArgGroup{{scale, zero_point}, vkapi::kRead}); + } + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), default_pick_global_wg_size, default_pick_local_wg_size, // Inputs and Outputs - {{output, vkapi::kWrite}, {input, vkapi::kRead}}, + inputs_and_outputs, // Shader param buffers param_ubos, // Push Constants @@ -241,8 +262,8 @@ void add_quantize_per_channel_node( int num_channels; if (axis_val == 0 && ndim == 4 && !graph.is_buffer_storage(input)) { - // For batch dimension quantization in 4D tensors, pass the actual number of channels - // so the shader can correctly unfold the batch-channel folding + // For batch dimension quantization in 4D tensors, pass the actual number of + // channels so the shader can correctly unfold the batch-channel folding num_channels = static_cast(input_sizes[1]); // Channel dimension } else { num_channels = static_cast(input_sizes[axis_val]); @@ -487,6 +508,9 @@ REGISTER_OPERATORS { VK_REGISTER_OP( quantized_decomposed.quantize_per_tensor.default, quantize_per_tensor_impl); + VK_REGISTER_OP( + quantized_decomposed.quantize_per_tensor.tensor, + quantize_per_tensor_impl); VK_REGISTER_OP( quantized_decomposed.quantize_per_token.default, quantize_per_token_impl); VK_REGISTER_OP( diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp index 51b07ac106b..c8191a654bb 100644 --- a/backends/vulkan/test/op_tests/quantize_test.cpp +++ b/backends/vulkan/test/op_tests/quantize_test.cpp @@ -58,6 +58,15 @@ Tensor& quantize_per_channel_out( ScalarType dtype, Tensor& out); +Tensor& quantize_per_tensor_tensor_args_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out); + // Wrapper function for quantize_per_tensor_out without context Tensor& quantize_per_tensor_out_no_context( const Tensor& input, @@ -98,6 +107,19 @@ Tensor& quantize_per_channel_out_no_context( input, scale, zero_point, axis, quant_min, quant_max, dtype, out); } +// Wrapper function for quantize_per_tensor_tensor_args_out without context +Tensor& quantize_per_tensor_tensor_args_out_no_context( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + return torch::executor::native::quantize_per_tensor_tensor_args_out( + input, scale, zero_point, quant_min, quant_max, dtype, out); +} + // ATen wrapper for quantize_per_tensor at::Tensor quantize_per_tensor_aten( const at::Tensor& input, @@ -147,6 +169,22 @@ at::Tensor quantize_per_channel_aten( return out; } +// ATen wrapper for quantize_per_tensor with tensor args +at::Tensor quantize_per_tensor_tensor_args_aten( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + auto out = at::empty_like(input, dtype); + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + WRAP_TO_ATEN(quantize_per_tensor_tensor_args_out_no_context, 6) + (input, scale, zero_point, quant_min, quant_max, et_dtype, out); + return out; +} + } // namespace native } // namespace executor } // namespace torch @@ -485,6 +523,17 @@ void test_vulkan_quantize_per_channel_impl( const vkcompute::utils::StorageType in_storage, const vkcompute::utils::StorageType out_storage); +void test_vulkan_quantize_per_tensor_tensor_impl( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + // Wrapper function to test both buffer and texture storage types void test_vulkan_quantize_per_tensor( const std::vector& input_sizes, @@ -607,6 +656,46 @@ void test_vulkan_quantize_per_channel( vkcompute::utils::kTexture3D); } +// Wrapper function to test both buffer and texture storage types +void test_vulkan_quantize_per_tensor_tensor( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + // Test with buffer storage + test_vulkan_quantize_per_tensor_tensor_impl( + input_sizes, + scale, + zero_point, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // If the in_dtype is a double, convert to float for texture implementation + // since they don't support 64bit as inputs + if (in_dtype == at::kDouble) { + in_dtype = at::kFloat; + } + + // Test with texture storage + test_vulkan_quantize_per_tensor_tensor_impl( + input_sizes, + scale, + zero_point, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + void test_reference_quantize_per_tensor( const std::vector& input_sizes, float scale, @@ -746,8 +835,10 @@ void test_vulkan_quantize_per_tensor_impl( at::Tensor reference_int = reference_out.to(at::kInt); at::Tensor vk_int = vk_out.to(at::kInt); - // Tolerance is 1 to address rounding errors and fp math differences between CPU/GPU - const bool output_correct = at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); + // Tolerance is 1 to address rounding errors and fp math differences between + // CPU/GPU + const bool output_correct = + at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); if (!output_correct) { at::Tensor diffs = at::abs(reference_int - vk_int); @@ -1123,8 +1214,10 @@ void test_vulkan_quantize_per_token_impl( at::Tensor reference_int = reference_out.to(at::kInt); at::Tensor vk_int = vk_out.to(at::kInt); - // Tolerance is 1 to address rounding errors and fp math differences between CPU/GPU - const bool output_correct = at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); + // Tolerance is 1 to address rounding errors and fp math differences between + // CPU/GPU + const bool output_correct = + at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); if (!output_correct) { at::Tensor diffs = at::abs(reference_int - vk_int); @@ -1244,9 +1337,7 @@ TEST( at::kByte); } -TEST( - VulkanQuantizePerTokenTest, - test_vulkan_quantize_per_token_float_to_int8) { +TEST(VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_int8) { if (!vkcompute::api::context() ->adapter_ptr() ->has_full_int8_buffers_support()) { @@ -1606,8 +1697,10 @@ void test_vulkan_quantize_per_channel_impl( at::Tensor reference_int = reference_out.to(at::kInt); at::Tensor vk_int = vk_out.to(at::kInt); - // Tolerance is 1 to address rounding errors and fp math differences between CPU/GPU - const bool output_correct = at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); + // Tolerance is 1 to address rounding errors and fp math differences between + // CPU/GPU + const bool output_correct = + at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); if (!output_correct) { at::Tensor diffs = at::abs(reference_int - vk_int); @@ -1717,7 +1810,9 @@ TEST( // END OF REFERENCE TESTS -TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int8_axis0) { +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_float_to_int8_axis0) { std::vector scales(9, 0.1f); std::vector zero_points(9, 2); @@ -1777,7 +1872,9 @@ TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int at::kChar); } -TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int8_axis1) { +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_float_to_int8_axis1) { std::vector scales(14, 0.001f); std::vector zero_points(14, -5); @@ -1826,7 +1923,9 @@ TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int at::kChar); } -TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int8_axis2) { +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_float_to_int8_axis2) { std::vector scales(11, 0.5f); std::vector zero_points(11, 12); @@ -1864,7 +1963,9 @@ TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int at::kChar); } -TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int8_axis3) { +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_float_to_int8_axis3) { std::vector scales(7, 0.5f); std::vector zero_points(7, 12); @@ -1891,7 +1992,9 @@ TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_int at::kChar); } -TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_uint8_comprehensive) { +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_float_to_uint8_comprehensive) { std::vector scales = {0.1, 0.2, 0.0001, 0.5, 0.02}; std::vector zero_points = {0, 5, -5, 1, 12}; @@ -1951,7 +2054,9 @@ TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_float_to_uin at::kByte); } -TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_half_to_8bit) { +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_half_to_8bit) { std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; std::vector zero_points = {0, 5, 5, 1, 12}; @@ -2011,7 +2116,9 @@ TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_half_to_8bit at::kByte); } -TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_double_to_8bit) { +TEST( + VulkanQuantizePerChannelTest, + test_vulkan_quantize_per_channel_double_to_8bit) { std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; std::vector zero_points = {0, 5, 5, 1, 12}; @@ -2070,3 +2177,219 @@ TEST(VulkanQuantizePerChannelTest, test_vulkan_quantize_per_channel_double_to_8b at::kDouble, at::kByte); } + +void test_vulkan_quantize_per_tensor_tensor_impl( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { + check_quantize_args(quant_min, quant_max, dtype); + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + + scale = scale < eps ? eps : scale; + + // Create scale and zero_point as tensors (single element tensors) + at::Tensor scale_tensor = + at::tensor({scale}, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_tensor = + at::tensor({zero_point}, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output using tensor variant + at::Tensor reference_out = torch::executor::native::quantize_per_tensor_tensor_args_aten( + input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); + + // Build Vulkan quantize_per_tensor.tensor graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + + // Add scale and zero_point as tensor inputs (buffer storage, width packed) + IOValueRef r_scale = graph.add_input_tensor( + scale_tensor.sizes().vec(), + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); + IOValueRef r_zero_point = graph.add_input_tensor( + zero_point_tensor.sizes().vec(), + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); + + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(dtype), out_storage); + + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN("quantized_decomposed.quantize_per_tensor.tensor") + (graph, + { + r_input.value, + r_scale.value, + r_zero_point.value, + r_quant_min, + r_quant_max, + r_dtype, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Run Vulkan quantize_per_tensor.tensor + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Convert scale tensor to float and copy to GPU + at::Tensor scale_float = scale_tensor.to(at::kFloat); + graph.copy_into_staging( + r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); + + // Convert zero_point tensor to int and copy to GPU + at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); + graph.copy_into_staging( + r_zero_point.staging, + zero_point_int.const_data_ptr(), + zero_point_int.numel()); + + graph.execute(); + + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + // For quantized types, we need to compare the actual integer values + at::Tensor reference_int = reference_out.to(at::kInt); + at::Tensor vk_int = vk_out.to(at::kInt); + + // Tolerance is 1 to address rounding errors and fp math differences between + // CPU/GPU + const bool output_correct = + at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); + if (!output_correct) { + at::Tensor diffs = at::abs(reference_int - vk_int); + + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale: " << scale << std::endl; + std::cout << " zero_point: " << zero_point << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_int << std::endl; + std::cout << "vulkan:" << std::endl; + std::cout << vk_int << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +TEST(VulkanQuantizePerTensorTensorTest, test_vulkan_quantize_per_tensor_tensor_float_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_quantize_per_tensor_tensor( + {2, 3, 4}, // input sizes + 0.01, // scale + 1, // zero_point + -128, // quant_min + 127, // quant_max + at::kFloat, // input dtype + at::kChar); // output dtype +} + +TEST(VulkanQuantizePerTensorTensorTest, test_vulkan_quantize_per_tensor_tensor_float_to_uint8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_quantize_per_tensor_tensor( + {2, 3, 4, 12}, // input sizes + 0.1, // scale + 5, // zero_point + 0, // quant_min + 255, // quant_max + at::kFloat, // input dtype + at::kByte); // output dtype +} + +TEST(VulkanQuantizePerTensorTensorTest, test_vulkan_quantize_per_tensor_tensor_float_to_int32) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_quantize_per_tensor_tensor( + {2, 3}, // input sizes + 0.01, // scale + 12, // zero_point + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kFloat, // input dtype + at::kInt); // output dtype +} + +TEST(VulkanQuantizePerTensorTensorTest, test_vulkan_quantize_per_tensor_tensor_half_to_uint8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_quantize_per_tensor_tensor( + {3, 4}, // input sizes + 0.3, // scale + 2, // zero_point + 0, // quant_min + 255, // quant_max + at::kHalf, // input dtype + at::kByte); // output dtype +} + +TEST(VulkanQuantizePerTensorTensorTest, test_vulkan_quantize_per_tensor_tensor_double_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_quantize_per_tensor_tensor( + {2, 3, 4}, // input sizes + 0.03, // scale + -2, // zero_point + -128, // quant_min + 127, // quant_max + at::kDouble, // input dtype + at::kChar); // output dtype +}