From 10512c5ea9cb243a98e10f689af160b53e640b7d Mon Sep 17 00:00:00 2001 From: morelos Date: Thu, 3 Jul 2025 11:17:31 -0700 Subject: [PATCH] [ET-VK][Ops] dequantize_per_tensor.tensor variant # Context We need a tensor variant for dequantize/quantize operators since that is the expected output of choose_qparams. # Changes This extends the logic that currently exists to support a tensor variant for scales and zeros. Differential Revision: [D77746135](https://our.internmc.facebook.com/intern/diff/D77746135/) [ghstack-poisoned] --- .../graph/ops/glsl/dequantize_buffer.glsl | 14 +- .../graph/ops/glsl/dequantize_buffer.yaml | 4 + .../graph/ops/glsl/dequantize_texture.glsl | 15 +- .../graph/ops/glsl/dequantize_texture.yaml | 4 + .../runtime/graph/ops/impl/Dequantize.cpp | 36 +- .../vulkan/test/op_tests/dequantize_test.cpp | 356 ++++++++++++++++++ 6 files changed, 416 insertions(+), 13 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl index faafa3fd266..d6ca802ae2a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl @@ -27,9 +27,14 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")} ${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} $if MODE == "per_tensor": + $if SHAPE == "tensor": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + layout(push_constant) uniform restrict Block { - float scale; - int zero_point; + $if SHAPE == "scalar": + float scale; + int zero_point; int quant_min; int quant_max; }; @@ -146,7 +151,10 @@ void dequantize_per_tensor() { const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); IN_T qvalue = t_in[in_bufi]; - OUT_T value = dequantize_val(qvalue, scale, zero_point); + $if SHAPE == "scalar": + OUT_T value = dequantize_val(qvalue, scale, zero_point); + $if SHAPE == "tensor": + OUT_T value = dequantize_val(qvalue, t_scale[0], t_zero_point[0]); t_out[out_bufi] = value; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml index b9a53217452..9e08471076e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml @@ -3,6 +3,7 @@ dequantize_buffer: IN_DTYPE: int32 OUT_DTYPE: float MODE: per_tensor + SHAPE: tensor generate_variant_forall: IN_DTYPE: - VALUE: uint8 @@ -15,6 +16,9 @@ dequantize_buffer: shader_variants: - NAME: dequantize_per_tensor_buffer MODE: per_tensor + SHAPE: scalar + - NAME: dequantize_per_tensor_tensor_buffer + MODE: per_tensor - NAME: dequantize_per_token_buffer MODE: per_token - NAME: dequantize_per_channel_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl index ef3f5cca869..17cf60e1d63 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl @@ -30,9 +30,14 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} ${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} $if MODE == "per_tensor": + $if SHAPE == "tensor": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + layout(push_constant) uniform restrict Block { - float scale; - int zero_point; + $if SHAPE == "scalar": + float scale; + int zero_point; int quant_min; int quant_max; }; @@ -148,7 +153,11 @@ void dequantize_per_tensor() { [[unroll]] for (int i = 0; i < 4; ++i) { IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, scale, zero_point); + $if SHAPE == "scalar": + OUT_T value = dequantize_val(qvalue, scale, zero_point); + $if SHAPE == "tensor": + OUT_T value = dequantize_val(qvalue, t_scale[0], t_zero_point[0]); + $if OUT_DTYPE == "double": outtex[i] = float(value); $else: diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml index 88ccc6e3274..ca7b7f4d6ab 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml @@ -3,6 +3,7 @@ dequantize_texture: IN_DTYPE: int32 OUT_DTYPE: float MODE: per_tensor + SHAPE: tensor generate_variant_forall: IN_DTYPE: - VALUE: uint8 @@ -15,6 +16,9 @@ dequantize_texture: shader_variants: - NAME: dequantize_per_tensor_texture3d MODE: per_tensor + SHAPE: scalar + - NAME: dequantize_per_tensor_tensor_texture3d + MODE: per_tensor - NAME: dequantize_per_token_texture3d MODE: per_token - NAME: dequantize_per_channel_texture3d diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp index 4000a9cd06e..355410766be 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -78,13 +78,23 @@ void add_dequantize_per_tensor_node( const ValueRef& quant_min, const ValueRef& quant_max, const ValueRef& output) { + const bool is_tensor_scale_zp = + graph.val_is_tensor(scale) && graph.val_is_tensor(zero_point); + std::string kernel_name("dequantize_per_tensor"); + if (is_tensor_scale_zp) { + kernel_name += "_tensor"; + } add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(input)); add_dtype_suffix(kernel_name, graph.dtype_of(output)); - float scale_val = static_cast(graph.get_double(scale)); - int zero_point_val = static_cast(graph.get_int(zero_point)); + float scale_val = 1.0; + int zero_point_val = 0; + if (!is_tensor_scale_zp) { + scale_val = static_cast(graph.get_double(scale)); + zero_point_val = static_cast(graph.get_int(zero_point)); + } int quant_min_val = static_cast(graph.get_int(quant_min)); int quant_max_val = static_cast(graph.get_int(quant_max)); @@ -98,15 +108,17 @@ void add_dequantize_per_tensor_node( graph.strides_ubo(input), graph.sizes_ubo(output), graph.strides_ubo(output)}; + } else { + param_ubos = { + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; + } + + if (is_tensor_scale_zp) { push_constants = { - PushConstantDataInfo(&scale_val, sizeof(float)), - PushConstantDataInfo(&zero_point_val, sizeof(int)), PushConstantDataInfo(&quant_min_val, sizeof(int)), PushConstantDataInfo(&quant_max_val, sizeof(int)), }; } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; push_constants = { PushConstantDataInfo(&scale_val, sizeof(float)), PushConstantDataInfo(&zero_point_val, sizeof(int)), @@ -120,13 +132,20 @@ void add_dequantize_per_tensor_node( graph.hashed_layout_of(input), }; + std::vector inputs_and_outputs = { + {output, vkapi::kWrite}, {input, vkapi::kRead}}; + if (is_tensor_scale_zp) { + inputs_and_outputs.emplace_back( + ArgGroup{{scale, zero_point}, vkapi::kRead}); + } + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), default_pick_global_wg_size, default_pick_local_wg_size, // Inputs and Outputs - {{output, vkapi::kWrite}, {input, vkapi::kRead}}, + inputs_and_outputs, // Shader param buffers param_ubos, // Push Constants @@ -517,6 +536,9 @@ REGISTER_OPERATORS { VK_REGISTER_OP( quantized_decomposed.dequantize_per_tensor.default, dequantize_per_tensor_impl); + VK_REGISTER_OP( + quantized_decomposed.dequantize_per_tensor.tensor, + dequantize_per_tensor_impl); VK_REGISTER_OP( quantized_decomposed.dequantize_per_token.default, dequantize_per_token_impl); diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index 53c74293d9a..87e0a3b5167 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -60,6 +60,16 @@ Tensor& dequantize_per_channel_out( executorch::aten::optional out_dtype, Tensor& out); +Tensor& dequantize_per_tensor_tensor_args_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out); + // Wrapper function for dequantize_per_tensor_out without context Tensor& dequantize_per_tensor_out_no_context( const Tensor& input, @@ -103,6 +113,20 @@ Tensor& dequantize_per_channel_out_no_context( input, scale, zero_points, axis, quant_min, quant_max, dtype, out_dtype, out); } +// Wrapper function for dequantize_per_tensor_tensor_args_out without context +Tensor& dequantize_per_tensor_tensor_args_out_no_context( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out) { + return torch::executor::native::dequantize_per_tensor_tensor_args_out( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); +} + // ATen wrapper for dequantize_per_tensor at::Tensor dequantize_per_tensor_aten( const at::Tensor& input, @@ -187,6 +211,34 @@ at::Tensor dequantize_per_channel_aten( return out; } +// ATen wrapper for dequantize_per_tensor with tensor args +at::Tensor dequantize_per_tensor_tensor_args_aten( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + auto out = at::empty_like(input, out_dtype); + // Convert at::ScalarType to executorch::ScalarType + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); + + executorch::aten::optional opt_et_out_dtype(et_out_dtype); + + WRAP_TO_ATEN(dequantize_per_tensor_tensor_args_out_no_context, 7) + (input, + scale, + zero_point, + quant_min, + quant_max, + et_dtype, + opt_et_out_dtype, + out); + return out; +} + } // namespace native } // namespace executor } // namespace torch @@ -555,6 +607,17 @@ void test_vulkan_dequantize_per_channel_impl( const vkcompute::utils::StorageType in_storage, const vkcompute::utils::StorageType out_storage); +void test_vulkan_dequantize_per_tensor_tensor_impl( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + // Wrapper function to test both buffer and texture storage types void test_vulkan_dequantize_per_tensor( const std::vector& input_sizes, @@ -678,6 +741,46 @@ void test_vulkan_dequantize_per_channel( vkcompute::utils::kTexture3D); } +// Wrapper function to test both buffer and texture storage types +void test_vulkan_dequantize_per_tensor_tensor( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Test with buffer storage + test_vulkan_dequantize_per_tensor_tensor_impl( + input_sizes, + scale, + zero_point, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Telling the system to expect a float instead of a double + // since the shader can only return 32bit anyways + if (out_dtype == at::kDouble) { + out_dtype = at::kFloat; + } + + // Test with texture storage + test_vulkan_dequantize_per_tensor_tensor_impl( + input_sizes, + scale, + zero_point, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + void test_reference_dequantize_per_tensor( const std::vector& input_sizes, float scale, @@ -2345,3 +2448,256 @@ TEST(VulkanDequantizePerChannelTest, test_vulkan_dequantize_per_channel_8bit_to_ at::kByte, at::kDouble); } + +void test_vulkan_dequantize_per_tensor_tensor_impl( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + + // Create a quantized input tensor with values from quant_min to quant_max + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + float step = 1.0f; + if (input.numel() > 1) { + step = static_cast(quant_max - quant_min) / (input.numel() - 1); + } + + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + flat_input[i] = static_cast(qvalue); + } + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + // Create scale and zero_point as tensors (single element tensors) + at::Tensor scale_tensor = + at::tensor({scale}, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_tensor = + at::tensor({zero_point}, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output using tensor variant + at::Tensor reference_out = torch::executor::native::dequantize_per_tensor_tensor_args_aten( + input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype, out_dtype); + + // Build Vulkan dequantize_per_tensor.tensor graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(dtype), in_storage); + + // Add scale and zero_point as tensor inputs (buffer storage, width packed) + IOValueRef r_scale = graph.add_input_tensor( + scale_tensor.sizes().vec(), + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); + IOValueRef r_zero_point = graph.add_input_tensor( + zero_point_tensor.sizes().vec(), + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); + + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); + + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + const ValueRef r_out_dtype = + graph.add_scalar(static_cast(out_dtype)); + + VK_GET_OP_FN("quantized_decomposed.dequantize_per_tensor.tensor") + (graph, + { + r_input.value, + r_scale.value, + r_zero_point.value, + r_quant_min, + r_quant_max, + r_dtype, + r_out_dtype, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Run Vulkan dequantize_per_tensor.tensor + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Convert scale tensor to float and copy to GPU + at::Tensor scale_float = scale_tensor.to(at::kFloat); + graph.copy_into_staging( + r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); + + // Convert zero_point tensor to int and copy to GPU + at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); + graph.copy_into_staging( + r_zero_point.staging, + zero_point_int.const_data_ptr(), + zero_point_int.numel()); + + graph.execute(); + + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs with appropriate tolerance for half precision + bool output_correct; + if (out_dtype == at::kHalf) { + // Use higher tolerance for half precision due to limited precision + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); + } else { + output_correct = at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); + } + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale: " << scale << std::endl; + std::cout << " zero_point: " << zero_point << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "vulkan:" << std::endl; + std::cout << vk_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +TEST(VulkanDequantizePerTensorTensorTest, test_vulkan_dequantize_per_tensor_tensor_int8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor_tensor( + {2, 3, 4}, // input sizes + 0.01, // scale + 1, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST(VulkanDequantizePerTensorTensorTest, test_vulkan_dequantize_per_tensor_tensor_uint8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor_tensor( + {2, 3, 4, 12}, // input sizes + 0.1, // scale + 5, // zero_point + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST(VulkanDequantizePerTensorTensorTest, test_vulkan_dequantize_per_tensor_tensor_int32_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor_tensor( + {2, 3}, // input sizes + 0.01, // scale + 12, // zero_point + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kFloat); // output dtype +} + +TEST(VulkanDequantizePerTensorTensorTest, test_vulkan_dequantize_per_tensor_tensor_uint8_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor_tensor( + {3, 4}, // input sizes + 0.3, // scale + 2, // zero_point + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kHalf); // output dtype +} + +TEST(VulkanDequantizePerTensorTensorTest, test_vulkan_dequantize_per_tensor_tensor_int8_to_double) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor_tensor( + {2, 3, 4}, // input sizes + 0.03, // scale + -2, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kDouble); // output dtype +}