From 3644030c34dad5982eb76602bc3d1d006063d17b Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 14 Jul 2025 13:24:26 -0700 Subject: [PATCH 01/12] redo concat impl --- .../core/providers/webgpu/tensor/concat.cc | 96 ++++--------------- .../core/providers/webgpu/tensor/concat.h | 4 +- .../providers/cpu/tensor/concat_op_test.cc | 93 ++++++++++++++++++ 3 files changed, 116 insertions(+), 77 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index 5cfd6c78f8929..17ed935a642d2 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -38,63 +38,24 @@ WEBGPU_CONCAT_VERSIONED_KERNEL(4, 10) WEBGPU_CONCAT_VERSIONED_KERNEL(11, 12) WEBGPU_CONCAT_KERNEL(13) -void AppendCalCulateInputIndexFunction(std::ostream& os, size_t input_count) { - os << "fn calculate_input_index(index: u32) -> u32 {\n" - << " for (var i = 0u; i < " << input_count << "; i = i + 1u) {\n" - << " if (index < " << GetElementAt("uniforms.size_in_concat_axis", "i", input_count) << ") {\n" - << " return i;\n" - << " }\n" - << " }\n" - << " return " << input_count << ";\n" - << "}\n"; -} - -void AppendAssignOutputDataFunction(std::ostream& os, gsl::span inputs, const ShaderVariableHelper& output) { - os << "fn assign_output_data(global_idx: u32, input_index: u32, indices: output_indices_t) {\n"; - for (size_t i = 0; i < inputs.size(); ++i) { - if (i == 0) { - os << " if (input_index == 0u) {\n"; - } else if (i == inputs.size() - 1) { - os << " } else {\n"; - } else { - os << " } else if (input_index == " << i << "u) {\n"; - } - os << " " << output.SetByOffset("global_idx", inputs[i]->GetByIndices("indices")) << ";\n"; - } - os << " }\n" - "}\n"; -} - Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { - size_t input_count = Inputs().size(); - std::vector inputs; - inputs.reserve(input_count); - for (size_t i = 0; i < input_count; ++i) { - inputs.push_back(&shader.AddInput("input_" + std::to_string(i), ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias)); - } + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const std::string output_indices_str = "output_indices" + (input.Rank() > 1 ? "[" + std::to_string(axis_) + "]" : ""); - // add implementation of fn calculate_input_index - AppendCalCulateInputIndexFunction(shader.AdditionalImplementation(), input_count); - // add implementation of fn assign_output_data - AppendAssignOutputDataFunction(shader.AdditionalImplementation(), inputs, output); - const std::string size_in_concat_axis = GetElementAt("uniforms.size_in_concat_axis", "input_index - 1", input_count); shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") - << " var indices = " << output.OffsetToIndices("global_idx") << ";\n" - << " let indices_axis = " << output.IndicesGet("indices", axis_) << ";\n" - << " let input_index = calculate_input_index(indices_axis);\n" - << " if (input_index != 0u) {\n" - << " " << output.IndicesSet("indices", axis_, "indices_axis - " + size_in_concat_axis) << ";\n" - << " }\n" - " assign_output_data(global_idx, input_index, indices);\n"; + << " var output_indices = " << input.OffsetToIndices("global_idx") << ";\n" + << " " << output_indices_str << " += uniforms.concat_axis_offset;\n" + << " " << output.SetByIndices("output_indices", input.GetByOffset("global_idx")) << "\n"; + return Status::OK(); } Status Concat::ComputeInternal(ComputeContext& context) const { - int input_count = context.InputCount(); + uint32_t input_count = context.InputCount(); InlinedTensorsVector input_tensors; input_tensors.reserve(input_count); - for (int i = 0; i < input_count; ++i) { + for (uint32_t i = 0; i < input_count; ++i) { input_tensors.push_back(context.Input(i)); } @@ -105,41 +66,26 @@ Status Concat::ComputeInternal(ComputeContext& context) const { } uint32_t output_size = onnxruntime::narrow(prepare.output_tensor->Shape().Size()); - size_t axis = static_cast(prepare.axis); - ConcatProgram program{axis}; - - std::vector sizes_in_concat_axis; - sizes_in_concat_axis.reserve(input_count); - uint32_t sum = 0; - for (int i = 0; i < input_count; ++i) { - const auto& input = prepare.inputs[i]; - if (input.tensor->Shape().Size() == 0) { - continue; - } - program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}); + uint32_t concat_axis_offset = 0; + for (uint32_t input_index = 0; input_index < input_count; input_index++) { + const auto& input = prepare.inputs[input_index]; auto axis_size = input.tensor->Shape()[axis]; - sum += static_cast(axis_size); - sizes_in_concat_axis.push_back(sum); - } - size_t non_empty_input_count = sizes_in_concat_axis.size(); + ConcatProgram pass_program{axis}; + pass_program.CacheHint(absl::StrJoin(std::make_tuple(prepare.axis), ",")) + .AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}) + .AddOutputs({prepare.output_tensor}) + .SetDispatchGroupSize((input.tensor->Shape().Size() + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({output_size, concat_axis_offset}); + ORT_RETURN_IF_ERROR(context.RunProgram(pass_program)); - if (non_empty_input_count + 1 > context.DeviceLimits().maxStorageBuffersPerShaderStage) { - // TODO: support when input_count + 1 > maxStorageBuffersPerShaderStage, by raising the limit or run the program in multiple passes. - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "The number of storage buffer (input=", - input_count, ", output=1) exceeds the limit (", - context.DeviceLimits().maxStorageBuffersPerShaderStage, ") of the device."); + concat_axis_offset += static_cast(axis_size); } - program.CacheHint(absl::StrJoin(std::make_tuple(non_empty_input_count, prepare.axis), ",")) - .AddOutputs({prepare.output_tensor}) - .SetDispatchGroupSize((prepare.output_num_elements + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .AddUniformVariables({gsl::span(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), - output_size}); - return context.RunProgram(program); + return Status::OK(); } } // namespace webgpu -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.h b/onnxruntime/core/providers/webgpu/tensor/concat.h index 0f6e6dd327e33..a8b74f6c40b40 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.h +++ b/onnxruntime/core/providers/webgpu/tensor/concat.h @@ -17,8 +17,8 @@ class ConcatProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"size_in_concat_axis", ProgramUniformVariableDataType::Uint32}, - {"output_size", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"concat_axis_offset", ProgramUniformVariableDataType::Uint32}); private: size_t axis_; diff --git a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc index 9e0fb81cbb0fc..db462d72465d1 100644 --- a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc @@ -434,5 +434,98 @@ TEST(ConcatOpTest, Concat4D_2) { test.Run(); } +#ifdef USE_WEBGPU +TEST(ConcatOpTest, Concat1D_exceed_maxStorageBuffersPerShaderStage) { + // maxStorageBuffersPerShaderStage==8 + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{0}); + + test.AddInput("input1", {1}, {1}); + test.AddInput("input2", {1}, {2}); + test.AddInput("input3", {1}, {3}); + test.AddInput("input4", {1}, {4}); + test.AddInput("input5", {1}, {5}); + test.AddInput("input6", {1}, {6}); + test.AddInput("input7", {1}, {7}); + test.AddInput("input8", {1}, {8}); + test.AddInput("input9", {1}, {9}); + test.AddOutput("concat_result", {9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + test.Run(); +} + +TEST(ConcatOpTest, Concat2D_exceed_maxStorageBuffersPerShaderStage_axis0) { + // maxStorageBuffersPerShaderStage==8 + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{0}); + + test.AddInput("input1", {1, 2}, {1, 1}); + test.AddInput("input2", {1, 2}, {2, 2}); + test.AddInput("input3", {1, 2}, {3, 3}); + test.AddInput("input4", {1, 2}, {4, 4}); + test.AddInput("input5", {1, 2}, {5, 5}); + test.AddInput("input6", {1, 2}, {6, 6}); + test.AddInput("input7", {1, 2}, {7, 7}); + test.AddInput("input8", {1, 2}, {8, 8}); + test.AddInput("input9", {1, 2}, {9, 9}); + test.AddOutput("concat_result", {9, 2}, {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}); + test.Run(); +} + +TEST(ConcatOpTest, Concat2D_exceed_maxStorageBuffersPerShaderStage_axis1) { + // maxStorageBuffersPerShaderStage==8 + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{1}); + + test.AddInput("input1", {1, 2}, {1, 1}); + test.AddInput("input2", {1, 2}, {2, 2}); + test.AddInput("input3", {1, 2}, {3, 3}); + test.AddInput("input4", {1, 2}, {4, 4}); + test.AddInput("input5", {1, 2}, {5, 5}); + test.AddInput("input6", {1, 2}, {6, 6}); + test.AddInput("input7", {1, 2}, {7, 7}); + test.AddInput("input8", {1, 2}, {8, 8}); + test.AddInput("input9", {1, 2}, {9, 9}); + test.AddOutput("concat_result", {1, 18}, {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}); + test.Run(); +} + +TEST(ConcatOpTest, Concat3D_exceed_maxStorageBuffersPerShaderStage) { + // maxStorageBuffersPerShaderStage==8 + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{1}); + + test.AddInput("input1", {2, 1, 1}, {1, 2}); + test.AddInput("input2", {2, 1, 1}, {3, 4}); + test.AddInput("input3", {2, 1, 1}, {5, 6}); + test.AddInput("input4", {2, 1, 1}, {7, 8}); + test.AddInput("input5", {2, 1, 1}, {9, 10}); + test.AddInput("input6", {2, 1, 1}, {11, 12}); + test.AddInput("input7", {2, 1, 1}, {13, 14}); + test.AddInput("input8", {2, 1, 1}, {15, 16}); + test.AddInput("input9", {2, 1, 1}, {17, 18}); + test.AddOutput("concat_result", {2, 9, 1}, {// batch 0 + 1, 3, 5, 7, 9, 11, 13, 15, 17, + // batch 1 + 2, 4, 6, 8, 10, 12, 14, 16, 18}); + test.Run(); +} + +TEST(ConcatOpTest, Concat3D_exceed_maxStorageBuffersPerShaderStage_small) { + // maxStorageBuffersPerShaderStage==8 + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{1}); + + test.AddInput("input1", {2, 1, 1}, {1, 2}); + test.AddInput("input2", {2, 3, 1}, {3, 4, 5, 6, 7, 8}); + test.AddInput("input3", {2, 2, 1}, {9, 10, 11, 12}); + test.AddInput("input4", {2, 1, 1}, {13, 14}); + test.AddOutput("concat_result", {2, 7, 1}, {// batch 0 + 1, 3, 4, 5, 9, 10, 13, + // batch 1 + 2, 6, 7, 8, 11, 12, 14}); + test.Run(); +} +#endif // USE_WEBGPU + } // namespace test } // namespace onnxruntime From e5761716ae13f124aaab1f3919f776b668a22696 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 14 Jul 2025 14:01:13 -0700 Subject: [PATCH 02/12] prevent perf regression --- .../core/providers/webgpu/tensor/concat.cc | 109 +++++++++++++++--- .../core/providers/webgpu/tensor/concat.h | 20 +++- 2 files changed, 112 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index 17ed935a642d2..b379c657aa7cf 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -38,7 +38,59 @@ WEBGPU_CONCAT_VERSIONED_KERNEL(4, 10) WEBGPU_CONCAT_VERSIONED_KERNEL(11, 12) WEBGPU_CONCAT_KERNEL(13) +void AppendCalCulateInputIndexFunction(std::ostream& os, size_t input_count) { + os << "fn calculate_input_index(index: u32) -> u32 {\n" + << " for (var i = 0u; i < " << input_count << "; i = i + 1u) {\n" + << " if (index < " << GetElementAt("uniforms.size_in_concat_axis", "i", input_count) << ") {\n" + << " return i;\n" + << " }\n" + << " }\n" + << " return " << input_count << ";\n" + << "}\n"; +} + +void AppendAssignOutputDataFunction(std::ostream& os, gsl::span inputs, const ShaderVariableHelper& output) { + os << "fn assign_output_data(global_idx: u32, input_index: u32, indices: output_indices_t) {\n"; + for (size_t i = 0; i < inputs.size(); ++i) { + if (i == 0) { + os << " if (input_index == 0u) {\n"; + } else if (i == inputs.size() - 1) { + os << " } else {\n"; + } else { + os << " } else if (input_index == " << i << "u) {\n"; + } + os << " " << output.SetByOffset("global_idx", inputs[i]->GetByIndices("indices")) << ";\n"; + } + os << " }\n" + "}\n"; +} + Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { + size_t input_count = Inputs().size(); + std::vector inputs; + inputs.reserve(input_count); + for (size_t i = 0; i < input_count; ++i) { + inputs.push_back(&shader.AddInput("input_" + std::to_string(i), ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias)); + } + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + + // add implementation of fn calculate_input_index + AppendCalCulateInputIndexFunction(shader.AdditionalImplementation(), input_count); + // add implementation of fn assign_output_data + AppendAssignOutputDataFunction(shader.AdditionalImplementation(), inputs, output); + const std::string size_in_concat_axis = GetElementAt("uniforms.size_in_concat_axis", "input_index - 1", input_count); + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " var indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " let indices_axis = " << output.IndicesGet("indices", axis_) << ";\n" + << " let input_index = calculate_input_index(indices_axis);\n" + << " if (input_index != 0u) {\n" + << " " << output.IndicesSet("indices", axis_, "indices_axis - " + size_in_concat_axis) << ";\n" + << " }\n" + " assign_output_data(global_idx, input_index, indices);\n"; + return Status::OK(); +} + +Status ConcatProgramSingle::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); const std::string output_indices_str = "output_indices" + (input.Rank() > 1 ? "[" + std::to_string(axis_) + "]" : ""); @@ -52,10 +104,10 @@ Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { } Status Concat::ComputeInternal(ComputeContext& context) const { - uint32_t input_count = context.InputCount(); + int input_count = context.InputCount(); InlinedTensorsVector input_tensors; input_tensors.reserve(input_count); - for (uint32_t i = 0; i < input_count; ++i) { + for (int i = 0; i < input_count; ++i) { input_tensors.push_back(context.Input(i)); } @@ -66,25 +118,54 @@ Status Concat::ComputeInternal(ComputeContext& context) const { } uint32_t output_size = onnxruntime::narrow(prepare.output_tensor->Shape().Size()); + size_t axis = static_cast(prepare.axis); + ConcatProgram program{axis}; + + std::vector sizes_in_concat_axis; + sizes_in_concat_axis.reserve(input_count); + uint32_t sum = 0; + for (int i = 0; i < input_count; ++i) { + const auto& input = prepare.inputs[i]; + if (input.tensor->Shape().Size() == 0) { + continue; + } + program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}); - uint32_t concat_axis_offset = 0; - for (uint32_t input_index = 0; input_index < input_count; input_index++) { - const auto& input = prepare.inputs[input_index]; auto axis_size = input.tensor->Shape()[axis]; + sum += static_cast(axis_size); + sizes_in_concat_axis.push_back(sum); + } + + size_t non_empty_input_count = sizes_in_concat_axis.size(); - ConcatProgram pass_program{axis}; - pass_program.CacheHint(absl::StrJoin(std::make_tuple(prepare.axis), ",")) - .AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}) - .AddOutputs({prepare.output_tensor}) - .SetDispatchGroupSize((input.tensor->Shape().Size() + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .AddUniformVariables({output_size, concat_axis_offset}); - ORT_RETURN_IF_ERROR(context.RunProgram(pass_program)); + if (non_empty_input_count + 1 > context.DeviceLimits().maxStorageBuffersPerShaderStage) { + LOGS_DEFAULT(WARNING) << "Storage buffer limit exceeded for Concat. Running operation one input at a time."; + uint32_t concat_axis_offset = 0; + for (uint32_t input_index = 0; input_index < input_count; input_index++) { + const auto& input = prepare.inputs[input_index]; + auto axis_size = input.tensor->Shape()[axis]; - concat_axis_offset += static_cast(axis_size); + ConcatProgramSingle pass_program{axis}; + pass_program.CacheHint(absl::StrJoin(std::make_tuple(prepare.axis), ",")) + .AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}) + .AddOutputs({prepare.output_tensor}) + .SetDispatchGroupSize((input.tensor->Shape().Size() + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({output_size, concat_axis_offset}); + ORT_RETURN_IF_ERROR(context.RunProgram(pass_program)); + + concat_axis_offset += static_cast(axis_size); + } + + return Status::OK(); } - return Status::OK(); + program.CacheHint(absl::StrJoin(std::make_tuple(non_empty_input_count, prepare.axis), ",")) + .AddOutputs({prepare.output_tensor}) + .SetDispatchGroupSize((prepare.output_num_elements + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({gsl::span(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), + output_size}); + return context.RunProgram(program); } } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.h b/onnxruntime/core/providers/webgpu/tensor/concat.h index a8b74f6c40b40..d0c19fad5ef29 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.h +++ b/onnxruntime/core/providers/webgpu/tensor/concat.h @@ -17,8 +17,22 @@ class ConcatProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, - {"concat_axis_offset", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"size_in_concat_axis", ProgramUniformVariableDataType::Uint32}, + {"output_size", ProgramUniformVariableDataType::Uint32}); + + private: + size_t axis_; +}; + +class ConcatProgramSingle final : public Program { + public: + ConcatProgramSingle(size_t axis) : Program{"Concat"}, axis_{axis} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"concat_axis_offset", ProgramUniformVariableDataType::Uint32}, + {"output_size", ProgramUniformVariableDataType::Uint32}); private: size_t axis_; @@ -33,4 +47,4 @@ class Concat final : public WebGpuKernel, public ConcatBase { }; } // namespace webgpu -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file From 224e278d6fea2ccbcc4f0913ea84eb7f3787ddea Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 14 Jul 2025 14:03:59 -0700 Subject: [PATCH 03/12] remove test case --- .../test/providers/cpu/tensor/concat_op_test.cc | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc index db462d72465d1..ec858ce1fc409 100644 --- a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc @@ -509,22 +509,6 @@ TEST(ConcatOpTest, Concat3D_exceed_maxStorageBuffersPerShaderStage) { 2, 4, 6, 8, 10, 12, 14, 16, 18}); test.Run(); } - -TEST(ConcatOpTest, Concat3D_exceed_maxStorageBuffersPerShaderStage_small) { - // maxStorageBuffersPerShaderStage==8 - OpTester test("Concat"); - test.AddAttribute("axis", int64_t{1}); - - test.AddInput("input1", {2, 1, 1}, {1, 2}); - test.AddInput("input2", {2, 3, 1}, {3, 4, 5, 6, 7, 8}); - test.AddInput("input3", {2, 2, 1}, {9, 10, 11, 12}); - test.AddInput("input4", {2, 1, 1}, {13, 14}); - test.AddOutput("concat_result", {2, 7, 1}, {// batch 0 - 1, 3, 4, 5, 9, 10, 13, - // batch 1 - 2, 6, 7, 8, 11, 12, 14}); - test.Run(); -} #endif // USE_WEBGPU } // namespace test From edd57f7fa4cc572f64a5d27784481627ded7347a Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 14 Jul 2025 14:29:56 -0700 Subject: [PATCH 04/12] cast bug fix --- onnxruntime/core/providers/webgpu/tensor/concat.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index b379c657aa7cf..abc6ee6763a6d 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -142,7 +142,7 @@ Status Concat::ComputeInternal(ComputeContext& context) const { if (non_empty_input_count + 1 > context.DeviceLimits().maxStorageBuffersPerShaderStage) { LOGS_DEFAULT(WARNING) << "Storage buffer limit exceeded for Concat. Running operation one input at a time."; uint32_t concat_axis_offset = 0; - for (uint32_t input_index = 0; input_index < input_count; input_index++) { + for (uint32_t input_index = 0; input_index < static_cast(input_count); input_index++) { const auto& input = prepare.inputs[input_index]; auto axis_size = input.tensor->Shape()[axis]; From a175358ea9ea09203a7f747403c337b66dc4aa1e Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Tue, 15 Jul 2025 06:49:23 -0700 Subject: [PATCH 05/12] batched version --- .../core/providers/webgpu/tensor/concat.cc | 137 ++++++------------ .../core/providers/webgpu/tensor/concat.h | 16 +- .../providers/cpu/tensor/concat_op_test.cc | 1 + 3 files changed, 48 insertions(+), 106 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index abc6ee6763a6d..9146ccfcef98d 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -38,33 +38,6 @@ WEBGPU_CONCAT_VERSIONED_KERNEL(4, 10) WEBGPU_CONCAT_VERSIONED_KERNEL(11, 12) WEBGPU_CONCAT_KERNEL(13) -void AppendCalCulateInputIndexFunction(std::ostream& os, size_t input_count) { - os << "fn calculate_input_index(index: u32) -> u32 {\n" - << " for (var i = 0u; i < " << input_count << "; i = i + 1u) {\n" - << " if (index < " << GetElementAt("uniforms.size_in_concat_axis", "i", input_count) << ") {\n" - << " return i;\n" - << " }\n" - << " }\n" - << " return " << input_count << ";\n" - << "}\n"; -} - -void AppendAssignOutputDataFunction(std::ostream& os, gsl::span inputs, const ShaderVariableHelper& output) { - os << "fn assign_output_data(global_idx: u32, input_index: u32, indices: output_indices_t) {\n"; - for (size_t i = 0; i < inputs.size(); ++i) { - if (i == 0) { - os << " if (input_index == 0u) {\n"; - } else if (i == inputs.size() - 1) { - os << " } else {\n"; - } else { - os << " } else if (input_index == " << i << "u) {\n"; - } - os << " " << output.SetByOffset("global_idx", inputs[i]->GetByIndices("indices")) << ";\n"; - } - os << " }\n" - "}\n"; -} - Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { size_t input_count = Inputs().size(); std::vector inputs; @@ -74,40 +47,25 @@ Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { } const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - // add implementation of fn calculate_input_index - AppendCalCulateInputIndexFunction(shader.AdditionalImplementation(), input_count); - // add implementation of fn assign_output_data - AppendAssignOutputDataFunction(shader.AdditionalImplementation(), inputs, output); - const std::string size_in_concat_axis = GetElementAt("uniforms.size_in_concat_axis", "input_index - 1", input_count); - shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") - << " var indices = " << output.OffsetToIndices("global_idx") << ";\n" - << " let indices_axis = " << output.IndicesGet("indices", axis_) << ";\n" - << " let input_index = calculate_input_index(indices_axis);\n" - << " if (input_index != 0u) {\n" - << " " << output.IndicesSet("indices", axis_, "indices_axis - " + size_in_concat_axis) << ";\n" - << " }\n" - " assign_output_data(global_idx, input_index, indices);\n"; - return Status::OK(); -} - -Status ConcatProgramSingle::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - const std::string output_indices_str = "output_indices" + (input.Rank() > 1 ? "[" + std::to_string(axis_) + "]" : ""); + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"); + for (size_t i = 0; i < input_count; ++i) { + const std::string output_indices_i = absl::StrCat("output_indices_", i); + const std::string output_indices_i_axis = output_indices_i + (inputs[i]->Rank() > 1 ? "[" + std::to_string(axis_) + "]" : ""); + const std::string concat_axis_offset = GetElementAt("uniforms.sizes_in_concat_axis", std::to_string(i), input_count); - shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") - << " var output_indices = " << input.OffsetToIndices("global_idx") << ";\n" - << " " << output_indices_str << " += uniforms.concat_axis_offset;\n" - << " " << output.SetByIndices("output_indices", input.GetByOffset("global_idx")) << "\n"; + shader.MainFunctionBody() << " var " << output_indices_i << " = " << inputs[i]->OffsetToIndices("global_idx") << ";\n" + << " " << output_indices_i_axis << " += " << concat_axis_offset << ";\n" + << " " << output.SetByIndices(output_indices_i, inputs[i]->GetByOffset("global_idx")) << "\n"; + } return Status::OK(); } Status Concat::ComputeInternal(ComputeContext& context) const { - int input_count = context.InputCount(); + uint32_t input_count = context.InputCount(); InlinedTensorsVector input_tensors; input_tensors.reserve(input_count); - for (int i = 0; i < input_count; ++i) { + for (uint32_t i = 0; i < input_count; ++i) { input_tensors.push_back(context.Input(i)); } @@ -117,55 +75,52 @@ Status Concat::ComputeInternal(ComputeContext& context) const { return Status::OK(); } - uint32_t output_size = onnxruntime::narrow(prepare.output_tensor->Shape().Size()); + uint32_t axis = static_cast(prepare.axis); + uint32_t max_inputs_per_concat = context.DeviceLimits().maxStorageBuffersPerShaderStage - 1; - size_t axis = static_cast(prepare.axis); - ConcatProgram program{axis}; + uint32_t input_index = 0; + uint32_t cumulative_output_size = 0; + uint32_t cumulative_size_in_concat_axis = 0; - std::vector sizes_in_concat_axis; - sizes_in_concat_axis.reserve(input_count); - uint32_t sum = 0; - for (int i = 0; i < input_count; ++i) { - const auto& input = prepare.inputs[i]; - if (input.tensor->Shape().Size() == 0) { - continue; - } - program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}); + while (input_index < input_count) { + ConcatProgram program{axis}; + uint32_t num_inputs_this_concat = std::min(max_inputs_per_concat, input_count - input_index); - auto axis_size = input.tensor->Shape()[axis]; - sum += static_cast(axis_size); - sizes_in_concat_axis.push_back(sum); - } + std::vector sizes; + std::vector sizes_in_concat_axis; + sizes.reserve(num_inputs_this_concat + 1); + sizes_in_concat_axis.reserve(num_inputs_this_concat + 1); - size_t non_empty_input_count = sizes_in_concat_axis.size(); + // Start with the cumulative size from previous dispatches + sizes.push_back(cumulative_output_size); + sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis); - if (non_empty_input_count + 1 > context.DeviceLimits().maxStorageBuffersPerShaderStage) { - LOGS_DEFAULT(WARNING) << "Storage buffer limit exceeded for Concat. Running operation one input at a time."; - uint32_t concat_axis_offset = 0; - for (uint32_t input_index = 0; input_index < static_cast(input_count); input_index++) { - const auto& input = prepare.inputs[input_index]; - auto axis_size = input.tensor->Shape()[axis]; + uint32_t dispatch_size = 0; + for (uint32_t i = 0; i < num_inputs_this_concat; i++) { + auto& input = prepare.inputs[input_index + i]; + program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}); - ConcatProgramSingle pass_program{axis}; - pass_program.CacheHint(absl::StrJoin(std::make_tuple(prepare.axis), ",")) - .AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}) - .AddOutputs({prepare.output_tensor}) - .SetDispatchGroupSize((input.tensor->Shape().Size() + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .AddUniformVariables({output_size, concat_axis_offset}); - ORT_RETURN_IF_ERROR(context.RunProgram(pass_program)); + uint32_t size = onnxruntime::narrow(input.tensor->Shape().Size()); + uint32_t axis_size = static_cast(input.tensor->Shape()[axis]); - concat_axis_offset += static_cast(axis_size); + cumulative_output_size += size; + dispatch_size += size; + sizes.push_back(cumulative_output_size); + + cumulative_size_in_concat_axis += axis_size; + sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis); } - return Status::OK(); + program.CacheHint(absl::StrJoin(std::make_tuple(num_inputs_this_concat, prepare.axis), ",")) + .AddOutputs({prepare.output_tensor}) + .SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({gsl::span(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), dispatch_size}); + ORT_RETURN_IF_ERROR(context.RunProgram(program)); + + input_index += num_inputs_this_concat; } - program.CacheHint(absl::StrJoin(std::make_tuple(non_empty_input_count, prepare.axis), ",")) - .AddOutputs({prepare.output_tensor}) - .SetDispatchGroupSize((prepare.output_num_elements + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .AddUniformVariables({gsl::span(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), - output_size}); - return context.RunProgram(program); + return Status::OK(); } } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.h b/onnxruntime/core/providers/webgpu/tensor/concat.h index d0c19fad5ef29..512146836323b 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.h +++ b/onnxruntime/core/providers/webgpu/tensor/concat.h @@ -17,21 +17,7 @@ class ConcatProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"size_in_concat_axis", ProgramUniformVariableDataType::Uint32}, - {"output_size", ProgramUniformVariableDataType::Uint32}); - - private: - size_t axis_; -}; - -class ConcatProgramSingle final : public Program { - public: - ConcatProgramSingle(size_t axis) : Program{"Concat"}, axis_{axis} { - } - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"concat_axis_offset", ProgramUniformVariableDataType::Uint32}, + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"sizes_in_concat_axis", ProgramUniformVariableDataType::Uint32}, {"output_size", ProgramUniformVariableDataType::Uint32}); private: diff --git a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc index ec858ce1fc409..e4f4cc413bbde 100644 --- a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc @@ -73,6 +73,7 @@ TEST(ConcatOpTest, Concat1D_2) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, // TensorRT: no support for dynamic shape tensor kNnapiExecutionProvider, // NNAPI: concat does not support 0 size input + kWebGpuExecutionProvider, // WebGPU: concat does not support 0 size input kQnnExecutionProvider}); // QNN: not support dynamic shape tensor } From 235d855ddec9ee95af369f6c2314acde5acc6229 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Tue, 15 Jul 2025 06:56:15 -0700 Subject: [PATCH 06/12] adjust test --- .../core/providers/webgpu/tensor/concat.cc | 1 - .../providers/cpu/tensor/concat_op_test.cc | 40 +++++++++---------- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index 9146ccfcef98d..ededabacdb45b 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -91,7 +91,6 @@ Status Concat::ComputeInternal(ComputeContext& context) const { sizes.reserve(num_inputs_this_concat + 1); sizes_in_concat_axis.reserve(num_inputs_this_concat + 1); - // Start with the cumulative size from previous dispatches sizes.push_back(cumulative_output_size); sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis); diff --git a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc index e4f4cc413bbde..e8e1c34ca7c58 100644 --- a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc @@ -459,16 +459,16 @@ TEST(ConcatOpTest, Concat2D_exceed_maxStorageBuffersPerShaderStage_axis0) { OpTester test("Concat"); test.AddAttribute("axis", int64_t{0}); - test.AddInput("input1", {1, 2}, {1, 1}); - test.AddInput("input2", {1, 2}, {2, 2}); - test.AddInput("input3", {1, 2}, {3, 3}); - test.AddInput("input4", {1, 2}, {4, 4}); - test.AddInput("input5", {1, 2}, {5, 5}); - test.AddInput("input6", {1, 2}, {6, 6}); - test.AddInput("input7", {1, 2}, {7, 7}); - test.AddInput("input8", {1, 2}, {8, 8}); - test.AddInput("input9", {1, 2}, {9, 9}); - test.AddOutput("concat_result", {9, 2}, {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}); + test.AddInput("input1", {1, 2}, {1, 2}); + test.AddInput("input2", {1, 2}, {3, 4}); + test.AddInput("input3", {1, 2}, {5, 6}); + test.AddInput("input4", {1, 2}, {7, 8}); + test.AddInput("input5", {1, 2}, {9, 10}); + test.AddInput("input6", {1, 2}, {11, 12}); + test.AddInput("input7", {1, 2}, {13, 14}); + test.AddInput("input8", {1, 2}, {15, 16}); + test.AddInput("input9", {1, 2}, {17, 18}); + test.AddOutput("concat_result", {9, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); test.Run(); } @@ -477,16 +477,16 @@ TEST(ConcatOpTest, Concat2D_exceed_maxStorageBuffersPerShaderStage_axis1) { OpTester test("Concat"); test.AddAttribute("axis", int64_t{1}); - test.AddInput("input1", {1, 2}, {1, 1}); - test.AddInput("input2", {1, 2}, {2, 2}); - test.AddInput("input3", {1, 2}, {3, 3}); - test.AddInput("input4", {1, 2}, {4, 4}); - test.AddInput("input5", {1, 2}, {5, 5}); - test.AddInput("input6", {1, 2}, {6, 6}); - test.AddInput("input7", {1, 2}, {7, 7}); - test.AddInput("input8", {1, 2}, {8, 8}); - test.AddInput("input9", {1, 2}, {9, 9}); - test.AddOutput("concat_result", {1, 18}, {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}); + test.AddInput("input1", {1, 2}, {1, 2}); + test.AddInput("input2", {1, 2}, {3, 4}); + test.AddInput("input3", {1, 2}, {5, 6}); + test.AddInput("input4", {1, 2}, {7, 8}); + test.AddInput("input5", {1, 2}, {9, 10}); + test.AddInput("input6", {1, 2}, {11, 12}); + test.AddInput("input7", {1, 2}, {13, 14}); + test.AddInput("input8", {1, 2}, {15, 16}); + test.AddInput("input9", {1, 2}, {17, 18}); + test.AddOutput("concat_result", {1, 18}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); test.Run(); } From 898e0d0e356a532d5a215e58daa0951f4092a77a Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Tue, 15 Jul 2025 08:29:30 -0700 Subject: [PATCH 07/12] fix edge case --- onnxruntime/core/providers/webgpu/tensor/concat.cc | 4 ++++ .../test/providers/cpu/tensor/concat_op_test.cc | 12 ++++++++++++ 2 files changed, 16 insertions(+) diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index ededabacdb45b..c07e5fa4cfbb6 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -110,6 +110,10 @@ Status Concat::ComputeInternal(ComputeContext& context) const { sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis); } + // Remove the last element from both vectors to prevent out of bounds writes + sizes.pop_back(); + sizes_in_concat_axis.pop_back(); + program.CacheHint(absl::StrJoin(std::make_tuple(num_inputs_this_concat, prepare.axis), ",")) .AddOutputs({prepare.output_tensor}) .SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) diff --git a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc index e8e1c34ca7c58..9ecb34c41bd59 100644 --- a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc @@ -436,6 +436,18 @@ TEST(ConcatOpTest, Concat4D_2) { } #ifdef USE_WEBGPU +TEST(ConcatOpTest, Concat1D_int32_4inputs) { + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{0}); + + test.AddInput("input1", {1}, {1}); + test.AddInput("input2", {2}, {2, 3}); + test.AddInput("input3", {4}, {4, 5, 6, 7}); + test.AddInput("input4", {2}, {8, 9}); + test.AddOutput("concat_result", {9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + test.Run(); +} + TEST(ConcatOpTest, Concat1D_exceed_maxStorageBuffersPerShaderStage) { // maxStorageBuffersPerShaderStage==8 OpTester test("Concat"); From c9e6303693159ec0564980e013eab25f4ef8c81a Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Tue, 15 Jul 2025 13:15:22 -0700 Subject: [PATCH 08/12] remove extra variable --- onnxruntime/core/providers/webgpu/tensor/concat.cc | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index c07e5fa4cfbb6..f12c9a0b6e138 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -79,19 +79,14 @@ Status Concat::ComputeInternal(ComputeContext& context) const { uint32_t max_inputs_per_concat = context.DeviceLimits().maxStorageBuffersPerShaderStage - 1; uint32_t input_index = 0; - uint32_t cumulative_output_size = 0; uint32_t cumulative_size_in_concat_axis = 0; while (input_index < input_count) { ConcatProgram program{axis}; uint32_t num_inputs_this_concat = std::min(max_inputs_per_concat, input_count - input_index); - std::vector sizes; std::vector sizes_in_concat_axis; - sizes.reserve(num_inputs_this_concat + 1); sizes_in_concat_axis.reserve(num_inputs_this_concat + 1); - - sizes.push_back(cumulative_output_size); sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis); uint32_t dispatch_size = 0; @@ -102,16 +97,11 @@ Status Concat::ComputeInternal(ComputeContext& context) const { uint32_t size = onnxruntime::narrow(input.tensor->Shape().Size()); uint32_t axis_size = static_cast(input.tensor->Shape()[axis]); - cumulative_output_size += size; dispatch_size += size; - sizes.push_back(cumulative_output_size); - cumulative_size_in_concat_axis += axis_size; sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis); } - // Remove the last element from both vectors to prevent out of bounds writes - sizes.pop_back(); sizes_in_concat_axis.pop_back(); program.CacheHint(absl::StrJoin(std::make_tuple(num_inputs_this_concat, prepare.axis), ",")) From 9e560e9723f7ace666196b5cf955e6cb01160294 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Tue, 15 Jul 2025 16:43:45 -0700 Subject: [PATCH 09/12] simplify --- onnxruntime/core/providers/webgpu/tensor/concat.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index f12c9a0b6e138..05b13fc6efbc1 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -86,11 +86,11 @@ Status Concat::ComputeInternal(ComputeContext& context) const { uint32_t num_inputs_this_concat = std::min(max_inputs_per_concat, input_count - input_index); std::vector sizes_in_concat_axis; - sizes_in_concat_axis.reserve(num_inputs_this_concat + 1); + sizes_in_concat_axis.reserve(num_inputs_this_concat); sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis); uint32_t dispatch_size = 0; - for (uint32_t i = 0; i < num_inputs_this_concat; i++) { + for (uint32_t i = 0; i < num_inputs_this_concat - 1; i++) { auto& input = prepare.inputs[input_index + i]; program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}); @@ -102,8 +102,6 @@ Status Concat::ComputeInternal(ComputeContext& context) const { sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis); } - sizes_in_concat_axis.pop_back(); - program.CacheHint(absl::StrJoin(std::make_tuple(num_inputs_this_concat, prepare.axis), ",")) .AddOutputs({prepare.output_tensor}) .SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) From 25e786f99d75180bc4dac00adda552e6e32f1c6c Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Tue, 15 Jul 2025 16:44:20 -0700 Subject: [PATCH 10/12] simplify --- onnxruntime/core/providers/webgpu/tensor/concat.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index 05b13fc6efbc1..dc5818a93224d 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -89,7 +89,7 @@ Status Concat::ComputeInternal(ComputeContext& context) const { sizes_in_concat_axis.reserve(num_inputs_this_concat); sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis); - uint32_t dispatch_size = 0; + uint32_t output_size = 0; for (uint32_t i = 0; i < num_inputs_this_concat - 1; i++) { auto& input = prepare.inputs[input_index + i]; program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}); @@ -97,15 +97,15 @@ Status Concat::ComputeInternal(ComputeContext& context) const { uint32_t size = onnxruntime::narrow(input.tensor->Shape().Size()); uint32_t axis_size = static_cast(input.tensor->Shape()[axis]); - dispatch_size += size; + output_size += size; cumulative_size_in_concat_axis += axis_size; sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis); } program.CacheHint(absl::StrJoin(std::make_tuple(num_inputs_this_concat, prepare.axis), ",")) .AddOutputs({prepare.output_tensor}) - .SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .AddUniformVariables({gsl::span(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), dispatch_size}); + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({gsl::span(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), output_size}); ORT_RETURN_IF_ERROR(context.RunProgram(program)); input_index += num_inputs_this_concat; From fa6d6c6a851bafe17a7a99406e0c9b0f57f005e7 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Tue, 15 Jul 2025 19:44:58 -0700 Subject: [PATCH 11/12] revert --- onnxruntime/core/providers/webgpu/tensor/concat.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index dc5818a93224d..4c9bcff3a61ac 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -86,11 +86,11 @@ Status Concat::ComputeInternal(ComputeContext& context) const { uint32_t num_inputs_this_concat = std::min(max_inputs_per_concat, input_count - input_index); std::vector sizes_in_concat_axis; - sizes_in_concat_axis.reserve(num_inputs_this_concat); + sizes_in_concat_axis.reserve(num_inputs_this_concat + 1); sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis); uint32_t output_size = 0; - for (uint32_t i = 0; i < num_inputs_this_concat - 1; i++) { + for (uint32_t i = 0; i < num_inputs_this_concat; i++) { auto& input = prepare.inputs[input_index + i]; program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}); @@ -102,6 +102,8 @@ Status Concat::ComputeInternal(ComputeContext& context) const { sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis); } + sizes_in_concat_axis.pop_back(); + program.CacheHint(absl::StrJoin(std::make_tuple(num_inputs_this_concat, prepare.axis), ",")) .AddOutputs({prepare.output_tensor}) .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) From ae9f06bea0747e164613fdfc4996082314112bce Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Wed, 16 Jul 2025 14:36:04 -0700 Subject: [PATCH 12/12] optimal solution --- .../core/providers/webgpu/tensor/concat.cc | 57 +++++++++++++++---- .../core/providers/webgpu/tensor/concat.h | 3 +- .../providers/cpu/tensor/concat_op_test.cc | 17 +++++- 3 files changed, 65 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index 4c9bcff3a61ac..283a9e5fe8262 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -38,6 +38,38 @@ WEBGPU_CONCAT_VERSIONED_KERNEL(4, 10) WEBGPU_CONCAT_VERSIONED_KERNEL(11, 12) WEBGPU_CONCAT_KERNEL(13) +void AppendCalculateInputIndexFunction(std::ostream& os, size_t input_count) { + os << "fn calculate_input_index(global_idx: u32) -> u32 {\n" + << " for (var i = 1u; i < " << input_count << "; i = i + 1u) {\n" + << " if (global_idx < " << GetElementAt("uniforms.offsets", "i", input_count) << ") {\n" + << " return i - 1;\n" + << " }\n" + << " }\n" + << " return " << input_count - 1 << ";\n" + << "}\n"; +} + +void AppendAssignOutputDataFunction(std::ostream& os, gsl::span inputs, const ShaderVariableHelper& output, size_t axis, size_t input_count) { + os << "fn assign_output_data(global_idx: u32, input_index: u32) {\n"; + for (size_t i = 0; i < inputs.size(); ++i) { + if (i == 0) { + os << " if (input_index == 0u) {\n"; + } else if (i == inputs.size() - 1) { + os << " } else {\n"; + } else { + os << " } else if (input_index == " << i << "u) {\n"; + } + std::string offset = GetElementAt("uniforms.offsets", "input_index", input_count); + std::string concat_axis_offset = GetElementAt("uniforms.sizes_in_concat_axis", std::to_string(i), input_count); + std::string output_indices_axis = "output_indices" + (inputs[i]->Rank() > 1 ? "[" + std::to_string(axis) + "]" : ""); + os << " var output_indices = " << inputs[i]->OffsetToIndices("global_idx - " + offset) << ";\n" + << " " << output_indices_axis << " += " << concat_axis_offset << ";\n" + << " " << output.SetByIndices("output_indices", inputs[i]->GetByOffset("global_idx - " + offset)) << "\n"; + } + os << " }\n" + "}\n"; +} + Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { size_t input_count = Inputs().size(); std::vector inputs; @@ -47,16 +79,12 @@ Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { } const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"); - for (size_t i = 0; i < input_count; ++i) { - const std::string output_indices_i = absl::StrCat("output_indices_", i); - const std::string output_indices_i_axis = output_indices_i + (inputs[i]->Rank() > 1 ? "[" + std::to_string(axis_) + "]" : ""); - const std::string concat_axis_offset = GetElementAt("uniforms.sizes_in_concat_axis", std::to_string(i), input_count); + AppendCalculateInputIndexFunction(shader.AdditionalImplementation(), input_count); + AppendAssignOutputDataFunction(shader.AdditionalImplementation(), inputs, output, axis_, input_count); - shader.MainFunctionBody() << " var " << output_indices_i << " = " << inputs[i]->OffsetToIndices("global_idx") << ";\n" - << " " << output_indices_i_axis << " += " << concat_axis_offset << ";\n" - << " " << output.SetByIndices(output_indices_i, inputs[i]->GetByOffset("global_idx")) << "\n"; - } + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let input_index = calculate_input_index(global_idx);\n" + << "assign_output_data(global_idx, input_index);\n"; return Status::OK(); } @@ -85,6 +113,10 @@ Status Concat::ComputeInternal(ComputeContext& context) const { ConcatProgram program{axis}; uint32_t num_inputs_this_concat = std::min(max_inputs_per_concat, input_count - input_index); + std::vector offsets; + offsets.reserve(num_inputs_this_concat + 1); + offsets.push_back(0); + std::vector sizes_in_concat_axis; sizes_in_concat_axis.reserve(num_inputs_this_concat + 1); sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis); @@ -92,22 +124,27 @@ Status Concat::ComputeInternal(ComputeContext& context) const { uint32_t output_size = 0; for (uint32_t i = 0; i < num_inputs_this_concat; i++) { auto& input = prepare.inputs[input_index + i]; + if (input.tensor->Shape().Size() == 0) { + continue; + } program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}); uint32_t size = onnxruntime::narrow(input.tensor->Shape().Size()); uint32_t axis_size = static_cast(input.tensor->Shape()[axis]); output_size += size; + offsets.push_back(output_size); cumulative_size_in_concat_axis += axis_size; sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis); } + offsets.pop_back(); sizes_in_concat_axis.pop_back(); program.CacheHint(absl::StrJoin(std::make_tuple(num_inputs_this_concat, prepare.axis), ",")) .AddOutputs({prepare.output_tensor}) .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .AddUniformVariables({gsl::span(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), output_size}); + .AddUniformVariables({gsl::span(offsets.data(), offsets.size()), gsl::span(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), output_size}); ORT_RETURN_IF_ERROR(context.RunProgram(program)); input_index += num_inputs_this_concat; diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.h b/onnxruntime/core/providers/webgpu/tensor/concat.h index 512146836323b..7980556e0a1f4 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.h +++ b/onnxruntime/core/providers/webgpu/tensor/concat.h @@ -17,7 +17,8 @@ class ConcatProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"sizes_in_concat_axis", ProgramUniformVariableDataType::Uint32}, + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"offsets", ProgramUniformVariableDataType::Uint32}, + {"sizes_in_concat_axis", ProgramUniformVariableDataType::Uint32}, {"output_size", ProgramUniformVariableDataType::Uint32}); private: diff --git a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc index 9ecb34c41bd59..b5e13c6377ccb 100644 --- a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc @@ -73,7 +73,6 @@ TEST(ConcatOpTest, Concat1D_2) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, // TensorRT: no support for dynamic shape tensor kNnapiExecutionProvider, // NNAPI: concat does not support 0 size input - kWebGpuExecutionProvider, // WebGPU: concat does not support 0 size input kQnnExecutionProvider}); // QNN: not support dynamic shape tensor } @@ -522,6 +521,22 @@ TEST(ConcatOpTest, Concat3D_exceed_maxStorageBuffersPerShaderStage) { 2, 4, 6, 8, 10, 12, 14, 16, 18}); test.Run(); } + +TEST(ConcatOpTest, Concat3D_exceed_maxStorageBuffersPerShaderStage_mixed_sizes) { + // maxStorageBuffersPerShaderStage==8 + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{1}); + + test.AddInput("input1", {2, 1, 1}, {1, 2}); + test.AddInput("input2", {2, 3, 1}, {3, 4, 5, 6, 7, 8}); + test.AddInput("input3", {2, 2, 1}, {9, 10, 11, 12}); + test.AddInput("input4", {2, 1, 1}, {13, 14}); + test.AddOutput("concat_result", {2, 7, 1}, {// batch 0 + 1, 3, 4, 5, 9, 10, 13, + // batch 1 + 2, 6, 7, 8, 11, 12, 14}); + test.Run(); +} #endif // USE_WEBGPU } // namespace test