-
Notifications
You must be signed in to change notification settings - Fork 3.6k
add webgpu support for GatherBlockQuantized #25413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
1a6dbe8
webgpu support for GatherBlockQuantized
guschmue 64d5b7e
add 8bit quantization
guschmue b7c7148
cleanup
guschmue 5fb5e39
copilot feedback
guschmue 5e2010b
lintrunner
guschmue 6ce7e05
move some code to 4bit case
guschmue 9a05ad6
Merge branch 'main' into gs/GatherBlockQuantized
guschmue 0d106c9
fix issue with memory going out of scope
guschmue File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
249 changes: 249 additions & 0 deletions
249
onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,249 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "core/providers/webgpu/shader_helper.h" | ||
| #include "core/providers/webgpu/webgpu_utils.h" | ||
| #include "core/providers/webgpu/webgpu_supported_types.h" | ||
| #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" | ||
| #include "contrib_ops/webgpu/quantization/gather_block_quantized.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace contrib { | ||
| namespace webgpu { | ||
|
|
||
| using namespace onnxruntime::webgpu; | ||
| using onnxruntime::webgpu::ComputeContext; | ||
|
|
||
| Status GatherBlockQuantizedProgram::GenerateShaderCode(ShaderHelper& shader) const { | ||
| const auto& x = shader.AddInput("input", ShaderUsage::UseElementTypeAlias); | ||
| const auto& x_shape = shader.AddIndices("input_shape", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); | ||
| const auto& indices = shader.AddInput("indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseIndicesToOffset); | ||
| const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); | ||
| const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseValueTypeAlias); | ||
|
|
||
| bool is_4bit = bits_ == 4; | ||
| const std::string unpack = (is_signed_) ? "unpack4xI8" : "unpack4xU8"; | ||
|
|
||
| shader.MainFunctionBody() | ||
| << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") | ||
| << "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n"; | ||
|
|
||
| if (indices_rank_ > 1) { | ||
| shader.MainFunctionBody() | ||
| << "var indices_indices = indices_indices_t(0);\n" | ||
| << "for (var i: u32 = 0; i < " << indices_rank_ << "; i++) {\n" | ||
| << " let index = " << output.IndicesGet("output_indices", "uniforms.gather_axis + i") << ";\n" | ||
| << " " << indices.IndicesSet("indices_indices", "i", "index") << ";\n};\n"; | ||
| } else { | ||
| shader.MainFunctionBody() | ||
| << "let indices_indices = " << output.IndicesGet("output_indices", "uniforms.gather_axis") << ";\n"; | ||
| } | ||
| shader.MainFunctionBody() | ||
| << "var data_indices = input_shape_indices_t(0);\n" | ||
| << "for (var i: u32 = 0; i < uniforms.gather_axis; i++) {\n" | ||
| << " let index = " << output.IndicesGet("output_indices", "i") << ";\n " | ||
| << x_shape.IndicesSet("data_indices", "i", "index") << ";\n};\n" | ||
| << "var index_from_indices = " << indices.GetByIndices("indices_indices") << ";\n" | ||
| << "if (index_from_indices < 0) { index_from_indices += " << x_shape_[gather_axis_] << ";}\n" | ||
| << x_shape.IndicesSet("data_indices", "uniforms.gather_axis", "u32(index_from_indices)") << ";\n" | ||
| << "for (var i = uniforms.gather_axis + 1; i < " << output_shape_.NumDimensions() << "; i++) {\n" | ||
| << " let index = " << output.IndicesGet("output_indices", "i + " + std::to_string(indices_rank_ - 1)) << ";\n " | ||
| << x_shape.IndicesSet("data_indices", "i", "index") << ";\n};\n" | ||
| << " let data_offset = " << x_shape.IndicesToOffset("data_indices") << ";\n"; | ||
|
|
||
| if (is_4bit) { | ||
| shader.MainFunctionBody() | ||
| << " let data_index = data_offset % 8;\n" | ||
| << " let packed_4bit_quantized_data = " << x.GetByOffset("data_offset / 8") << ";\n" | ||
| << " let packed_8bit_quantized_data = (packed_4bit_quantized_data >> (4 * (data_index % 2))) & 0x0f0f0f0f;\n" | ||
| << " let quantized_data_vec = " << unpack << "(u32(packed_8bit_quantized_data));\n" | ||
| << " var quantized_data = quantized_data_vec[data_index / 2];\n"; | ||
| } else { | ||
| shader.MainFunctionBody() | ||
| << " let data_index = data_offset % 4;\n" | ||
| << " let packed_8bit_quantized_data = " << x.GetByOffset("data_offset / 4") << ";\n" | ||
| << " let quantized_data_vec = " << unpack << "(u32(packed_8bit_quantized_data));\n" | ||
| << " var quantized_data = quantized_data_vec[data_index];\n"; | ||
| } | ||
|
|
||
| if (is_signed_) { | ||
| shader.MainFunctionBody() | ||
| << " if((quantized_data & 0x8) != 0) { quantized_data = quantized_data - 16 ;};\n"; | ||
| } | ||
| shader.MainFunctionBody() | ||
| << " var scale_indices = data_indices;\n" | ||
| << " let quantize_axis_index = " << scales.IndicesGet("data_indices", "uniforms.quantize_axis") << "/ uniforms.block_size;\n " | ||
| << scales.IndicesSet("scale_indices", "uniforms.quantize_axis", "quantize_axis_index") << ";\n" | ||
| << " var scale = " << scales.GetByIndices("scale_indices") << ";\n"; | ||
|
|
||
| if (!has_zeropoint_) { | ||
| const std::string default_zero_point = is_uint8_ ? is_4bit ? "input_element_t(8)" : "input_element_t(128)" : "input_element_t(0)"; | ||
| shader.MainFunctionBody() | ||
| << " let zero_point = " << default_zero_point << ";\n"; | ||
| } else { | ||
| const auto& zero_point = shader.AddInput("zero_point", ShaderUsage::None); | ||
| shader.MainFunctionBody() | ||
| << " let zero_point_indices = scale_indices;\n" | ||
| << " let zero_point_offset = " << scales.IndicesToOffset("zero_point_indices") << ";\n"; | ||
| if (is_4bit) { | ||
| shader.MainFunctionBody() | ||
| << " let zero_point_index = zero_point_offset % 8;\n" | ||
| << " let packed_4bit_zero_points = " << zero_point.GetByOffset("zero_point_offset / 8") << ";\n" | ||
| << " let packed_8bit_zero_points = (packed_4bit_zero_points >> (4 * (zero_point_index % 2))) & 0x0f0f0f0f;\n" | ||
| << " let zero_point_vec = " << unpack << "(u32(packed_8bit_zero_points));\n" | ||
| << " var zero_point = zero_point_vec[zero_point_index / 2];\n"; | ||
| } else { | ||
| shader.MainFunctionBody() | ||
| << " let zero_point_index = zero_point_offset % 4;\n" | ||
| << " let packed_8bit_zero_points = " << zero_point.GetByOffset("zero_point_offset / 4") << ";\n" | ||
| << " let zero_point_vec = " << unpack << "(u32(packed_8bit_zero_points));\n" | ||
| << " var zero_point = zero_point_vec[zero_point_index];\n"; | ||
| } | ||
guschmue marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if (is_signed_) { | ||
| shader.MainFunctionBody() | ||
| << " if((zero_point & 0x8) != 0) { zero_point = zero_point - 16 ;};\n"; | ||
| } | ||
| } | ||
| shader.MainFunctionBody() | ||
| << " let dequantized_data = (output_value_t(quantized_data) - output_value_t(zero_point)) * scale;\n " | ||
| << output.SetByOffset("global_idx", "dequantized_data") << ";\n"; | ||
|
|
||
| return Status::OK(); | ||
| } | ||
|
|
||
| TensorShapeVector splice(TensorShapeVector vec, size_t start, size_t deleteCount, const TensorShapeVector toInsert = {}) { | ||
| TensorShapeVector new_vec; | ||
|
|
||
| for (size_t i = 0; i < vec.size(); i++) { | ||
| if (i < start) { | ||
| new_vec.push_back(vec[i]); | ||
| } else if (i == start) { | ||
| new_vec.insert(new_vec.end(), toInsert.begin(), toInsert.end()); | ||
| } else if (i >= start + deleteCount) { | ||
| new_vec.push_back(vec[i]); | ||
| } | ||
| } | ||
| return new_vec; | ||
| } | ||
|
|
||
| Status GatherBlockQuantized::ComputeInternal(ComputeContext& context) const { | ||
| const auto* x = context.Input(0); | ||
| const auto* indices = context.Input(1); | ||
| const auto* scales = context.Input(2); | ||
| const auto* zero_points = context.Input(3); | ||
|
|
||
| // auto x_shape = x->Shape(); | ||
| int64_t x_size = x->Shape().Size(); | ||
| int x_rank = static_cast<int>(x->Shape().NumDimensions()); | ||
| int64_t x_dtype = x->GetElementType(); | ||
| bool is_signed = x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4; | ||
| bool is_int8 = x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; | ||
|
|
||
| if (bits_ == 4 && is_int8) { | ||
| std::optional<Tensor> data_representation_4bit; | ||
| std::optional<Tensor> zero_points_representation_4bit; | ||
| TensorShape data_representation_4bit_shape{x->Shape()}; | ||
| MLDataType new_dtype = (x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) ? | ||
| DataTypeImpl::GetType<UInt4x2>() : DataTypeImpl::GetType<Int4x2>(); | ||
| auto memory_info = OrtMemoryInfo{ | ||
| "WebGPU_Buffer", | ||
| OrtDeviceAllocator, | ||
| OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0}}; | ||
|
|
||
guschmue marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| data_representation_4bit_shape[x_rank - 1] = data_representation_4bit_shape[x_rank - 1] * 2; | ||
| data_representation_4bit.emplace( | ||
| new_dtype, | ||
| data_representation_4bit_shape, | ||
| const_cast<void*>(x->DataRaw()), | ||
| memory_info); | ||
|
|
||
| if (zero_points) { | ||
| TensorShape zero_points_representation_4bit_shape{zero_points->Shape()}; | ||
| zero_points_representation_4bit_shape[zero_points->Shape().NumDimensions() - 1] = | ||
| zero_points_representation_4bit_shape[zero_points->Shape().NumDimensions() - 1] * 2; | ||
| zero_points_representation_4bit.emplace( | ||
| new_dtype, | ||
| zero_points_representation_4bit_shape, | ||
| const_cast<void*>(zero_points->DataRaw()), | ||
| memory_info); | ||
| } | ||
| x = data_representation_4bit.has_value() ? &data_representation_4bit.value() : x; | ||
| zero_points = zero_points_representation_4bit.has_value() ? &zero_points_representation_4bit.value() : zero_points; | ||
| } | ||
|
|
||
| const auto& x_shape = x->Shape(); | ||
|
|
||
| size_t indices_rank = indices->Shape().NumDimensions(); | ||
| const auto scales_shape = scales->Shape(); | ||
| size_t scales_rank = scales_shape.NumDimensions(); | ||
| int gather_axis = (gather_axis_ >= 0) ? gather_axis_ : gather_axis_ + x_rank; | ||
| int quantize_axis = (quantize_axis_ >= 0) ? quantize_axis_ : quantize_axis_ + x_rank; | ||
|
|
||
| ORT_RETURN_IF_NOT(x_shape.NumDimensions() == scales_rank, | ||
| "data and scales must have the same rank."); | ||
| for (size_t i = 0; i < x_shape.NumDimensions(); ++i) { | ||
| ORT_RETURN_IF_NOT(i == static_cast<size_t>(quantize_axis) | ||
| ? (x_shape[i] * 1 + block_size_ - 1) / block_size_ == scales_shape[i] | ||
| : x_shape[i] == scales_shape[i], | ||
| "data and scales do not match shapes."); | ||
| } | ||
|
|
||
| TensorShape output_shape = splice(x_shape.AsShapeVector(), gather_axis, 1, indices->Shape().AsShapeVector()); | ||
| int64_t output_size = output_shape.Size(); | ||
| auto* output_tensor = context.Output(0, output_shape); | ||
|
|
||
| GatherBlockQuantizedProgram program{is_signed, is_int8, indices_rank, gather_axis, bits_, zero_points != nullptr, x_shape, output_shape}; | ||
|
|
||
| program | ||
| .AddInputs({{x, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, (bits_ == 4) ? 8 : 4}}) | ||
| .AddIndices(x_shape) | ||
| .AddInputs({{indices, ProgramTensorMetadataDependency::TypeAndRank}}) | ||
| .AddInputs({{scales, ProgramTensorMetadataDependency::TypeAndRank}}) | ||
| .AddOutput({output_tensor, ProgramTensorMetadataDependency::None}) | ||
| .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) | ||
| .AddUniformVariables({{static_cast<uint32_t>(x_size)}}) | ||
guschmue marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| .AddUniformVariables({{static_cast<uint32_t>(quantize_axis)}}) | ||
| .AddUniformVariables({{static_cast<uint32_t>(gather_axis)}}) | ||
| .AddUniformVariables({{static_cast<uint32_t>(block_size_)}}) | ||
| .CacheHint(std::to_string(gather_axis), std::to_string(quantize_axis), std::to_string(block_size_)); | ||
|
|
||
| if (zero_points != nullptr) { | ||
| ORT_RETURN_IF_NOT(scales_shape == zero_points->Shape(), | ||
| "scales and zero_points must have the same shape."); | ||
| auto zero_points_shape = zero_points->Shape(); | ||
| program.AddInputs({{zero_points, ProgramTensorMetadataDependency::None, ProgramInput::Flatten, (bits_ == 4) ? 8 : 4}}); | ||
| } | ||
|
|
||
| return context.RunProgram(program); | ||
| } | ||
|
|
||
| namespace { | ||
| const std::vector<MLDataType>& GatherBlockQuantizedT1Constraint() { | ||
| static std::vector<MLDataType> types{ | ||
| DataTypeImpl::GetTensorType<Int4x2>(), | ||
| DataTypeImpl::GetTensorType<UInt4x2>(), | ||
| DataTypeImpl::GetTensorType<uint8_t>()}; | ||
| return types; | ||
| } | ||
| const std::vector<MLDataType>& GatherBlockQuantizedTindConstraint() { | ||
| static std::vector<MLDataType> types{ | ||
| DataTypeImpl::GetTensorType<int32_t>(), | ||
| DataTypeImpl::GetTensorType<int64_t>()}; | ||
| return types; | ||
| } | ||
| } // namespace | ||
|
|
||
| ONNX_OPERATOR_KERNEL_EX( | ||
| GatherBlockQuantized, | ||
| kMSDomain, | ||
| 1, | ||
| kWebGpuExecutionProvider, | ||
| (*KernelDefBuilder::Create()) | ||
| .TypeConstraint("T1", GatherBlockQuantizedT1Constraint()) | ||
| .TypeConstraint("T2", WebGpuSupportedFloatTypes()) | ||
| .TypeConstraint("Tind", GatherBlockQuantizedTindConstraint()), | ||
| GatherBlockQuantized); | ||
|
|
||
| } // namespace webgpu | ||
| } // namespace contrib | ||
| } // namespace onnxruntime | ||
71 changes: 71 additions & 0 deletions
71
onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "core/providers/webgpu/program.h" | ||
| #include "core/providers/webgpu/webgpu_kernel.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace contrib { | ||
| namespace webgpu { | ||
|
|
||
| using namespace onnxruntime::webgpu; | ||
| using onnxruntime::webgpu::ComputeContext; | ||
|
|
||
| class GatherBlockQuantizedProgram final : public Program<GatherBlockQuantizedProgram> { | ||
| public: | ||
| GatherBlockQuantizedProgram(const bool is_signed, const bool is_uint8, size_t indices_rank, int gather_axis, int bits, bool has_zeropoint, | ||
| TensorShape x_shape, TensorShape output_shape) : Program<GatherBlockQuantizedProgram>{"GatherBlockQuantized"}, | ||
| is_signed_{is_signed}, | ||
| is_uint8_{is_uint8}, | ||
| indices_rank_{indices_rank}, | ||
| gather_axis_{gather_axis}, | ||
| bits_{bits}, | ||
| has_zeropoint_{has_zeropoint}, | ||
| x_shape_{x_shape}, | ||
| output_shape_{output_shape} {} | ||
|
|
||
| Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
|
||
| WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, | ||
| {"quantize_axis", ProgramUniformVariableDataType::Uint32}, | ||
| {"gather_axis", ProgramUniformVariableDataType::Uint32}, | ||
| {"block_size", ProgramUniformVariableDataType::Uint32}); | ||
|
|
||
| private: | ||
| bool is_signed_; | ||
| bool is_uint8_; | ||
| size_t indices_rank_; | ||
| int gather_axis_; | ||
| int bits_; | ||
| bool has_zeropoint_; | ||
| TensorShape x_shape_; | ||
| TensorShape output_shape_; | ||
| }; | ||
|
|
||
| class GatherBlockQuantized final : public WebGpuKernel { | ||
| public: | ||
| GatherBlockQuantized(const OpKernelInfo& info) : WebGpuKernel(info) { | ||
| gather_axis_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("gather_axis", 0)); | ||
| block_size_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("block_size", 128)); | ||
| quantize_axis_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("quantize_axis", 1)); | ||
| bits_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("bits", 4)); | ||
|
|
||
| ORT_ENFORCE(bits_ == 4 || bits_ == 8, "'bits' must be 4 or 8."); | ||
| ORT_ENFORCE(block_size_ >= 16 && ((block_size_ - 1) & block_size_) == 0, | ||
| "'block_size' must be 2's power and not less than 16."); | ||
| } | ||
|
|
||
| Status ComputeInternal(ComputeContext& context) const override; | ||
|
|
||
| private: | ||
| int gather_axis_; | ||
| int quantize_axis_; | ||
| int block_size_; | ||
| int bits_; | ||
| }; | ||
|
|
||
| } // namespace webgpu | ||
| } // namespace contrib | ||
| } // namespace onnxruntime |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.