2
2
// Licensed under the MIT License.
3
3
4
4
#include " core/providers/webgpu/math/gemm.h"
5
- #include " core/providers/webgpu/math/gemm_vec4 .h"
5
+ #include " core/providers/webgpu/math/gemm_packed .h"
6
6
7
7
#include < vector>
8
8
@@ -38,130 +38,52 @@ WEBGPU_GEMM_VERSIONED_KERNEL(9, 10)
38
38
WEBGPU_GEMM_VERSIONED_KERNEL(11 , 12 )
39
39
WEBGPU_GEMM_KERNEL(13 )
40
40
41
- Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
42
- const uint32_t TILE_SIZE = 16 ;
43
-
44
- // Add shared memory arrays
45
- shader.AdditionalImplementation () << " var<workgroup> tile_a: array<array<output_value_t, " << TILE_SIZE << " >, " << TILE_SIZE << " >;\n "
46
- << " var<workgroup> tile_b: array<array<output_value_t, " << TILE_SIZE << " >, " << TILE_SIZE << " >;\n\n " ;
47
-
41
+ Status GemmNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
48
42
const ShaderVariableHelper& output = shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
49
43
50
- shader.MainFunctionBody () << " var value = output_value_t(0);\n\n "
51
- << " let tile_col_start = (workgroup_idx % uniforms.num_tile_n) * " << TILE_SIZE << " u;\n "
52
- << " let tile_row_start = (workgroup_idx / uniforms.num_tile_n) * " << TILE_SIZE << " u;\n " ;
44
+ shader.MainFunctionBody () << shader.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.output_size" )
45
+ << " let m = global_idx / uniforms.N;\n "
46
+ << " let n = global_idx % uniforms.N;\n "
47
+ << " var value = output_value_t(0);\n "
48
+ << " \n " ;
53
49
54
50
// When A or B is empty, we don't bind A and B. Because WebGPU doesn't support binding a zero-sized buffer.
55
51
if (need_handle_matmul_) {
56
52
const ShaderVariableHelper& A = shader.AddInput (" A" , ShaderUsage::UseUniform);
57
53
const ShaderVariableHelper& B = shader.AddInput (" B" , ShaderUsage::UseUniform);
58
54
59
- shader.MainFunctionBody ()
60
- << " let num_tiles = (uniforms.K - 1u) / " << TILE_SIZE << " u + 1u;\n "
61
- << " var k_start = 0u;\n "
62
- << " for (var t = 0u; t < num_tiles; t = t + 1u) {\n " ;
63
-
64
- // Fill workgroup shared memory
65
- if (transA_ && transB_) {
66
- shader.MainFunctionBody () << " var col = tile_row_start + local_id.x;\n "
67
- << " var row = k_start + local_id.y;\n "
68
- << " if (col < uniforms.M && row < uniforms.K) {\n "
69
- << " tile_a[local_id.y][local_id.x] = " << A.GetByOffset (" row * uniforms.M + col" ) << " ;\n "
70
- << " } else {\n "
71
- << " tile_a[local_id.y][local_id.x] = output_value_t(0);\n "
72
- << " }\n\n "
73
- << " col = k_start + local_id.x;\n "
74
- << " row = tile_col_start + local_id.y;\n "
75
- << " if (col < uniforms.K && row < uniforms.N) {\n "
76
- << " tile_b[local_id.y][local_id.x] = " << B.GetByOffset (" row * uniforms.K + col" ) << " ;\n "
77
- << " } else {\n "
78
- << " tile_b[local_id.y][local_id.x] = output_value_t(0);\n "
79
- << " }\n " ;
80
- } else if (transA_ && !transB_) {
81
- shader.MainFunctionBody () << " var col = tile_row_start + local_id.x;\n "
82
- << " var row = k_start + local_id.y;\n "
83
- << " if (col < uniforms.M && row < uniforms.K) {\n "
84
- << " tile_a[local_id.y][local_id.x] = " << A.GetByOffset (" row * uniforms.M + col" ) << " ;\n "
85
- << " } else {\n "
86
- << " tile_a[local_id.y][local_id.x] = output_value_t(0);\n "
87
- << " }\n\n "
88
- << " col = tile_col_start + local_id.x;\n "
89
- << " row = k_start + local_id.y;\n "
90
- << " if (col < uniforms.N && row < uniforms.K) {\n "
91
- << " tile_b[local_id.y][local_id.x] = " << B.GetByOffset (" row * uniforms.N + col" ) << " ;\n "
92
- << " } else {\n "
93
- << " tile_b[local_id.y][local_id.x] = output_value_t(0);\n "
94
- << " }\n " ;
95
- } else if (!transA_ && transB_) {
96
- shader.MainFunctionBody () << " var col = k_start + local_id.x;\n "
97
- << " var row = tile_row_start + local_id.y;\n "
98
- << " if (col < uniforms.K && row < uniforms.M) {\n "
99
- << " tile_a[local_id.y][local_id.x] = " << A.GetByOffset (" row * uniforms.K + col" ) << " ;\n "
100
- << " } else {\n "
101
- << " tile_a[local_id.y][local_id.x] = output_value_t(0);\n "
102
- << " }\n\n "
103
- << " col = k_start + local_id.x;\n "
104
- << " row = tile_col_start + local_id.y;\n "
105
- << " if (col < uniforms.K && row < uniforms.N) {\n "
106
- << " tile_b[local_id.y][local_id.x] = " << B.GetByOffset (" row * uniforms.K + col" ) << " ;\n "
107
- << " } else {\n "
108
- << " tile_b[local_id.y][local_id.x] = output_value_t(0);\n "
109
- << " }\n " ;
110
- } else {
111
- shader.MainFunctionBody () << " var col = k_start + local_id.x;\n "
112
- << " var row = tile_row_start + local_id.y;\n "
113
- << " if (col < uniforms.K && row < uniforms.M) {\n "
114
- << " tile_a[local_id.y][local_id.x] = " << A.GetByOffset (" row * uniforms.K + col" ) << " ;\n "
115
- << " } else {\n "
116
- << " tile_a[local_id.y][local_id.x] = output_value_t(0);\n "
117
- << " }\n\n "
118
- << " col = tile_col_start + local_id.x;\n "
119
- << " row = k_start + local_id.y;\n "
120
- << " if (col < uniforms.N && row < uniforms.K) {\n "
121
- << " tile_b[local_id.y][local_id.x] = " << B.GetByOffset (" row * uniforms.N + col" ) << " ;\n "
122
- << " } else {\n "
123
- << " tile_b[local_id.y][local_id.x] = output_value_t(0);\n "
124
- << " }\n " ;
125
- }
126
-
127
- shader.MainFunctionBody () << " k_start = k_start + " << TILE_SIZE << " u;\n "
128
- << " workgroupBarrier();\n\n "
129
- << " for (var k = 0u; k < " << TILE_SIZE << " u; k = k + 1u) {\n " ;
55
+ shader.MainFunctionBody () << " for (var k = 0u; k < uniforms.K; k = k + 1u) {\n " ;
130
56
131
57
if (transA_ && transB_) {
132
- shader.MainFunctionBody () << " value = value + tile_a[k][local_id.y] * tile_b[local_id.x][k];\n " ;
58
+ shader.MainFunctionBody () << " value = value + " << A.GetByOffset (" k * uniforms.M + m" )
59
+ << " * " << B.GetByOffset (" n * uniforms.K + k" ) << " ;\n " ;
133
60
} else if (transA_ && !transB_) {
134
- shader.MainFunctionBody () << " value = value + tile_a[k][local_id.y] * tile_b[k][local_id.x];\n " ;
61
+ shader.MainFunctionBody () << " value = value + " << A.GetByOffset (" k * uniforms.M + m" )
62
+ << " * " << B.GetByOffset (" k * uniforms.N + n" ) << " ;\n " ;
135
63
} else if (!transA_ && transB_) {
136
- shader.MainFunctionBody () << " value = value + tile_a[local_id.y][k] * tile_b[local_id.x][k];\n " ;
64
+ shader.MainFunctionBody () << " value = value + " << A.GetByOffset (" m * uniforms.K + k" )
65
+ << " * " << B.GetByOffset (" n * uniforms.K + k" ) << " ;\n " ;
137
66
} else {
138
- shader.MainFunctionBody () << " value = value + tile_a[local_id.y][k] * tile_b[k][local_id.x];\n " ;
67
+ shader.MainFunctionBody () << " value = value + " << A.GetByOffset (" m * uniforms.K + k" )
68
+ << " * " << B.GetByOffset (" k * uniforms.N + n" ) << " ;\n " ;
139
69
}
140
-
141
- shader.MainFunctionBody () << " }\n "
142
- << " workgroupBarrier();\n "
143
- << " }\n\n " ;
70
+ shader.MainFunctionBody () << " }\n "
71
+ << " \n " ;
144
72
}
145
73
146
74
// Calculate Alpha
147
75
if (alpha_) {
148
76
shader.MainFunctionBody () << " value = value * output_value_t(uniforms.alpha);\n " ;
149
77
}
150
78
151
- shader.MainFunctionBody () << " let m = tile_row_start + local_id.y;\n "
152
- << " let n = tile_col_start + local_id.x;\n " ;
153
-
154
79
// Calculate Bias
155
80
if (need_handle_bias_) {
156
81
const ShaderVariableHelper& C = shader.AddInput (" C" , ShaderUsage::UseUniform);
157
82
shader.MainFunctionBody () << " value = value + output_value_t(uniforms.beta) * "
158
83
<< C.GetByOffset (C.BroadcastedIndicesToOffset (" vec2(m, n)" , output)) << " ;\n " ;
159
84
}
160
85
161
- // Write output
162
- shader.MainFunctionBody () << " if (m < uniforms.M && n < uniforms.N) {\n "
163
- << " " << output.SetByOffset (" m * uniforms.N + n" , " value" ) << " \n "
164
- << " }\n " ;
86
+ shader.MainFunctionBody () << output.SetByOffset (" global_idx" , " value" ) << " \n " ;
165
87
166
88
return Status::OK ();
167
89
}
@@ -182,14 +104,14 @@ Status Gemm::ComputeInternal(ComputeContext& context) const {
182
104
return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Input tensors A and B must be 2 dimensional." );
183
105
}
184
106
185
- uint32_t M = onnxruntime::narrow<uint32_t >(transA_ ? A_shape[1 ] : A_shape[0 ]);
186
- uint32_t K = onnxruntime::narrow<uint32_t >(transA_ ? A_shape[0 ] : A_shape[1 ]);
187
- uint32_t N = onnxruntime::narrow<uint32_t >(transB_ ? B_shape[0 ] : B_shape[1 ]);
188
-
189
107
if ((transA_ ? A_shape[0 ] : A_shape[1 ]) != (transB_ ? B_shape[1 ] : B_shape[0 ])) {
190
108
return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Inner dimensions of A and B must match." );
191
109
}
192
110
111
+ int64_t M = transA_ ? A_shape[1 ] : A_shape[0 ];
112
+ int64_t K = transA_ ? A_shape[0 ] : A_shape[1 ];
113
+ int64_t N = transB_ ? B_shape[0 ] : B_shape[1 ];
114
+
193
115
std::vector<int64_t > output_dims{M, N};
194
116
auto * Y = context.Output (0 , output_dims);
195
117
int64_t output_size = Y->Shape ().Size ();
@@ -198,42 +120,36 @@ Status Gemm::ComputeInternal(ComputeContext& context) const {
198
120
return Status::OK ();
199
121
}
200
122
201
- // First try vec4 optimization if possible
202
- if (CanApplyGemmVec4 (A, B)) {
203
- return ApplyGemmVec4 (A, B, C, transA_, transB_, alpha_, beta_, context, Y);
204
- }
205
-
206
123
// WebGPU doesn't support binding a zero-sized buffer, so we need to check if A or B is empty.
207
124
bool need_handle_matmul = A_shape.Size () > 0 && B_shape.Size () > 0 ;
208
125
bool need_handle_bias = C && beta_;
209
126
210
- GemmProgram program{transA_, transB_, alpha_, need_handle_bias, need_handle_matmul};
127
+ if (M <= 8 && N <= 8 && K <= 8 ) {
128
+ // Use naive implementation for small matrices
129
+ GemmNaiveProgram program{transA_, transB_, alpha_, need_handle_bias, need_handle_matmul};
130
+ if (need_handle_matmul) {
131
+ program.AddInputs ({{A, ProgramTensorMetadataDependency::Type},
132
+ {B, ProgramTensorMetadataDependency::Type}});
133
+ }
211
134
212
- if (need_handle_matmul) {
213
- program.AddInputs ({{A, ProgramTensorMetadataDependency::Type},
214
- {B, ProgramTensorMetadataDependency::Type}});
215
- }
135
+ if (need_handle_bias) {
136
+ program.AddInput ({C, ProgramTensorMetadataDependency::Rank});
137
+ }
216
138
217
- if (need_handle_bias) {
218
- program.AddInput ({C, ProgramTensorMetadataDependency::Rank});
139
+ program.CacheHint (alpha_, transA_, transB_)
140
+ .AddOutputs ({{Y, ProgramTensorMetadataDependency::Type}})
141
+ .SetDispatchGroupSize ((output_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE)
142
+ .SetWorkgroupSize (WORKGROUP_SIZE)
143
+ .AddUniformVariables ({{static_cast <uint32_t >(output_size)},
144
+ {static_cast <uint32_t >(M)},
145
+ {static_cast <uint32_t >(N)},
146
+ {static_cast <uint32_t >(K)},
147
+ {alpha_},
148
+ {beta_}});
149
+ return context.RunProgram (program);
219
150
}
220
151
221
- const uint32_t TILE_SIZE = 16 ;
222
- const uint32_t num_tile_n = (N + TILE_SIZE - 1 ) / TILE_SIZE;
223
- const uint32_t num_tile_m = (M + TILE_SIZE - 1 ) / TILE_SIZE;
224
-
225
- program.CacheHint (alpha_, transA_, transB_)
226
- .AddOutputs ({{Y, ProgramTensorMetadataDependency::Type}})
227
- .SetDispatchGroupSize (num_tile_n * num_tile_m)
228
- .SetWorkgroupSize (TILE_SIZE, TILE_SIZE)
229
- .AddUniformVariables ({{num_tile_n},
230
- {M},
231
- {N},
232
- {K},
233
- {alpha_},
234
- {beta_}});
235
-
236
- return context.RunProgram (program);
152
+ return ApplyGemmPacked (A, B, C, transA_, transB_, alpha_, beta_, context);
237
153
}
238
154
239
155
} // namespace webgpu
0 commit comments