Skip to content

Commit ae9f06b

Browse files
committed
optimal solution
1 parent bc4fd5b commit ae9f06b

File tree

3 files changed

+65
-12
lines changed

3 files changed

+65
-12
lines changed

onnxruntime/core/providers/webgpu/tensor/concat.cc

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,38 @@ WEBGPU_CONCAT_VERSIONED_KERNEL(4, 10)
3838
WEBGPU_CONCAT_VERSIONED_KERNEL(11, 12)
3939
WEBGPU_CONCAT_KERNEL(13)
4040

41+
void AppendCalculateInputIndexFunction(std::ostream& os, size_t input_count) {
42+
os << "fn calculate_input_index(global_idx: u32) -> u32 {\n"
43+
<< " for (var i = 1u; i < " << input_count << "; i = i + 1u) {\n"
44+
<< " if (global_idx < " << GetElementAt("uniforms.offsets", "i", input_count) << ") {\n"
45+
<< " return i - 1;\n"
46+
<< " }\n"
47+
<< " }\n"
48+
<< " return " << input_count - 1 << ";\n"
49+
<< "}\n";
50+
}
51+
52+
void AppendAssignOutputDataFunction(std::ostream& os, gsl::span<const ShaderVariableHelper*> inputs, const ShaderVariableHelper& output, size_t axis, size_t input_count) {
53+
os << "fn assign_output_data(global_idx: u32, input_index: u32) {\n";
54+
for (size_t i = 0; i < inputs.size(); ++i) {
55+
if (i == 0) {
56+
os << " if (input_index == 0u) {\n";
57+
} else if (i == inputs.size() - 1) {
58+
os << " } else {\n";
59+
} else {
60+
os << " } else if (input_index == " << i << "u) {\n";
61+
}
62+
std::string offset = GetElementAt("uniforms.offsets", "input_index", input_count);
63+
std::string concat_axis_offset = GetElementAt("uniforms.sizes_in_concat_axis", std::to_string(i), input_count);
64+
std::string output_indices_axis = "output_indices" + (inputs[i]->Rank() > 1 ? "[" + std::to_string(axis) + "]" : "");
65+
os << " var output_indices = " << inputs[i]->OffsetToIndices("global_idx - " + offset) << ";\n"
66+
<< " " << output_indices_axis << " += " << concat_axis_offset << ";\n"
67+
<< " " << output.SetByIndices("output_indices", inputs[i]->GetByOffset("global_idx - " + offset)) << "\n";
68+
}
69+
os << " }\n"
70+
"}\n";
71+
}
72+
4173
Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const {
4274
size_t input_count = Inputs().size();
4375
std::vector<const ShaderVariableHelper*> inputs;
@@ -47,16 +79,12 @@ Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const {
4779
}
4880
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
4981

50-
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size");
51-
for (size_t i = 0; i < input_count; ++i) {
52-
const std::string output_indices_i = absl::StrCat("output_indices_", i);
53-
const std::string output_indices_i_axis = output_indices_i + (inputs[i]->Rank() > 1 ? "[" + std::to_string(axis_) + "]" : "");
54-
const std::string concat_axis_offset = GetElementAt("uniforms.sizes_in_concat_axis", std::to_string(i), input_count);
82+
AppendCalculateInputIndexFunction(shader.AdditionalImplementation(), input_count);
83+
AppendAssignOutputDataFunction(shader.AdditionalImplementation(), inputs, output, axis_, input_count);
5584

56-
shader.MainFunctionBody() << " var " << output_indices_i << " = " << inputs[i]->OffsetToIndices("global_idx") << ";\n"
57-
<< " " << output_indices_i_axis << " += " << concat_axis_offset << ";\n"
58-
<< " " << output.SetByIndices(output_indices_i, inputs[i]->GetByOffset("global_idx")) << "\n";
59-
}
85+
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
86+
<< "let input_index = calculate_input_index(global_idx);\n"
87+
<< "assign_output_data(global_idx, input_index);\n";
6088

6189
return Status::OK();
6290
}
@@ -85,29 +113,38 @@ Status Concat::ComputeInternal(ComputeContext& context) const {
85113
ConcatProgram program{axis};
86114
uint32_t num_inputs_this_concat = std::min(max_inputs_per_concat, input_count - input_index);
87115

116+
std::vector<uint32_t> offsets;
117+
offsets.reserve(num_inputs_this_concat + 1);
118+
offsets.push_back(0);
119+
88120
std::vector<uint32_t> sizes_in_concat_axis;
89121
sizes_in_concat_axis.reserve(num_inputs_this_concat + 1);
90122
sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis);
91123

