Skip to content

Commit f041df3

Browse files
prathikrjnagi-intel
authored andcommitted
[WebGPU EP] allow concat operator to handle large number of inputs (microsoft#25390)
### Description <!-- Describe your changes. --> Adjusts concat operator to batch inputs based on maxStorageBuffersPerShaderStage to allow unlimited number of inputs. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Fixes patchtst model for transformers.js <img width="960" height="367" alt="{31C75CD1-7A7D-48E3-A090-FB153925D165}" src="https://github.com/user-attachments/assets/f5772709-80b7-4a05-8927-40f496be908c" />
1 parent 26ecacf commit f041df3

File tree

3 files changed

+171
-53
lines changed

3 files changed

+171
-53
lines changed

onnxruntime/core/providers/webgpu/tensor/concat.cc

Lines changed: 63 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,19 @@ WEBGPU_CONCAT_VERSIONED_KERNEL(4, 10)
3838
WEBGPU_CONCAT_VERSIONED_KERNEL(11, 12)
3939
WEBGPU_CONCAT_KERNEL(13)
4040

41-
void AppendCalCulateInputIndexFunction(std::ostream& os, size_t input_count) {
42-
os << "fn calculate_input_index(index: u32) -> u32 {\n"
43-
<< " for (var i = 0u; i < " << input_count << "; i = i + 1u) {\n"
44-
<< " if (index < " << GetElementAt("uniforms.size_in_concat_axis", "i", input_count) << ") {\n"
45-
<< " return i;\n"
41+
void AppendCalculateInputIndexFunction(std::ostream& os, size_t input_count) {
42+
os << "fn calculate_input_index(global_idx: u32) -> u32 {\n"
43+
<< " for (var i = 1u; i < " << input_count << "; i = i + 1u) {\n"
44+
<< " if (global_idx < " << GetElementAt("uniforms.offsets", "i", input_count) << ") {\n"
45+
<< " return i - 1;\n"
4646
<< " }\n"
4747
<< " }\n"
48-
<< " return " << input_count << ";\n"
48+
<< " return " << input_count - 1 << ";\n"
4949
<< "}\n";
5050
}
5151

52-
void AppendAssignOutputDataFunction(std::ostream& os, gsl::span<const ShaderVariableHelper*> inputs, const ShaderVariableHelper& output) {
53-
os << "fn assign_output_data(global_idx: u32, input_index: u32, indices: output_indices_t) {\n";
52+
void AppendAssignOutputDataFunction(std::ostream& os, gsl::span<const ShaderVariableHelper*> inputs, const ShaderVariableHelper& output, size_t axis, size_t input_count) {
53+
os << "fn assign_output_data(global_idx: u32, input_index: u32) {\n";
5454
for (size_t i = 0; i < inputs.size(); ++i) {
5555
if (i == 0) {
5656
os << " if (input_index == 0u) {\n";
@@ -59,7 +59,12 @@ void AppendAssignOutputDataFunction(std::ostream& os, gsl::span<const ShaderVari
5959
} else {
6060
os << " } else if (input_index == " << i << "u) {\n";
6161
}
62-
os << " " << output.SetByOffset("global_idx", inputs[i]->GetByIndices("indices")) << ";\n";
62+
std::string offset = GetElementAt("uniforms.offsets", "input_index", input_count);
63+
std::string concat_axis_offset = GetElementAt("uniforms.sizes_in_concat_axis", std::to_string(i), input_count);
64+
std::string output_indices_axis = "output_indices" + (inputs[i]->Rank() > 1 ? "[" + std::to_string(axis) + "]" : "");
65+
os << " var output_indices = " << inputs[i]->OffsetToIndices("global_idx - " + offset) << ";\n"
66+
<< " " << output_indices_axis << " += " << concat_axis_offset << ";\n"
67+
<< " " << output.SetByIndices("output_indices", inputs[i]->GetByOffset("global_idx - " + offset)) << "\n";
6368
}
6469
os << " }\n"
6570
"}\n";
@@ -74,27 +79,21 @@ Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const {
7479
}
7580
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
7681

77-
// add implementation of fn calculate_input_index
78-
AppendCalCulateInputIndexFunction(shader.AdditionalImplementation(), input_count);
79-
// add implementation of fn assign_output_data
80-
AppendAssignOutputDataFunction(shader.AdditionalImplementation(), inputs, output);
81-
const std::string size_in_concat_axis = GetElementAt("uniforms.size_in_concat_axis", "input_index - 1", input_count);
82+
AppendCalculateInputIndexFunction(shader.AdditionalImplementation(), input_count);
83+
AppendAssignOutputDataFunction(shader.AdditionalImplementation(), inputs, output, axis_, input_count);
84+
8285
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
83-
<< " var indices = " << output.OffsetToIndices("global_idx") << ";\n"
84-
<< " let indices_axis = " << output.IndicesGet("indices", axis_) << ";\n"
85-
<< " let input_index = calculate_input_index(indices_axis);\n"
86-
<< " if (input_index != 0u) {\n"
87-
<< " " << output.IndicesSet("indices", axis_, "indices_axis - " + size_in_concat_axis) << ";\n"
88-
<< " }\n"
89-
" assign_output_data(global_idx, input_index, indices);\n";
86+
<< "let input_index = calculate_input_index(global_idx);\n"
87+
<< "assign_output_data(global_idx, input_index);\n";
88+
9089
return Status::OK();
9190
}
9291

9392
Status Concat::ComputeInternal(ComputeContext& context) const {
94-
int input_count = context.InputCount();
93+
uint32_t input_count = context.InputCount();
9594
InlinedTensorsVector input_tensors;
9695
input_tensors.reserve(input_count);
97-
for (int i = 0; i < input_count; ++i) {
96+
for (uint32_t i = 0; i < input_count; ++i) {
9897
input_tensors.push_back(context.Input<Tensor>(i));
9998
}
10099

@@ -104,42 +103,55 @@ Status Concat::ComputeInternal(ComputeContext& context) const {
104103
return Status::OK();
105104
}
106105

107-
uint32_t output_size = onnxruntime::narrow<int32_t>(prepare.output_tensor->Shape().Size());
106+
uint32_t axis = static_cast<uint32_t>(prepare.axis);
107+
uint32_t max_inputs_per_concat = context.DeviceLimits().maxStorageBuffersPerShaderStage - 1;
108+
109+
uint32_t input_index = 0;
110+
uint32_t cumulative_size_in_concat_axis = 0;
111+
112+
while (input_index < input_count) {
113+
ConcatProgram program{axis};
114+
uint32_t num_inputs_this_concat = std::min(max_inputs_per_concat, input_count - input_index);
115+
116+
std::vector<uint32_t> offsets;
117+
offsets.reserve(num_inputs_this_concat + 1);
118+
offsets.push_back(0);
108119

109-
size_t axis = static_cast<size_t>(prepare.axis);
110-
ConcatProgram program{axis};
120+
std::vector<uint32_t> sizes_in_concat_axis;
121+
sizes_in_concat_axis.reserve(num_inputs_this_concat + 1);
122+
sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis);
111123

112-
std::vector<uint32_t> sizes_in_concat_axis;
113-
sizes_in_concat_axis.reserve(input_count);
114-
uint32_t sum = 0;
115-
for (int i = 0; i < input_count; ++i) {
116-
const auto& input = prepare.inputs[i];
117-
if (input.tensor->Shape().Size() == 0) {
118-
continue;
124+
uint32_t output_size = 0;
125+
for (uint32_t i = 0; i < num_inputs_this_concat; i++) {
126+
auto& input = prepare.inputs[input_index + i];
127+
if (input.tensor->Shape().Size() == 0) {
128+
continue;
129+
}
130+
program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank});
131+
132+
uint32_t size = onnxruntime::narrow<int32_t>(input.tensor->Shape().Size());
133+
uint32_t axis_size = static_cast<uint32_t>(input.tensor->Shape()[axis]);
134+
135+
output_size += size;
136+
offsets.push_back(output_size);
137+
cumulative_size_in_concat_axis += axis_size;
138+
sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis);
119139
}
120-
program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank});
121140

122-
auto axis_size = input.tensor->Shape()[axis];
123-
sum += static_cast<uint32_t>(axis_size);
124-
sizes_in_concat_axis.push_back(sum);
125-
}
141+
offsets.pop_back();
142+
sizes_in_concat_axis.pop_back();
126143

127-
size_t non_empty_input_count = sizes_in_concat_axis.size();
144+
program.CacheHint(absl::StrJoin(std::make_tuple(num_inputs_this_concat, prepare.axis), ","))
145+
.AddOutputs({prepare.output_tensor})
146+
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
147+
.AddUniformVariables({gsl::span<const uint32_t>(offsets.data(), offsets.size()), gsl::span<const uint32_t>(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), output_size});
148+
ORT_RETURN_IF_ERROR(context.RunProgram(program));
128149

129-
if (non_empty_input_count + 1 > context.DeviceLimits().maxStorageBuffersPerShaderStage) {
130-
// TODO: support when input_count + 1 > maxStorageBuffersPerShaderStage, by raising the limit or run the program in multiple passes.
131-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "The number of storage buffer (input=",
132-
input_count, ", output=1) exceeds the limit (",
133-
context.DeviceLimits().maxStorageBuffersPerShaderStage, ") of the device.");
150+
input_index += num_inputs_this_concat;
134151
}
135152

136-
program.CacheHint(absl::StrJoin(std::make_tuple(non_empty_input_count, prepare.axis), ","))
137-
.AddOutputs({prepare.output_tensor})
138-
.SetDispatchGroupSize((prepare.output_num_elements + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
139-
.AddUniformVariables({gsl::span<const uint32_t>(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()),
140-
output_size});
141-
return context.RunProgram(program);
153+
return Status::OK();
142154
}
143155

144156
} // namespace webgpu
145-
} // namespace onnxruntime
157+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/tensor/concat.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ class ConcatProgram final : public Program<ConcatProgram> {
1717

1818
Status GenerateShaderCode(ShaderHelper& sh) const override;
1919

20-
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"size_in_concat_axis", ProgramUniformVariableDataType::Uint32},
20+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"offsets", ProgramUniformVariableDataType::Uint32},
21+
{"sizes_in_concat_axis", ProgramUniformVariableDataType::Uint32},
2122
{"output_size", ProgramUniformVariableDataType::Uint32});
2223

