diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl index faafa3fd266..94072dfbfea 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl @@ -27,9 +27,10 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")} ${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} $if MODE == "per_tensor": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + layout(push_constant) uniform restrict Block { - float scale; - int zero_point; int quant_min; int quant_max; }; @@ -146,7 +147,7 @@ void dequantize_per_tensor() { const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); IN_T qvalue = t_in[in_bufi]; - OUT_T value = dequantize_val(qvalue, scale, zero_point); + OUT_T value = dequantize_val(qvalue, t_scale[0], t_zero_point[0]); t_out[out_bufi] = value; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl index ef3f5cca869..5c978c61846 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl @@ -30,9 +30,10 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} ${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} $if MODE == "per_tensor": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + layout(push_constant) uniform restrict Block { - float scale; - int zero_point; int quant_min; int quant_max; }; @@ -148,7 +149,8 @@ void dequantize_per_tensor() { [[unroll]] for (int i = 0; i < 4; ++i) { IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, scale, zero_point); + OUT_T value = dequantize_val(qvalue, t_scale[0], t_zero_point[0]); + $if OUT_DTYPE == "double": outtex[i] = float(value); $else: diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp index 8845d6f6254..1578b515f55 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -85,8 +85,6 @@ void add_dequantize_per_tensor_node( add_dtype_suffix(kernel_name, graph.dtype_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(output)); - float scale_val = static_cast(graph.get_double(scale)); - int zero_point_val = static_cast(graph.get_int(zero_point)); int quant_min_val = static_cast(graph.get_int(quant_min)); int quant_max_val = static_cast(graph.get_int(quant_max)); @@ -100,23 +98,16 @@ void add_dequantize_per_tensor_node( graph.strides_ubo(input), graph.sizes_ubo(output), graph.strides_ubo(output)}; - push_constants = { - PushConstantDataInfo(&scale_val, sizeof(float)), - PushConstantDataInfo(&zero_point_val, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; } else { param_ubos = { graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; - push_constants = { - PushConstantDataInfo(&scale_val, sizeof(float)), - PushConstantDataInfo(&zero_point_val, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; } + push_constants = { + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + vkapi::SpecVarList spec_vars = { graph.hashed_layout_of(output), graph.hashed_layout_of(input), @@ -128,7 +119,9 @@ void add_dequantize_per_tensor_node( default_pick_global_wg_size, default_pick_local_wg_size, // Inputs and Outputs - {{output, vkapi::kWrite}, {input, vkapi::kRead}}, + {{output, vkapi::kWrite}, + {input, vkapi::kRead}, + {{scale, zero_point}, vkapi::kRead}}, // Shader param buffers param_ubos, // Push Constants @@ -517,7 +510,7 @@ void dequantize_per_channel_impl( REGISTER_OPERATORS { VK_REGISTER_OP( - quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, dequantize_per_tensor_impl); VK_REGISTER_OP( quantized_decomposed.dequantize_per_token.default, diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index cb9c04ee089..b4c4ac274dc 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -60,6 +60,16 @@ Tensor& dequantize_per_channel_out( executorch::aten::optional out_dtype, Tensor& out); +Tensor& dequantize_per_tensor_tensor_args_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out); + // Wrapper function for dequantize_per_tensor_out without context Tensor& dequantize_per_tensor_out_no_context( const Tensor& input, @@ -111,6 +121,20 @@ Tensor& dequantize_per_channel_out_no_context( out); } +// Wrapper function for dequantize_per_tensor_tensor_args_out without context +Tensor& dequantize_per_tensor_tensor_args_out_no_context( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out) { + return torch::executor::native::dequantize_per_tensor_tensor_args_out( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); +} + // ATen wrapper for dequantize_per_tensor at::Tensor dequantize_per_tensor_aten( const at::Tensor& input, @@ -195,6 +219,34 @@ at::Tensor dequantize_per_channel_aten( return out; } +// ATen wrapper for dequantize_per_tensor with tensor args +at::Tensor dequantize_per_tensor_tensor_args_aten( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + auto out = at::empty_like(input, out_dtype); + // Convert at::ScalarType to executorch::ScalarType + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); + + executorch::aten::optional opt_et_out_dtype(et_out_dtype); + + WRAP_TO_ATEN(dequantize_per_tensor_tensor_args_out_no_context, 7) + (input, + scale, + zero_point, + quant_min, + quant_max, + et_dtype, + opt_et_out_dtype, + out); + return out; +} + } // namespace native } // namespace executor } // namespace torch @@ -535,10 +587,10 @@ at::Tensor dequantize_per_channel_reference_impl( } // Forward declaration of implementation functions -void test_vulkan_dequantize_per_tensor_impl( +void test_vulkan_dequantize_per_token_impl( const std::vector& input_sizes, - float scale, - int zero_point, + const std::vector& scales, + const std::vector& zero_points, int64_t quant_min, int64_t quant_max, at::ScalarType dtype, @@ -546,10 +598,11 @@ void test_vulkan_dequantize_per_tensor_impl( const vkcompute::utils::StorageType in_storage, const vkcompute::utils::StorageType out_storage); -void test_vulkan_dequantize_per_token_impl( +void test_vulkan_dequantize_per_channel_impl( const std::vector& input_sizes, const std::vector& scales, const std::vector& zero_points, + int64_t axis, int64_t quant_min, int64_t quant_max, at::ScalarType dtype, @@ -557,11 +610,10 @@ void test_vulkan_dequantize_per_token_impl( const vkcompute::utils::StorageType in_storage, const vkcompute::utils::StorageType out_storage); -void test_vulkan_dequantize_per_channel_impl( +void test_vulkan_dequantize_per_tensor_tensor_impl( const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, + float scale, + int zero_point, int64_t quant_min, int64_t quant_max, at::ScalarType dtype, @@ -570,19 +622,19 @@ void test_vulkan_dequantize_per_channel_impl( const vkcompute::utils::StorageType out_storage); // Wrapper function to test both buffer and texture storage types -void test_vulkan_dequantize_per_tensor( +void test_vulkan_dequantize_per_token( const std::vector& input_sizes, - float scale, - int zero_point, + const std::vector& scales, + const std::vector& zero_points, int64_t quant_min, int64_t quant_max, at::ScalarType dtype, at::ScalarType out_dtype) { // Test with buffer storage - test_vulkan_dequantize_per_tensor_impl( + test_vulkan_dequantize_per_token_impl( input_sizes, - scale, - zero_point, + scales, + zero_points, quant_min, quant_max, dtype, @@ -597,10 +649,10 @@ void test_vulkan_dequantize_per_tensor( } // Test with texture storage - test_vulkan_dequantize_per_tensor_impl( + test_vulkan_dequantize_per_token_impl( input_sizes, - scale, - zero_point, + scales, + zero_points, quant_min, quant_max, dtype, @@ -610,19 +662,21 @@ void test_vulkan_dequantize_per_tensor( } // Wrapper function to test both buffer and texture storage types -void test_vulkan_dequantize_per_token( +void test_vulkan_dequantize_per_channel( const std::vector& input_sizes, const std::vector& scales, const std::vector& zero_points, + int64_t axis, int64_t quant_min, int64_t quant_max, at::ScalarType dtype, at::ScalarType out_dtype) { // Test with buffer storage - test_vulkan_dequantize_per_token_impl( + test_vulkan_dequantize_per_channel_impl( input_sizes, scales, zero_points, + axis, quant_min, quant_max, dtype, @@ -637,10 +691,11 @@ void test_vulkan_dequantize_per_token( } // Test with texture storage - test_vulkan_dequantize_per_token_impl( + test_vulkan_dequantize_per_channel_impl( input_sizes, scales, zero_points, + axis, quant_min, quant_max, dtype, @@ -650,21 +705,19 @@ void test_vulkan_dequantize_per_token( } // Wrapper function to test both buffer and texture storage types -void test_vulkan_dequantize_per_channel( +void test_vulkan_dequantize_per_tensor_tensor( const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, + float scale, + int zero_point, int64_t quant_min, int64_t quant_max, at::ScalarType dtype, at::ScalarType out_dtype) { // Test with buffer storage - test_vulkan_dequantize_per_channel_impl( + test_vulkan_dequantize_per_tensor_tensor_impl( input_sizes, - scales, - zero_points, - axis, + scale, + zero_point, quant_min, quant_max, dtype, @@ -679,11 +732,10 @@ void test_vulkan_dequantize_per_channel( } // Test with texture storage - test_vulkan_dequantize_per_channel_impl( + test_vulkan_dequantize_per_tensor_tensor_impl( input_sizes, - scales, - zero_points, - axis, + scale, + zero_point, quant_min, quant_max, dtype, @@ -775,151 +827,6 @@ void test_reference_dequantize_per_tensor( ASSERT_TRUE(output_correct); } -void test_vulkan_dequantize_per_tensor_impl( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage) { - check_dequantize_args(quant_min, quant_max, dtype, out_dtype); - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - - // Create a quantized input tensor with values from quant_min to quant_max - at::Tensor input; - if (dtype == at::kByte) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); - } else if (dtype == at::kChar) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); - } else if (dtype == at::kShort) { - input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); - } else if (dtype == at::kInt) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); - } else { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); - } - - // Fill with a simple pattern: values from quant_min to quant_max in steps - float step = 1.0f; - if (input.numel() > 1) { - step = static_cast(quant_max - quant_min) / (input.numel() - 1); - } - - auto flat_input = input.flatten(); - for (int i = 0; i < flat_input.numel(); i++) { - int64_t qvalue = quant_min + i * step; - if (dtype == at::kByte) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - flat_input[i] = static_cast(qvalue); - } - } - - // Reshape back to original dimensions - input = flat_input.reshape(input_sizes_int64); - - // Get reference output - at::Tensor reference_out = - torch::executor::native::dequantize_per_tensor_aten( - input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); - - // Build Vulkan dequantize_per_tensor graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(dtype), in_storage); - - const ValueRef r_scale = graph.add_scalar(scale); - const ValueRef r_zero_point = graph.add_scalar(zero_point); - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - - const ValueRef r_out = graph.add_tensor( - input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); - - const ValueRef r_dtype = - graph.add_scalar(static_cast(out_dtype)); - - VK_GET_OP_FN("quantized_decomposed.dequantize_per_tensor.default") - (graph, - { - r_input.value, - r_scale, - r_zero_point, - r_quant_min, - r_quant_max, - r_dtype, - r_dtype, - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - graph.encode_prepack(); - graph.prepack(); - graph.encode_execute(); - - // Run Vulkan dequantize_per_tensor - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - graph.execute(); - - at::Tensor vk_out = at::empty_like(reference_out).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare outputs with appropriate tolerance for half precision - bool output_correct; - if (out_dtype == at::kHalf) { - // Use higher tolerance for half precision due to limited precision - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); - } else { - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); - } - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale: " << scale << std::endl; - std::cout << " zero_point: " << zero_point << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - std::cout << " input dtype: " << dtype << std::endl; - std::cout << " output dtype: " << out_dtype << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_out << std::endl; - std::cout << "vulkan:" << std::endl; - std::cout << vk_out << std::endl; - } - - ASSERT_TRUE(output_correct); -} - TEST( VulkanDequantizePerTensorTest, test_reference_dequantize_per_tensor_uint8_to_float) { @@ -974,128 +881,21 @@ TEST( TEST( VulkanDequantizePerTensorTest, - test_reference_dequantize_per_tensor_int32_to_half) { - test_reference_dequantize_per_tensor( - {2, 6, 5}, // input sizes - 0.3, // scale - -10, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTensorTest, - test_vulkan_dequantize_per_tensor_uint8_to_float) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_dequantize_per_tensor( - {2, 3, 4}, // input sizes - 0.1, // scale - 5, // zero_point - 0, // quant_min - 255, // quant_max - at::kByte, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTest, - test_vulkan_dequantize_per_tensor_int8_to_float) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_dequantize_per_tensor( - {3, 4}, // input sizes - 0.05, // scale - 0, // zero_point - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTest, - test_vulkan_dequantize_per_tensor_int32_to_float) { - test_vulkan_dequantize_per_tensor( - {2, 4, 3, 12}, // input sizes - 0.0001, // scale - 100, // zero_point - -2147483648, // quant_min - 2147483647, // quant_max - at::kInt, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTest, - test_vulkan_dequantize_per_tensor_int8_to_half) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_float16_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_dequantize_per_tensor( - {2, 3}, // input sizes - 0.05, // scale - 10, // zero_point - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTensorTest, - test_vulkan_dequantize_per_tensor_int32_to_half) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_float16_buffers_support()) { - GTEST_SKIP(); - } - // Use much smaller scale to avoid overflow to infinity in half precision - // Half precision max value is ~65504, so with int32 values around 2e9, - // we need scales smaller than 65504/2e9 ≈ 3e-5 to avoid overflow - test_vulkan_dequantize_per_tensor( - {7}, // input sizes - 1e-5, // scale (much smaller to avoid overflow) - 5, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTensorTest, - test_vulkan_dequantize_per_tensor_int8_to_double) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_dequantize_per_tensor( - {2, 3}, // input sizes - 0.05, // scale - 10, // zero_point - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kDouble); // output dtype + test_reference_dequantize_per_tensor_int32_to_half) { + test_reference_dequantize_per_tensor( + {2, 6, 5}, // input sizes + 0.3, // scale + -10, // zero_point + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kHalf); // output dtype } +// No Vulkan tests for quantized_decomposed.dequantize_per_tensor.default +// because it is not going to be implemented in Vulkan since we will +// be handling any future calls to this op via the export stage + void test_reference_dequantize_per_token( const std::vector& input_sizes, const std::vector& scales, @@ -2424,3 +2224,274 @@ TEST( at::kByte, at::kDouble); } + +void test_vulkan_dequantize_per_tensor_tensor_impl( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + + // Create a quantized input tensor with values from quant_min to quant_max + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + float step = 1.0f; + if (input.numel() > 1) { + step = static_cast(quant_max - quant_min) / (input.numel() - 1); + } + + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + flat_input[i] = static_cast(qvalue); + } + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + // Create scale and zero_point as tensors (single element tensors) + at::Tensor scale_tensor = + at::tensor({scale}, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_tensor = + at::tensor({zero_point}, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output using tensor variant + at::Tensor reference_out = + torch::executor::native::dequantize_per_tensor_tensor_args_aten( + input, + scale_tensor, + zero_point_tensor, + quant_min, + quant_max, + dtype, + out_dtype); + + // Build Vulkan dequantize_per_tensor.tensor graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(dtype), in_storage); + + // Add scale and zero_point as tensor inputs (buffer storage, width packed) + IOValueRef r_scale = graph.add_input_tensor( + scale_tensor.sizes().vec(), + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); + IOValueRef r_zero_point = graph.add_input_tensor( + zero_point_tensor.sizes().vec(), + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); + + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); + + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + const ValueRef r_out_dtype = + graph.add_scalar(static_cast(out_dtype)); + + VK_GET_OP_FN("quantized_decomposed.dequantize_per_tensor.tensor") + (graph, + { + r_input.value, + r_scale.value, + r_zero_point.value, + r_quant_min, + r_quant_max, + r_dtype, + r_out_dtype, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Run Vulkan dequantize_per_tensor.tensor + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Convert scale tensor to float and copy to GPU + at::Tensor scale_float = scale_tensor.to(at::kFloat); + graph.copy_into_staging( + r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); + + // Convert zero_point tensor to int and copy to GPU + at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); + graph.copy_into_staging( + r_zero_point.staging, + zero_point_int.const_data_ptr(), + zero_point_int.numel()); + + graph.execute(); + + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs with appropriate tolerance for half precision + bool output_correct; + if (out_dtype == at::kHalf) { + // Use higher tolerance for half precision due to limited precision + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); + } else { + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); + } + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale: " << scale << std::endl; + std::cout << " zero_point: " << zero_point << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "vulkan:" << std::endl; + std::cout << vk_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +TEST( + VulkanDequantizePerTensorTensorTest, + test_vulkan_dequantize_per_tensor_tensor_int8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor_tensor( + {2, 3, 4}, // input sizes + 0.01, // scale + 1, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTensorTest, + test_vulkan_dequantize_per_tensor_tensor_uint8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor_tensor( + {2, 3, 4, 12}, // input sizes + 0.1, // scale + 5, // zero_point + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTensorTest, + test_vulkan_dequantize_per_tensor_tensor_int32_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor_tensor( + {2, 3}, // input sizes + 0.01, // scale + 12, // zero_point + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTensorTest, + test_vulkan_dequantize_per_tensor_tensor_uint8_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor_tensor( + {3, 4}, // input sizes + 0.3, // scale + 2, // zero_point + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTensorTensorTest, + test_vulkan_dequantize_per_tensor_tensor_int8_to_double) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor_tensor( + {2, 3, 4}, // input sizes + 0.03, // scale + -2, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kDouble); // output dtype +}