92124
uint32_t output_size = 0;
93125
for (uint32_t i = 0; i < num_inputs_this_concat; i++) {
94126
auto& input = prepare.inputs[input_index + i];
127+
if (input.tensor->Shape().Size() == 0) {
128+
continue;
129+
}
95130
program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank});
96131

97132
uint32_t size = onnxruntime::narrow<int32_t>(input.tensor->Shape().Size());
98133
uint32_t axis_size = static_cast<uint32_t>(input.tensor->Shape()[axis]);
99134

100135
output_size += size;
136+
offsets.push_back(output_size);
101137
cumulative_size_in_concat_axis += axis_size;
102138
sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis);
103139
}
104140

141+
offsets.pop_back();
105142
sizes_in_concat_axis.pop_back();
106143

107144
program.CacheHint(absl::StrJoin(std::make_tuple(num_inputs_this_concat, prepare.axis), ","))
108145
.AddOutputs({prepare.output_tensor})
109146
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
110-
.AddUniformVariables({gsl::span<const uint32_t>(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), output_size});
147+
.AddUniformVariables({gsl::span<const uint32_t>(offsets.data(), offsets.size()), gsl::span<const uint32_t>(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), output_size});
111148
ORT_RETURN_IF_ERROR(context.RunProgram(program));
112149

113150
input_index += num_inputs_this_concat;

onnxruntime/core/providers/webgpu/tensor/concat.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ class ConcatProgram final : public Program<ConcatProgram> {
1717

1818
Status GenerateShaderCode(ShaderHelper& sh) const override;
1919

20-
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"sizes_in_concat_axis", ProgramUniformVariableDataType::Uint32},
20+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"offsets", ProgramUniformVariableDataType::Uint32},
21+
{"sizes_in_concat_axis", ProgramUniformVariableDataType::Uint32},
2122
{"output_size", ProgramUniformVariableDataType::Uint32});
2223

2324
private:

onnxruntime/test/providers/cpu/tensor/concat_op_test.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ TEST(ConcatOpTest, Concat1D_2) {
7373
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
7474
{kTensorrtExecutionProvider, // TensorRT: no support for dynamic shape tensor
7575
kNnapiExecutionProvider, // NNAPI: concat does not support 0 size input
76-
kWebGpuExecutionProvider, // WebGPU: concat does not support 0 size input
7776
kQnnExecutionProvider}); // QNN: not support dynamic shape tensor
7877
}
7978

@@ -522,6 +521,22 @@ TEST(ConcatOpTest, Concat3D_exceed_maxStorageBuffersPerShaderStage) {
522521
2, 4, 6, 8, 10, 12, 14, 16, 18});
523522
test.Run();
524523
}
524+
525+
TEST(ConcatOpTest, Concat3D_exceed_maxStorageBuffersPerShaderStage_mixed_sizes) {
526+
// maxStorageBuffersPerShaderStage==8
527+
OpTester test("Concat");
528+
test.AddAttribute("axis", int64_t{1});
529+
530+
test.AddInput<int32_t>("input1", {2, 1, 1}, {1, 2});
531+
test.AddInput<int32_t>("input2", {2, 3, 1}, {3, 4, 5, 6, 7, 8});
532+
test.AddInput<int32_t>("input3", {2, 2, 1}, {9, 10, 11, 12});
533+
test.AddInput<int32_t>("input4", {2, 1, 1}, {13, 14});
534+
test.AddOutput<int32_t>("concat_result", {2, 7, 1}, {// batch 0
535+
1, 3, 4, 5, 9, 10, 13,
536+
// batch 1
537+
2, 6, 7, 8, 11, 12, 14});
538+
test.Run();
539+
}
525540
#endif // USE_WEBGPU
526541

527542
} // namespace test

0 commit comments

Comments
 (0)