Skip to content
37 changes: 32 additions & 5 deletions onnxruntime/core/providers/webgpu/tensor/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@
return Status::OK();
}

Status ConcatProgramSingle::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const std::string output_indices_str = "output_indices" + (input.Rank() > 1 ? "[" + std::to_string(axis_) + "]" : "");

Check warning on line 96 in onnxruntime/core/providers/webgpu/tensor/concat.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/tensor/concat.cc:96: Add #include <string> for string [build/include_what_you_use] [4]

shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
<< " var output_indices = " << input.OffsetToIndices("global_idx") << ";\n"
<< " " << output_indices_str << " += uniforms.concat_axis_offset;\n"
<< " " << output.SetByIndices("output_indices", input.GetByOffset("global_idx")) << "\n";

return Status::OK();
}

Status Concat::ComputeInternal(ComputeContext& context) const {
int input_count = context.InputCount();
InlinedTensorsVector input_tensors;
Expand Down Expand Up @@ -127,10 +140,24 @@
size_t non_empty_input_count = sizes_in_concat_axis.size();

if (non_empty_input_count + 1 > context.DeviceLimits().maxStorageBuffersPerShaderStage) {
// TODO: support when input_count + 1 > maxStorageBuffersPerShaderStage, by raising the limit or run the program in multiple passes.
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "The number of storage buffer (input=",
input_count, ", output=1) exceeds the limit (",
context.DeviceLimits().maxStorageBuffersPerShaderStage, ") of the device.");
LOGS_DEFAULT(WARNING) << "Storage buffer limit exceeded for Concat. Running operation one input at a time.";
uint32_t concat_axis_offset = 0;
for (uint32_t input_index = 0; input_index < static_cast<uint32_t>(input_count); input_index++) {
const auto& input = prepare.inputs[input_index];
auto axis_size = input.tensor->Shape()[axis];

ConcatProgramSingle pass_program{axis};
pass_program.CacheHint(absl::StrJoin(std::make_tuple(prepare.axis), ","))
.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank})
.AddOutputs({prepare.output_tensor})
.SetDispatchGroupSize((input.tensor->Shape().Size() + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({output_size, concat_axis_offset});
ORT_RETURN_IF_ERROR(context.RunProgram(pass_program));

concat_axis_offset += static_cast<uint32_t>(axis_size);
}

return Status::OK();
}

program.CacheHint(absl::StrJoin(std::make_tuple(non_empty_input_count, prepare.axis), ","))
Expand All @@ -142,4 +169,4 @@
}

} // namespace webgpu
} // namespace onnxruntime
} // namespace onnxruntime
16 changes: 15 additions & 1 deletion onnxruntime/core/providers/webgpu/tensor/concat.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ class ConcatProgram final : public Program<ConcatProgram> {
size_t axis_;
};

class ConcatProgramSingle final : public Program<ConcatProgramSingle> {
public:
ConcatProgramSingle(size_t axis) : Program{"Concat"}, axis_{axis} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"concat_axis_offset", ProgramUniformVariableDataType::Uint32},
{"output_size", ProgramUniformVariableDataType::Uint32});

private:
size_t axis_;
};

class Concat final : public WebGpuKernel, public ConcatBase {
public:
Concat(const OpKernelInfo& info) : WebGpuKernel(info), ConcatBase(info) {
Expand All @@ -33,4 +47,4 @@ class Concat final : public WebGpuKernel, public ConcatBase {
};

} // namespace webgpu
} // namespace onnxruntime
} // namespace onnxruntime
77 changes: 77 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/concat_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,5 +434,82 @@ TEST(ConcatOpTest, Concat4D_2) {
test.Run();
}

