@@ -38,6 +38,38 @@ WEBGPU_CONCAT_VERSIONED_KERNEL(4, 10)
3838WEBGPU_CONCAT_VERSIONED_KERNEL(11 , 12 )
3939WEBGPU_CONCAT_KERNEL(13 )
4040
41+ void AppendCalculateInputIndexFunction(std::ostream& os, size_t input_count) {
42+ os << " fn calculate_input_index(global_idx: u32) -> u32 {\n "
43+ << " for (var i = 1u; i < " << input_count << " ; i = i + 1u) {\n "
44+ << " if (global_idx < " << GetElementAt (" uniforms.offsets" , " i" , input_count) << " ) {\n "
45+ << " return i - 1;\n "
46+ << " }\n "
47+ << " }\n "
48+ << " return " << input_count - 1 << " ;\n "
49+ << " }\n " ;
50+ }
51+
52+ void AppendAssignOutputDataFunction (std::ostream& os, gsl::span<const ShaderVariableHelper*> inputs, const ShaderVariableHelper& output, size_t axis, size_t input_count) {
53+ os << " fn assign_output_data(global_idx: u32, input_index: u32) {\n " ;
54+ for (size_t i = 0 ; i < inputs.size (); ++i) {
55+ if (i == 0 ) {
56+ os << " if (input_index == 0u) {\n " ;
57+ } else if (i == inputs.size () - 1 ) {
58+ os << " } else {\n " ;
59+ } else {
60+ os << " } else if (input_index == " << i << " u) {\n " ;
61+ }
62+ std::string offset = GetElementAt (" uniforms.offsets" , " input_index" , input_count);
63+ std::string concat_axis_offset = GetElementAt (" uniforms.sizes_in_concat_axis" , std::to_string (i), input_count);
64+ std::string output_indices_axis = " output_indices" + (inputs[i]->Rank () > 1 ? " [" + std::to_string (axis) + " ]" : " " );
65+ os << " var output_indices = " << inputs[i]->OffsetToIndices (" global_idx - " + offset) << " ;\n "
66+ << " " << output_indices_axis << " += " << concat_axis_offset << " ;\n "
67+ << " " << output.SetByIndices (" output_indices" , inputs[i]->GetByOffset (" global_idx - " + offset)) << " \n " ;
68+ }
69+ os << " }\n "
70+ " }\n " ;
71+ }
72+
4173Status ConcatProgram::GenerateShaderCode (ShaderHelper& shader) const {
4274 size_t input_count = Inputs ().size ();
4375 std::vector<const ShaderVariableHelper*> inputs;
@@ -47,16 +79,12 @@ Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const {
4779 }
4880 const auto & output = shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
4981
50- shader.MainFunctionBody () << shader.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.output_size" );
51- for (size_t i = 0 ; i < input_count; ++i) {
52- const std::string output_indices_i = absl::StrCat (" output_indices_" , i);
53- const std::string output_indices_i_axis = output_indices_i + (inputs[i]->Rank () > 1 ? " [" + std::to_string (axis_) + " ]" : " " );
54- const std::string concat_axis_offset = GetElementAt (" uniforms.sizes_in_concat_axis" , std::to_string (i), input_count);
82+ AppendCalculateInputIndexFunction (shader.AdditionalImplementation (), input_count);
83+ AppendAssignOutputDataFunction (shader.AdditionalImplementation (), inputs, output, axis_, input_count);
5584
56- shader.MainFunctionBody () << " var " << output_indices_i << " = " << inputs[i]->OffsetToIndices (" global_idx" ) << " ;\n "
57- << " " << output_indices_i_axis << " += " << concat_axis_offset << " ;\n "
58- << " " << output.SetByIndices (output_indices_i, inputs[i]->GetByOffset (" global_idx" )) << " \n " ;
59- }
85+ shader.MainFunctionBody () << shader.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.output_size" )
86+ << " let input_index = calculate_input_index(global_idx);\n "
87+ << " assign_output_data(global_idx, input_index);\n " ;
6088
6189 return Status::OK ();
6290}
@@ -85,29 +113,38 @@ Status Concat::ComputeInternal(ComputeContext& context) const {
85113 ConcatProgram program{axis};
86114 uint32_t num_inputs_this_concat = std::min (max_inputs_per_concat, input_count - input_index);
87115
116+ std::vector<uint32_t > offsets;
117+ offsets.reserve (num_inputs_this_concat + 1 );
118+ offsets.push_back (0 );
119+
88120 std::vector<uint32_t > sizes_in_concat_axis;
89121 sizes_in_concat_axis.reserve (num_inputs_this_concat + 1 );
90122 sizes_in_concat_axis.push_back (cumulative_size_in_concat_axis);
91123
92124 uint32_t output_size = 0 ;
93125 for (uint32_t i = 0 ; i < num_inputs_this_concat; i++) {
94126 auto & input = prepare.inputs [input_index + i];
127+ if (input.tensor ->Shape ().Size () == 0 ) {
128+ continue ;
129+ }
95130 program.AddInput ({input.tensor , ProgramTensorMetadataDependency::TypeAndRank});
96131
97132 uint32_t size = onnxruntime::narrow<int32_t >(input.tensor ->Shape ().Size ());
98133 uint32_t axis_size = static_cast <uint32_t >(input.tensor ->Shape ()[axis]);
99134
100135 output_size += size;
136+ offsets.push_back (output_size);
101137 cumulative_size_in_concat_axis += axis_size;
102138 sizes_in_concat_axis.push_back (cumulative_size_in_concat_axis);
103139 }
104140
141+ offsets.pop_back ();
105142 sizes_in_concat_axis.pop_back ();
106143
107144 program.CacheHint (absl::StrJoin (std::make_tuple (num_inputs_this_concat, prepare.axis ), " ," ))
108145 .AddOutputs ({prepare.output_tensor })
109146 .SetDispatchGroupSize ((output_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE)
110- .AddUniformVariables ({gsl::span<const uint32_t >(sizes_in_concat_axis.data (), sizes_in_concat_axis.size ()), output_size});
147+ .AddUniformVariables ({gsl::span<const uint32_t >(offsets. data (), offsets. size ()), gsl::span< const uint32_t >( sizes_in_concat_axis.data (), sizes_in_concat_axis.size ()), output_size});
111148 ORT_RETURN_IF_ERROR (context.RunProgram (program));
112149
113150 input_index += num_inputs_this_concat;
0 commit comments