From 4d697b75693146ee69d26cae7a86d8198ecc6014 Mon Sep 17 00:00:00 2001 From: morelos Date: Thu, 3 Jul 2025 11:17:23 -0700 Subject: [PATCH] [ET-VK][Ops] dequantize_per_channel reference impl and testing # Context In order to properly enable dynamic quantization, we create the dequantize_per_channel operator as its seemingly useful to have for the pipeline. # Changes This creates the wrapper for the cpu reference implementation, and also a dummy reference implementation I created just to test against it. Differential Revision: [D77746138](https://our.internmc.facebook.com/intern/diff/D77746138/) [ghstack-poisoned] --- .../vulkan/test/op_tests/dequantize_test.cpp | 445 +++++++++++++++++- .../make_aten_functor_from_et_functor.h | 34 +- 2 files changed, 470 insertions(+), 9 deletions(-) diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index 82f316abe82..aaebf877668 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -49,6 +49,17 @@ Tensor& dequantize_per_token_out( ScalarType out_dtype, Tensor& out); +Tensor& dequantize_per_channel_out( + const Tensor& input, + const Tensor& scale, + const std::optional& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out); + // Wrapper function for dequantize_per_tensor_out without context Tensor& dequantize_per_tensor_out_no_context( const Tensor& input, @@ -77,6 +88,21 @@ Tensor& dequantize_per_token_out_no_context( input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out); } +// Wrapper function for dequantize_per_channel_out without context +Tensor& dequantize_per_channel_out_no_context( + const Tensor& input, + const Tensor& scale, + const std::optional& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out) { + return torch::executor::native::dequantize_per_channel_out( + input, scale, zero_points, axis, quant_min, quant_max, dtype, out_dtype, out); +} + // ATen wrapper for dequantize_per_tensor at::Tensor dequantize_per_tensor_aten( const at::Tensor& input, @@ -131,6 +157,36 @@ at::Tensor dequantize_per_token_aten( return out; } +// ATen wrapper for dequantize_per_channel +at::Tensor dequantize_per_channel_aten( + const at::Tensor& input, + const at::Tensor& scale, + const std::optional& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + auto out = at::empty_like(input, out_dtype); + // Convert at::ScalarType to executorch::ScalarType + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); + + executorch::aten::optional opt_et_out_dtype(et_out_dtype); + + WRAP_TO_ATEN(dequantize_per_channel_out_no_context, 8) + (input, + scale, + zero_points, + axis, + quant_min, + quant_max, + et_dtype, + opt_et_out_dtype, + out); + return out; +} + } // namespace native } // namespace executor } // namespace torch @@ -183,6 +239,40 @@ void check_dequantize_args( } } +/** + * Helper function to validate dequantize_per_channel arguments + * Similar to the validation in quantize_test.cpp + */ +void check_dequantize_per_channel_args( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis) { + // Normalize axis + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input_sizes.size(); + } + + ASSERT_GE(normalized_axis, 0) + << "axis " << axis << " is not legal, normalized axis " << normalized_axis + << " should be >= 0"; + + ASSERT_LT(normalized_axis, static_cast(input_sizes.size())) + << "axis " << axis << " is not legal, normalized axis " << normalized_axis + << " should be < input.dim() " << input_sizes.size(); + + int64_t num_channels = input_sizes[normalized_axis]; + + ASSERT_EQ(num_channels, static_cast(scales.size())) + << "Expected scales.size() to match input.size(axis) (" << num_channels + << "), but got " << scales.size(); + + ASSERT_EQ(num_channels, static_cast(zero_points.size())) + << "Expected zero_points.size() to match input.size(axis) (" + << num_channels << "), but got " << zero_points.size(); +} + // // Reference Implementation // @@ -322,6 +412,114 @@ at::Tensor dequantize_per_token_reference_impl( return out; } +/* + * Reference implementation of dequantize_per_channel + */ +at::Tensor dequantize_per_channel_reference_impl( + const at::Tensor& input, + const at::Tensor& scale, + const std::optional& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Normalize axis to handle negative values + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input.dim(); + } + + // Create output tensor with the same shape as input but with target dtype + at::Tensor output = at::empty_like(input, out_dtype); + + // Get the number of channels along the quantization axis + int64_t num_channels = input.size(normalized_axis); + + // Calculate strides for efficient indexing + std::vector input_strides; + std::vector input_sizes; + for (int64_t i = 0; i < input.dim(); i++) { + input_sizes.push_back(input.size(i)); + input_strides.push_back(input.stride(i)); + } + + // Get data pointers + const double* scale_data = scale.const_data_ptr(); + const int64_t* zero_point_data = nullptr; + if (zero_point.has_value()) { + zero_point_data = zero_point.value().const_data_ptr(); + } + + // Iterate through all elements in the tensor + int64_t total_elements = input.numel(); + + // Helper lambda to convert flat index to multi-dimensional coordinates + auto flat_to_coords = [&](int64_t flat_idx, std::vector& coords) { + int64_t remaining = flat_idx; + for (int64_t dim = input.dim() - 1; dim >= 0; dim--) { + coords[dim] = remaining % input_sizes[dim]; + remaining /= input_sizes[dim]; + } + }; + + // Process each element + std::vector coords(input.dim()); + for (int64_t flat_idx = 0; flat_idx < total_elements; flat_idx++) { + // Convert flat index to coordinates + flat_to_coords(flat_idx, coords); + + // Get the channel index for this element + int64_t channel_idx = coords[normalized_axis]; + + // Get the quantization parameters for this channel + double channel_scale = scale_data[channel_idx]; + int64_t channel_zero_point = 0; + if (zero_point_data != nullptr) { + channel_zero_point = zero_point_data[channel_idx]; + } + + // Store casted values to avoid repeated casting + const int32_t channel_zero_point_int32 = static_cast(channel_zero_point); + const float channel_scale_float = static_cast(channel_scale); + + // Get the input value and dequantize + double dequantized_value = 0.0; + + // Extract quantized value and dequantize based on input dtype + // Following the CPU implementation pattern: (input - zero_point) * scale + if (dtype == at::kByte) { + uint8_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float; + } else if (dtype == at::kChar) { + int8_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float; + } else if (dtype == at::kShort) { + int16_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float; + } else if (dtype == at::kInt) { + int32_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float; + } else if (dtype == at::kLong) { + int64_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = (qvalue - channel_zero_point_int32) * channel_scale_float; + } else { + throw std::runtime_error("Unsupported input dtype"); + } + + // Store the result based on output dtype + if (out_dtype == at::kFloat) { + output.flatten()[flat_idx] = static_cast(dequantized_value); + } else if (out_dtype == at::kDouble) { + output.flatten()[flat_idx] = dequantized_value; + } else if (out_dtype == at::kHalf) { + output.flatten()[flat_idx] = static_cast(dequantized_value); + } + } + + return output; +} + // Forward declaration of implementation functions void test_vulkan_dequantize_per_tensor_impl( const std::vector& input_sizes, @@ -345,6 +543,18 @@ void test_vulkan_dequantize_per_token_impl( const vkcompute::utils::StorageType in_storage, const vkcompute::utils::StorageType out_storage); +void test_vulkan_dequantize_per_channel_impl( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + // Wrapper function to test both buffer and texture storage types void test_vulkan_dequantize_per_tensor( const std::vector& input_sizes, @@ -425,6 +635,49 @@ void test_vulkan_dequantize_per_token( vkcompute::utils::kTexture3D); } +// Wrapper function to test both buffer and texture storage types +void test_vulkan_dequantize_per_channel( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Test with buffer storage + test_vulkan_dequantize_per_channel_impl( + input_sizes, + scales, + zero_points, + axis, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Telling the system to expect a float instead of a double + // since the shader can only return 32bit anyways + if (out_dtype == at::kDouble) { + out_dtype = at::kFloat; + } + + // Test with texture storage + test_vulkan_dequantize_per_channel_impl( + input_sizes, + scales, + zero_points, + axis, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + void test_reference_dequantize_per_tensor( const std::vector& input_sizes, float scale, @@ -625,7 +878,7 @@ void test_vulkan_dequantize_per_tensor_impl( output_correct = at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); } else { - output_correct = at::allclose(reference_out, vk_out); + output_correct = at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); } if (!output_correct) { std::cout << "\n" @@ -1105,7 +1358,7 @@ void test_vulkan_dequantize_per_token_impl( output_correct = at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); } else { - output_correct = at::allclose(reference_out, vk_out); + output_correct = at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); } if (!output_correct) { std::cout << "\n" @@ -1349,3 +1602,191 @@ TEST( at::kChar, // input dtype at::kDouble); // output dtype } + +void test_reference_dequantize_per_channel( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + check_dequantize_per_channel_args(input_sizes, scales, zero_points, axis); + + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + + // Create input tensor with quantized values + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + float step = 1.0f; + if (input.numel() > 1) { + step = static_cast(quant_max - quant_min) / (input.numel() - 1); + } + + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + flat_input[i] = static_cast(qvalue); + } + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor my_ref = dequantize_per_channel_reference_impl( + input, + scale_tensor, + zero_point_tensor, + axis, + quant_min, + quant_max, + dtype, + out_dtype); + + // Get implementation output + at::Tensor cpu_ref = torch::executor::native::dequantize_per_channel_aten( + input, + scale_tensor, + zero_point_tensor, + axis, + quant_min, + quant_max, + dtype, + out_dtype); + + // Compare outputs + const bool output_correct = at::allclose(my_ref, cpu_ref); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " axis: " << axis << std::endl; + std::cout << " input sizes:"; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << " " << input_sizes[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "cpu_ref:" << std::endl; + std::cout << cpu_ref << std::endl; + std::cout << "my_ref:" << std::endl; + std::cout << my_ref << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +TEST( + VulkanDequantizePerChannelTest, + test_reference_dequantize_per_channel_uint8_to_float_3D_axis0) { + std::vector scales = {0.1, 0.2, 0.3}; + std::vector zero_points = {0, 5, -2}; + + test_reference_dequantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + 0, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_reference_dequantize_per_channel_int8_to_float_3D_axis2) { + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_reference_dequantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + 2, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_reference_dequantize_per_channel_int8_to_float_3D_axisn1) { + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_reference_dequantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + -1, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_reference_dequantize_per_channel_int32_to_float_4D_axis0) { + std::vector scales = {0.1, 0.2, 0.00002}; + std::vector zero_points = {0, 5, -4}; + + test_reference_dequantize_per_channel( + {3, 4, 2, 5}, // input sizes + scales, + zero_points, + 0, // axis + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, + at::kFloat); +} diff --git a/extension/aten_util/make_aten_functor_from_et_functor.h b/extension/aten_util/make_aten_functor_from_et_functor.h index cb7b36a5fc1..8889df70ef0 100644 --- a/extension/aten_util/make_aten_functor_from_et_functor.h +++ b/extension/aten_util/make_aten_functor_from_et_functor.h @@ -154,24 +154,44 @@ struct type_convert< at::Tensor converted_; }; -// Optionals: ATen to ETen. +// Optionals: ETen to ATen. template -struct type_convert, torch::executor::optional> final { +struct type_convert, std::optional> final { public: - std::optional val; + torch::executor::optional val; std::unique_ptr> convert_struct; - explicit type_convert(std::optional value) : val(value) {} - torch::executor::optional call() { + explicit type_convert(torch::executor::optional value) : val(value) {} + std::optional call() { if (val.has_value()) { convert_struct = std::make_unique>( type_convert(val.value())); - return torch::executor::optional(convert_struct->call()); + return std::optional(convert_struct->call()); } else { - return torch::executor::optional(); + return std::optional(); } } }; +// Specific specialization for optional tensor conversion: std::optional to std::optional +template <> +struct type_convert&, const std::optional&> final { + public: + const std::optional& val; + std::unique_ptr> convert_struct; + explicit type_convert(const std::optional& value) : val(value) {} + const std::optional& call() { + static std::optional result; + if (val.has_value()) { + convert_struct = std::make_unique>( + type_convert(val.value())); + result = std::optional(convert_struct->call()); + } else { + result = std::optional(); + } + return result; + } +}; + // ArrayRefs: ATen to ETen. template struct type_convert, torch::executor::ArrayRef> final {