#ifdef USE_WEBGPU
TEST(ConcatOpTest, Concat1D_exceed_maxStorageBuffersPerShaderStage) {
// maxStorageBuffersPerShaderStage==8
OpTester test("Concat");
test.AddAttribute("axis", int64_t{0});

test.AddInput<int32_t>("input1", {1}, {1});
test.AddInput<int32_t>("input2", {1}, {2});
test.AddInput<int32_t>("input3", {1}, {3});
test.AddInput<int32_t>("input4", {1}, {4});
test.AddInput<int32_t>("input5", {1}, {5});
test.AddInput<int32_t>("input6", {1}, {6});
test.AddInput<int32_t>("input7", {1}, {7});
test.AddInput<int32_t>("input8", {1}, {8});
test.AddInput<int32_t>("input9", {1}, {9});
test.AddOutput<int32_t>("concat_result", {9}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
test.Run();
}

TEST(ConcatOpTest, Concat2D_exceed_maxStorageBuffersPerShaderStage_axis0) {
// maxStorageBuffersPerShaderStage==8
OpTester test("Concat");
test.AddAttribute("axis", int64_t{0});

test.AddInput<int32_t>("input1", {1, 2}, {1, 1});
test.AddInput<int32_t>("input2", {1, 2}, {2, 2});
test.AddInput<int32_t>("input3", {1, 2}, {3, 3});
test.AddInput<int32_t>("input4", {1, 2}, {4, 4});
test.AddInput<int32_t>("input5", {1, 2}, {5, 5});
test.AddInput<int32_t>("input6", {1, 2}, {6, 6});
test.AddInput<int32_t>("input7", {1, 2}, {7, 7});
test.AddInput<int32_t>("input8", {1, 2}, {8, 8});
test.AddInput<int32_t>("input9", {1, 2}, {9, 9});
test.AddOutput<int32_t>("concat_result", {9, 2}, {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9});
test.Run();
}

TEST(ConcatOpTest, Concat2D_exceed_maxStorageBuffersPerShaderStage_axis1) {
// maxStorageBuffersPerShaderStage==8
OpTester test("Concat");
test.AddAttribute("axis", int64_t{1});

test.AddInput<int32_t>("input1", {1, 2}, {1, 1});
test.AddInput<int32_t>("input2", {1, 2}, {2, 2});
test.AddInput<int32_t>("input3", {1, 2}, {3, 3});
test.AddInput<int32_t>("input4", {1, 2}, {4, 4});
test.AddInput<int32_t>("input5", {1, 2}, {5, 5});
test.AddInput<int32_t>("input6", {1, 2}, {6, 6});
test.AddInput<int32_t>("input7", {1, 2}, {7, 7});
test.AddInput<int32_t>("input8", {1, 2}, {8, 8});
test.AddInput<int32_t>("input9", {1, 2}, {9, 9});
test.AddOutput<int32_t>("concat_result", {1, 18}, {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9});
test.Run();
}

TEST(ConcatOpTest, Concat3D_exceed_maxStorageBuffersPerShaderStage) {
// maxStorageBuffersPerShaderStage==8
OpTester test("Concat");
test.AddAttribute("axis", int64_t{1});

test.AddInput<int32_t>("input1", {2, 1, 1}, {1, 2});
test.AddInput<int32_t>("input2", {2, 1, 1}, {3, 4});
test.AddInput<int32_t>("input3", {2, 1, 1}, {5, 6});
test.AddInput<int32_t>("input4", {2, 1, 1}, {7, 8});
test.AddInput<int32_t>("input5", {2, 1, 1}, {9, 10});
test.AddInput<int32_t>("input6", {2, 1, 1}, {11, 12});
test.AddInput<int32_t>("input7", {2, 1, 1}, {13, 14});
test.AddInput<int32_t>("input8", {2, 1, 1}, {15, 16});
test.AddInput<int32_t>("input9", {2, 1, 1}, {17, 18});
test.AddOutput<int32_t>("concat_result", {2, 9, 1}, {// batch 0
1, 3, 5, 7, 9, 11, 13, 15, 17,
// batch 1
2, 4, 6, 8, 10, 12, 14, 16, 18});
test.Run();
}
#endif // USE_WEBGPU

} // namespace test
} // namespace onnxruntime
Loading