2324
private:
@@ -33,4 +34,4 @@ class Concat final : public WebGpuKernel, public ConcatBase {
3334
};
3435

3536
} // namespace webgpu
36-
} // namespace onnxruntime
37+
} // namespace onnxruntime

onnxruntime/test/providers/cpu/tensor/concat_op_test.cc

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,5 +434,110 @@ TEST(ConcatOpTest, Concat4D_2) {
434434
test.Run();
435435
}
436436

437+
#ifdef USE_WEBGPU
438+
TEST(ConcatOpTest, Concat1D_int32_4inputs) {
439+
OpTester test("Concat");
440+
test.AddAttribute("axis", int64_t{0});
441+
442+
test.AddInput<int32_t>("input1", {1}, {1});
443+
test.AddInput<int32_t>("input2", {2}, {2, 3});
444+
test.AddInput<int32_t>("input3", {4}, {4, 5, 6, 7});
445+
test.AddInput<int32_t>("input4", {2}, {8, 9});
446+
test.AddOutput<int32_t>("concat_result", {9}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
447+
test.Run();
448+
}
449+
450+
TEST(ConcatOpTest, Concat1D_exceed_maxStorageBuffersPerShaderStage) {
451+
// maxStorageBuffersPerShaderStage==8
452+
OpTester test("Concat");
453+
test.AddAttribute("axis", int64_t{0});
454+
455+
test.AddInput<int32_t>("input1", {1}, {1});
456+
test.AddInput<int32_t>("input2", {1}, {2});
457+
test.AddInput<int32_t>("input3", {1}, {3});
458+
test.AddInput<int32_t>("input4", {1}, {4});
459+
test.AddInput<int32_t>("input5", {1}, {5});
460+
test.AddInput<int32_t>("input6", {1}, {6});
461+
test.AddInput<int32_t>("input7", {1}, {7});
462+
test.AddInput<int32_t>("input8", {1}, {8});
463+
test.AddInput<int32_t>("input9", {1}, {9});
464+
test.AddOutput<int32_t>("concat_result", {9}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
465+
test.Run();
466+
}
467+
468+
TEST(ConcatOpTest, Concat2D_exceed_maxStorageBuffersPerShaderStage_axis0) {
469+
// maxStorageBuffersPerShaderStage==8
470+
OpTester test("Concat");
471+
test.AddAttribute("axis", int64_t{0});
472+
473+
test.AddInput<int32_t>("input1", {1, 2}, {1, 2});
474+
test.AddInput<int32_t>("input2", {1, 2}, {3, 4});
475+
test.AddInput<int32_t>("input3", {1, 2}, {5, 6});
476+
test.AddInput<int32_t>("input4", {1, 2}, {7, 8});
477+
test.AddInput<int32_t>("input5", {1, 2}, {9, 10});
478+
test.AddInput<int32_t>("input6", {1, 2}, {11, 12});
479+
test.AddInput<int32_t>("input7", {1, 2}, {13, 14});
480+
test.AddInput<int32_t>("input8", {1, 2}, {15, 16});
481+
test.AddInput<int32_t>("input9", {1, 2}, {17, 18});
482+
test.AddOutput<int32_t>("concat_result", {9, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
483+
test.Run();
484+
}
485+
486+
TEST(ConcatOpTest, Concat2D_exceed_maxStorageBuffersPerShaderStage_axis1) {
487+
// maxStorageBuffersPerShaderStage==8
488+
OpTester test("Concat");
489+
test.AddAttribute("axis", int64_t{1});
490+
491+
test.AddInput<int32_t>("input1", {1, 2}, {1, 2});
492+
test.AddInput<int32_t>("input2", {1, 2}, {3, 4});
493+
test.AddInput<int32_t>("input3", {1, 2}, {5, 6});
494+
test.AddInput<int32_t>("input4", {1, 2}, {7, 8});
495+
test.AddInput<int32_t>("input5", {1, 2}, {9, 10});
496+
test.AddInput<int32_t>("input6", {1, 2}, {11, 12});
497+
test.AddInput<int32_t>("input7", {1, 2}, {13, 14});
498+
test.AddInput<int32_t>("input8", {1, 2}, {15, 16});
499+
test.AddInput<int32_t>("input9", {1, 2}, {17, 18});
500+
test.AddOutput<int32_t>("concat_result", {1, 18}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
501+
test.Run();
502+
}
503+
504+
TEST(ConcatOpTest, Concat3D_exceed_maxStorageBuffersPerShaderStage) {
505+
// maxStorageBuffersPerShaderStage==8
506+
OpTester test("Concat");
507+
test.AddAttribute("axis", int64_t{1});
508+
509+
test.AddInput<int32_t>("input1", {2, 1, 1}, {1, 2});
510+
test.AddInput<int32_t>("input2", {2, 1, 1}, {3, 4});
511+
test.AddInput<int32_t>("input3", {2, 1, 1}, {5, 6});
512+
test.AddInput<int32_t>("input4", {2, 1, 1}, {7, 8});
513+
test.AddInput<int32_t>("input5", {2, 1, 1}, {9, 10});
514+
test.AddInput<int32_t>("input6", {2, 1, 1}, {11, 12});
515+
test.AddInput<int32_t>("input7", {2, 1, 1}, {13, 14});
516+
test.AddInput<int32_t>("input8", {2, 1, 1}, {15, 16});
517+
test.AddInput<int32_t>("input9", {2, 1, 1}, {17, 18});
518+
test.AddOutput<int32_t>("concat_result", {2, 9, 1}, {// batch 0
519+
1, 3, 5, 7, 9, 11, 13, 15, 17,
520+
// batch 1
521+
2, 4, 6, 8, 10, 12, 14, 16, 18});
522+
test.Run();
523+
}
524+
525+
TEST(ConcatOpTest, Concat3D_exceed_maxStorageBuffersPerShaderStage_mixed_sizes) {
526+
// maxStorageBuffersPerShaderStage==8
527+
OpTester test("Concat");
528+
test.AddAttribute("axis", int64_t{1});
529+
530+
test.AddInput<int32_t>("input1", {2, 1, 1}, {1, 2});
531+
test.AddInput<int32_t>("input2", {2, 3, 1}, {3, 4, 5, 6, 7, 8});
532+
test.AddInput<int32_t>("input3", {2, 2, 1}, {9, 10, 11, 12});
533+
test.AddInput<int32_t>("input4", {2, 1, 1}, {13, 14});
534+
test.AddOutput<int32_t>("concat_result", {2, 7, 1}, {// batch 0
535+
1, 3, 4, 5, 9, 10, 13,
536+
// batch 1
537+
2, 6, 7, 8, 11, 12, 14});
538+
test.Run();
539+
}
540+
#endif // USE_WEBGPU
541+
437542
} // namespace test
438543
} // namespace onnxruntime

0 commit comments

Comments
 (0)