diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index 82f316abe82..f32a93e2b6a 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -49,6 +49,17 @@ Tensor& dequantize_per_token_out( ScalarType out_dtype, Tensor& out); +Tensor& dequantize_per_channel_out( + const Tensor& input, + const Tensor& scale, + const std::optional& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out); + // Wrapper function for dequantize_per_tensor_out without context Tensor& dequantize_per_tensor_out_no_context( const Tensor& input, @@ -77,6 +88,29 @@ Tensor& dequantize_per_token_out_no_context( input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out); } +// Wrapper function for dequantize_per_channel_out without context +Tensor& dequantize_per_channel_out_no_context( + const Tensor& input, + const Tensor& scale, + const std::optional& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out) { + return torch::executor::native::dequantize_per_channel_out( + input, + scale, + zero_points, + axis, + quant_min, + quant_max, + dtype, + out_dtype, + out); +} + // ATen wrapper for dequantize_per_tensor at::Tensor dequantize_per_tensor_aten( const at::Tensor& input, @@ -131,6 +165,36 @@ at::Tensor dequantize_per_token_aten( return out; } +// ATen wrapper for dequantize_per_channel +at::Tensor dequantize_per_channel_aten( + const at::Tensor& input, + const at::Tensor& scale, + const std::optional& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + auto out = at::empty_like(input, out_dtype); + // Convert at::ScalarType to executorch::ScalarType + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); + + executorch::aten::optional opt_et_out_dtype(et_out_dtype); + + WRAP_TO_ATEN(dequantize_per_channel_out_no_context, 8) + (input, + scale, + zero_points, + axis, + quant_min, + quant_max, + et_dtype, + opt_et_out_dtype, + out); + return out; +} + } // namespace native } // namespace executor } // namespace torch @@ -183,6 +247,40 @@ void check_dequantize_args( } } +/** + * Helper function to validate dequantize_per_channel arguments + * Similar to the validation in quantize_test.cpp + */ +void check_dequantize_per_channel_args( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis) { + // Normalize axis + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input_sizes.size(); + } + + ASSERT_GE(normalized_axis, 0) + << "axis " << axis << " is not legal, normalized axis " << normalized_axis + << " should be >= 0"; + + ASSERT_LT(normalized_axis, static_cast(input_sizes.size())) + << "axis " << axis << " is not legal, normalized axis " << normalized_axis + << " should be < input.dim() " << input_sizes.size(); + + int64_t num_channels = input_sizes[normalized_axis]; + + ASSERT_EQ(num_channels, static_cast(scales.size())) + << "Expected scales.size() to match input.size(axis) (" << num_channels + << "), but got " << scales.size(); + + ASSERT_EQ(num_channels, static_cast(zero_points.size())) + << "Expected zero_points.size() to match input.size(axis) (" + << num_channels << "), but got " << zero_points.size(); +} + // // Reference Implementation // @@ -322,6 +420,120 @@ at::Tensor dequantize_per_token_reference_impl( return out; } +/* + * Reference implementation of dequantize_per_channel + */ +at::Tensor dequantize_per_channel_reference_impl( + const at::Tensor& input, + const at::Tensor& scale, + const std::optional& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Normalize axis to handle negative values + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input.dim(); + } + + // Create output tensor with the same shape as input but with target dtype + at::Tensor output = at::empty_like(input, out_dtype); + + // Get the number of channels along the quantization axis + int64_t num_channels = input.size(normalized_axis); + + // Calculate strides for efficient indexing + std::vector input_strides; + std::vector input_sizes; + for (int64_t i = 0; i < input.dim(); i++) { + input_sizes.push_back(input.size(i)); + input_strides.push_back(input.stride(i)); + } + + // Get data pointers + const double* scale_data = scale.const_data_ptr(); + const int64_t* zero_point_data = nullptr; + if (zero_point.has_value()) { + zero_point_data = zero_point.value().const_data_ptr(); + } + + // Iterate through all elements in the tensor + int64_t total_elements = input.numel(); + + // Helper lambda to convert flat index to multi-dimensional coordinates + auto flat_to_coords = [&](int64_t flat_idx, std::vector& coords) { + int64_t remaining = flat_idx; + for (int64_t dim = input.dim() - 1; dim >= 0; dim--) { + coords[dim] = remaining % input_sizes[dim]; + remaining /= input_sizes[dim]; + } + }; + + // Process each element + std::vector coords(input.dim()); + for (int64_t flat_idx = 0; flat_idx < total_elements; flat_idx++) { + // Convert flat index to coordinates + flat_to_coords(flat_idx, coords); + + // Get the channel index for this element + int64_t channel_idx = coords[normalized_axis]; + + // Get the quantization parameters for this channel + double channel_scale = scale_data[channel_idx]; + int64_t channel_zero_point = 0; + if (zero_point_data != nullptr) { + channel_zero_point = zero_point_data[channel_idx]; + } + + // Store casted values to avoid repeated casting + const int32_t channel_zero_point_int32 = + static_cast(channel_zero_point); + const float channel_scale_float = static_cast(channel_scale); + + // Get the input value and dequantize + double dequantized_value = 0.0; + + // Extract quantized value and dequantize based on input dtype + // Following the CPU implementation pattern: (input - zero_point) * scale + if (dtype == at::kByte) { + uint8_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = + (qvalue - channel_zero_point_int32) * channel_scale_float; + } else if (dtype == at::kChar) { + int8_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = + (qvalue - channel_zero_point_int32) * channel_scale_float; + } else if (dtype == at::kShort) { + int16_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = + (qvalue - channel_zero_point_int32) * channel_scale_float; + } else if (dtype == at::kInt) { + int32_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = + (qvalue - channel_zero_point_int32) * channel_scale_float; + } else if (dtype == at::kLong) { + int64_t qvalue = input.flatten()[flat_idx].item(); + dequantized_value = + (qvalue - channel_zero_point_int32) * channel_scale_float; + } else { + throw std::runtime_error("Unsupported input dtype"); + } + + // Store the result based on output dtype + if (out_dtype == at::kFloat) { + output.flatten()[flat_idx] = static_cast(dequantized_value); + } else if (out_dtype == at::kDouble) { + output.flatten()[flat_idx] = dequantized_value; + } else if (out_dtype == at::kHalf) { + output.flatten()[flat_idx] = static_cast(dequantized_value); + } + } + + return output; +} + // Forward declaration of implementation functions void test_vulkan_dequantize_per_tensor_impl( const std::vector& input_sizes, @@ -625,7 +837,8 @@ void test_vulkan_dequantize_per_tensor_impl( output_correct = at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); } else { - output_correct = at::allclose(reference_out, vk_out); + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); } if (!output_correct) { std::cout << "\n" @@ -1105,7 +1318,8 @@ void test_vulkan_dequantize_per_token_impl( output_correct = at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); } else { - output_correct = at::allclose(reference_out, vk_out); + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); } if (!output_correct) { std::cout << "\n" @@ -1349,3 +1563,191 @@ TEST( at::kChar, // input dtype at::kDouble); // output dtype } + +void test_reference_dequantize_per_channel( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + check_dequantize_per_channel_args(input_sizes, scales, zero_points, axis); + + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + + // Create input tensor with quantized values + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + float step = 1.0f; + if (input.numel() > 1) { + step = static_cast(quant_max - quant_min) / (input.numel() - 1); + } + + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + flat_input[i] = static_cast(qvalue); + } + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor my_ref = dequantize_per_channel_reference_impl( + input, + scale_tensor, + zero_point_tensor, + axis, + quant_min, + quant_max, + dtype, + out_dtype); + + // Get implementation output + at::Tensor cpu_ref = torch::executor::native::dequantize_per_channel_aten( + input, + scale_tensor, + zero_point_tensor, + axis, + quant_min, + quant_max, + dtype, + out_dtype); + + // Compare outputs + const bool output_correct = at::allclose(my_ref, cpu_ref); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " axis: " << axis << std::endl; + std::cout << " input sizes:"; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << " " << input_sizes[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "cpu_ref:" << std::endl; + std::cout << cpu_ref << std::endl; + std::cout << "my_ref:" << std::endl; + std::cout << my_ref << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +TEST( + VulkanDequantizePerChannelTest, + test_reference_dequantize_per_channel_uint8_to_float_3D_axis0) { + std::vector scales = {0.1, 0.2, 0.3}; + std::vector zero_points = {0, 5, -2}; + + test_reference_dequantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + 0, // axis + 0, // quant_min + 255, // quant_max + at::kByte, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_reference_dequantize_per_channel_int8_to_float_3D_axis2) { + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_reference_dequantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + 2, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_reference_dequantize_per_channel_int8_to_float_3D_axisn1) { + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_reference_dequantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + -1, // axis + -128, // quant_min + 127, // quant_max + at::kChar, + at::kFloat); +} + +TEST( + VulkanDequantizePerChannelTest, + test_reference_dequantize_per_channel_int32_to_float_4D_axis0) { + std::vector scales = {0.1, 0.2, 0.00002}; + std::vector zero_points = {0, 5, -4}; + + test_reference_dequantize_per_channel( + {3, 4, 2, 5}, // input sizes + scales, + zero_points, + 0, // axis + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, + at::kFloat); +} diff --git a/extension/aten_util/make_aten_functor_from_et_functor.h b/extension/aten_util/make_aten_functor_from_et_functor.h index cb7b36a5fc1..104531f0fbb 100644 --- a/extension/aten_util/make_aten_functor_from_et_functor.h +++ b/extension/aten_util/make_aten_functor_from_et_functor.h @@ -155,19 +155,39 @@ struct type_convert< }; // Optionals: ATen to ETen. -template -struct type_convert, torch::executor::optional> final { +template +struct type_convert< + AOptional, + EOptional, + std::enable_if_t< + std::is_same_v< + typename remove_const_ref::type, + std::optional< + typename remove_const_ref::type::value_type>> && + std::is_same_v< + typename remove_const_ref::type, + torch::executor::optional< + typename remove_const_ref::type::value_type>>>> + final { public: - std::optional val; - std::unique_ptr> convert_struct; - explicit type_convert(std::optional value) : val(value) {} - torch::executor::optional call() { + typename remove_const_ref::type val; + std::unique_ptr::type::value_type, + typename remove_const_ref::type::value_type>> + convert_struct; + explicit type_convert(AOptional value) : val(value) {} + typename remove_const_ref::type call() { if (val.has_value()) { - convert_struct = std::make_unique>( - type_convert(val.value())); - return torch::executor::optional(convert_struct->call()); + convert_struct = std::make_unique::type::value_type, + typename remove_const_ref::type::value_type>>( + type_convert< + typename remove_const_ref::type::value_type, + typename remove_const_ref::type::value_type>( + val.value())); + return typename remove_const_ref::type(convert_struct->call()); } else { - return torch::executor::optional(); + return typename remove_const_ref::type(); } } }; diff --git a/extension/aten_util/test/make_aten_functor_from_et_functor_test.cpp b/extension/aten_util/test/make_aten_functor_from_et_functor_test.cpp index 17d0f7a4d63..a5b53096ae2 100644 --- a/extension/aten_util/test/make_aten_functor_from_et_functor_test.cpp +++ b/extension/aten_util/test/make_aten_functor_from_et_functor_test.cpp @@ -421,3 +421,92 @@ TEST_F(MakeATenFunctorFromETFunctorTest, TestWrap_ArrayRefOptional) { EXPECT_EQ(stack.size(), 1); EXPECT_EQ(stack[0].toTensor().const_data_ptr()[0], 4); } + +TEST_F(MakeATenFunctorFromETFunctorTest, TestConvert_ConstRefOptionals) { + // Test const optional scalar conversion + const std::optional const_optional_at_in = + std::optional(42); + auto const_optional_et = + type_convert< + const std::optional, + torch::executor::optional>(const_optional_at_in) + .call(); + EXPECT_TRUE(const_optional_et.has_value()); + EXPECT_EQ(const_optional_et.value(), 42); + + // Test optional scalar reference conversion + std::optional optional_at_ref_in = std::optional(24); + auto optional_et_from_ref = + type_convert&, torch::executor::optional>( + optional_at_ref_in) + .call(); + EXPECT_TRUE(optional_et_from_ref.has_value()); + EXPECT_EQ(optional_et_from_ref.value(), 24); + + // Test const optional scalar reference conversion + const std::optional const_optional_at_ref_in = + std::optional(84); + auto const_optional_et_from_ref = + type_convert< + const std::optional&, + torch::executor::optional>(const_optional_at_ref_in) + .call(); + EXPECT_TRUE(const_optional_et_from_ref.has_value()); + EXPECT_EQ(const_optional_et_from_ref.value(), 84); + + // Test const optional tensor conversion + const std::optional const_optional_tensor_at_in = + std::optional(torch::tensor({5})); + auto const_optional_tensor_converter = type_convert< + const std::optional, + torch::executor::optional>( + const_optional_tensor_at_in); + auto const_optional_tensor_et = const_optional_tensor_converter.call(); + EXPECT_TRUE(const_optional_tensor_et.has_value()); + EXPECT_EQ(const_optional_tensor_et.value().const_data_ptr()[0], 5); + + // Test optional tensor reference conversion + std::optional optional_tensor_at_ref_in = + std::optional(torch::tensor({7})); + auto optional_tensor_converter_from_ref = type_convert< + std::optional&, + torch::executor::optional>( + optional_tensor_at_ref_in); + auto optional_tensor_et_from_ref = optional_tensor_converter_from_ref.call(); + EXPECT_TRUE(optional_tensor_et_from_ref.has_value()); + EXPECT_EQ( + optional_tensor_et_from_ref.value().const_data_ptr()[0], 7); + + // Test const optional tensor reference conversion + const std::optional const_optional_tensor_at_ref_in = + std::optional(torch::tensor({9})); + auto const_optional_tensor_converter_from_ref = type_convert< + const std::optional&, + torch::executor::optional>( + const_optional_tensor_at_ref_in); + auto const_optional_tensor_et_from_ref = + const_optional_tensor_converter_from_ref.call(); + EXPECT_TRUE(const_optional_tensor_et_from_ref.has_value()); + EXPECT_EQ( + const_optional_tensor_et_from_ref.value().const_data_ptr()[0], + 9); + + // Test empty const optional conversions + const std::optional empty_const_optional_at_in = std::nullopt; + auto empty_const_optional_et = + type_convert< + const std::optional, + torch::executor::optional>(empty_const_optional_at_in) + .call(); + EXPECT_FALSE(empty_const_optional_et.has_value()); + + const std::optional empty_const_optional_tensor_at_in = + std::nullopt; + auto empty_const_optional_tensor_et = + type_convert< + const std::optional, + torch::executor::optional>( + empty_const_optional_tensor_at_in) + .call(); + EXPECT_FALSE(empty_const_optional_tensor_et.has_value()); +}