Skip to content

Commit 915a999

Browse files
xiaofeihan1fs-eire
andauthored
[WebGPU] Unify core implementations of GEMM and MatMul (#24586)
### Description This PR extract core implementations into gemm_utils.cc which is used to generate shader both GEMM and Matmul ops. The core implemenations included scalar and vec4 versions of GEMM and Matmul. ### Motivation and Context There are many common codes for GEMM and Matmul, so we want to extra common code to unify their implementations. ![Blank diagram (1)](https://github.com/user-attachments/assets/45f8d7ac-6705-4cea-8b8c-966ded6a6ca5) --------- Co-authored-by: Yulong Wang <[email protected]>
1 parent d5fa2ac commit 915a999

File tree

14 files changed

+1144
-1058
lines changed

14 files changed

+1144
-1058
lines changed

onnxruntime/core/providers/webgpu/math/gemm.cc

Lines changed: 45 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Licensed under the MIT License.
33

44
#include "core/providers/webgpu/math/gemm.h"
5-
#include "core/providers/webgpu/math/gemm_vec4.h"
5+
#include "core/providers/webgpu/math/gemm_packed.h"
66

77
#include <vector>
88

@@ -38,130 +38,52 @@ WEBGPU_GEMM_VERSIONED_KERNEL(9, 10)
3838
WEBGPU_GEMM_VERSIONED_KERNEL(11, 12)
3939
WEBGPU_GEMM_KERNEL(13)
4040

41-
Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
42-
const uint32_t TILE_SIZE = 16;
43-
44-
// Add shared memory arrays
45-
shader.AdditionalImplementation() << "var<workgroup> tile_a: array<array<output_value_t, " << TILE_SIZE << ">, " << TILE_SIZE << ">;\n"
46-
<< "var<workgroup> tile_b: array<array<output_value_t, " << TILE_SIZE << ">, " << TILE_SIZE << ">;\n\n";
47-
41+
Status GemmNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
4842
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
4943

50-
shader.MainFunctionBody() << " var value = output_value_t(0);\n\n"
51-
<< " let tile_col_start = (workgroup_idx % uniforms.num_tile_n) * " << TILE_SIZE << "u;\n"
52-
<< " let tile_row_start = (workgroup_idx / uniforms.num_tile_n) * " << TILE_SIZE << "u;\n";
44+
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
45+
<< " let m = global_idx / uniforms.N;\n"
46+
<< " let n = global_idx % uniforms.N;\n"
47+
<< " var value = output_value_t(0);\n"
48+
<< "\n";
5349

5450
// When A or B is empty, we don't bind A and B. Because WebGPU doesn't support binding a zero-sized buffer.
5551
if (need_handle_matmul_) {
5652
const ShaderVariableHelper& A = shader.AddInput("A", ShaderUsage::UseUniform);
5753
const ShaderVariableHelper& B = shader.AddInput("B", ShaderUsage::UseUniform);
5854

59-
shader.MainFunctionBody()
60-
<< " let num_tiles = (uniforms.K - 1u) / " << TILE_SIZE << "u + 1u;\n"
61-
<< " var k_start = 0u;\n"
62-
<< " for (var t = 0u; t < num_tiles; t = t + 1u) {\n";
63-
64-
// Fill workgroup shared memory
65-
if (transA_ && transB_) {
66-
shader.MainFunctionBody() << " var col = tile_row_start + local_id.x;\n"
67-
<< " var row = k_start + local_id.y;\n"
68-
<< " if (col < uniforms.M && row < uniforms.K) {\n"
69-
<< " tile_a[local_id.y][local_id.x] = " << A.GetByOffset("row * uniforms.M + col") << ";\n"
70-
<< " } else {\n"
71-
<< " tile_a[local_id.y][local_id.x] = output_value_t(0);\n"
72-
<< " }\n\n"
73-
<< " col = k_start + local_id.x;\n"
74-
<< " row = tile_col_start + local_id.y;\n"
75-
<< " if (col < uniforms.K && row < uniforms.N) {\n"
76-
<< " tile_b[local_id.y][local_id.x] = " << B.GetByOffset("row * uniforms.K + col") << ";\n"
77-
<< " } else {\n"
78-
<< " tile_b[local_id.y][local_id.x] = output_value_t(0);\n"
79-
<< " }\n";
80-
} else if (transA_ && !transB_) {
81-
shader.MainFunctionBody() << " var col = tile_row_start + local_id.x;\n"
82-
<< " var row = k_start + local_id.y;\n"
83-
<< " if (col < uniforms.M && row < uniforms.K) {\n"
84-
<< " tile_a[local_id.y][local_id.x] = " << A.GetByOffset("row * uniforms.M + col") << ";\n"
85-
<< " } else {\n"
86-
<< " tile_a[local_id.y][local_id.x] = output_value_t(0);\n"
87-
<< " }\n\n"
88-
<< " col = tile_col_start + local_id.x;\n"
89-
<< " row = k_start + local_id.y;\n"
90-
<< " if (col < uniforms.N && row < uniforms.K) {\n"
91-
<< " tile_b[local_id.y][local_id.x] = " << B.GetByOffset("row * uniforms.N + col") << ";\n"
92-
<< " } else {\n"
93-
<< " tile_b[local_id.y][local_id.x] = output_value_t(0);\n"
94-
<< " }\n";
95-
} else if (!transA_ && transB_) {
96-
shader.MainFunctionBody() << " var col = k_start + local_id.x;\n"
97-
<< " var row = tile_row_start + local_id.y;\n"
98-
<< " if (col < uniforms.K && row < uniforms.M) {\n"
99-
<< " tile_a[local_id.y][local_id.x] = " << A.GetByOffset("row * uniforms.K + col") << ";\n"
100-
<< " } else {\n"
101-
<< " tile_a[local_id.y][local_id.x] = output_value_t(0);\n"
102-
<< " }\n\n"
103-
<< " col = k_start + local_id.x;\n"
104-
<< " row = tile_col_start + local_id.y;\n"
105-
<< " if (col < uniforms.K && row < uniforms.N) {\n"
106-
<< " tile_b[local_id.y][local_id.x] = " << B.GetByOffset("row * uniforms.K + col") << ";\n"
107-
<< " } else {\n"
108-
<< " tile_b[local_id.y][local_id.x] = output_value_t(0);\n"
109-
<< " }\n";
110-
} else {
111-
shader.MainFunctionBody() << " var col = k_start + local_id.x;\n"
112-
<< " var row = tile_row_start + local_id.y;\n"
113-
<< " if (col < uniforms.K && row < uniforms.M) {\n"
114-
<< " tile_a[local_id.y][local_id.x] = " << A.GetByOffset("row * uniforms.K + col") << ";\n"
115-
<< " } else {\n"
116-
<< " tile_a[local_id.y][local_id.x] = output_value_t(0);\n"
117-
<< " }\n\n"
118-
<< " col = tile_col_start + local_id.x;\n"
119-
<< " row = k_start + local_id.y;\n"
120-
<< " if (col < uniforms.N && row < uniforms.K) {\n"
121-
<< " tile_b[local_id.y][local_id.x] = " << B.GetByOffset("row * uniforms.N + col") << ";\n"
122-
<< " } else {\n"
123-
<< " tile_b[local_id.y][local_id.x] = output_value_t(0);\n"
124-
<< " }\n";
125-
}
126-
127-
shader.MainFunctionBody() << " k_start = k_start + " << TILE_SIZE << "u;\n"
128-
<< " workgroupBarrier();\n\n"
129-
<< " for (var k = 0u; k < " << TILE_SIZE << "u; k = k + 1u) {\n";
55+
shader.MainFunctionBody() << " for (var k = 0u; k < uniforms.K; k = k + 1u) {\n";
13056

13157
if (transA_ && transB_) {
132-
shader.MainFunctionBody() << " value = value + tile_a[k][local_id.y] * tile_b[local_id.x][k];\n";
58+
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("k * uniforms.M + m")
59+
<< " * " << B.GetByOffset("n * uniforms.K + k") << ";\n";
13360
} else if (transA_ && !transB_) {
134-
shader.MainFunctionBody() << " value = value + tile_a[k][local_id.y] * tile_b[k][local_id.x];\n";
61+
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("k * uniforms.M + m")
62+
<< " * " << B.GetByOffset("k * uniforms.N + n") << ";\n";
13563
} else if (!transA_ && transB_) {
136-
shader.MainFunctionBody() << " value = value + tile_a[local_id.y][k] * tile_b[local_id.x][k];\n";
64+
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("m * uniforms.K + k")
65+
<< " * " << B.GetByOffset("n * uniforms.K + k") << ";\n";
13766
} else {
138-
shader.MainFunctionBody() << " value = value + tile_a[local_id.y][k] * tile_b[k][local_id.x];\n";
67+
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("m * uniforms.K + k")
68+
<< " * " << B.GetByOffset("k * uniforms.N + n") << ";\n";
13969
}
140-
141-
shader.MainFunctionBody() << " }\n"
142-
<< " workgroupBarrier();\n"
143-
<< " }\n\n";
70+
shader.MainFunctionBody() << " }\n"
71+
<< "\n";
14472
}
14573

14674
// Calculate Alpha
14775
if (alpha_) {
14876
shader.MainFunctionBody() << " value = value * output_value_t(uniforms.alpha);\n";
14977
}
15078

151-
shader.MainFunctionBody() << " let m = tile_row_start + local_id.y;\n"
152-
<< " let n = tile_col_start + local_id.x;\n";
153-
15479
// Calculate Bias
15580
if (need_handle_bias_) {
15681
const ShaderVariableHelper& C = shader.AddInput("C", ShaderUsage::UseUniform);
15782
shader.MainFunctionBody() << " value = value + output_value_t(uniforms.beta) * "
15883
<< C.GetByOffset(C.BroadcastedIndicesToOffset("vec2(m, n)", output)) << ";\n";
15984
}
16085

161-
// Write output
162-
shader.MainFunctionBody() << " if (m < uniforms.M && n < uniforms.N) {\n"
163-
<< " " << output.SetByOffset("m * uniforms.N + n", "value") << "\n"
164-
<< " }\n";
86+
shader.MainFunctionBody() << output.SetByOffset("global_idx", "value") << "\n";
16587

16688
return Status::OK();
16789
}
@@ -182,14 +104,14 @@ Status Gemm::ComputeInternal(ComputeContext& context) const {
182104
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input tensors A and B must be 2 dimensional.");
183105
}
184106

185-
uint32_t M = onnxruntime::narrow<uint32_t>(transA_ ? A_shape[1] : A_shape[0]);
186-
uint32_t K = onnxruntime::narrow<uint32_t>(transA_ ? A_shape[0] : A_shape[1]);
187-
uint32_t N = onnxruntime::narrow<uint32_t>(transB_ ? B_shape[0] : B_shape[1]);
188-
189107
if ((transA_ ? A_shape[0] : A_shape[1]) != (transB_ ? B_shape[1] : B_shape[0])) {
190108
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inner dimensions of A and B must match.");
191109
}
192110

111+
int64_t M = transA_ ? A_shape[1] : A_shape[0];
112+
int64_t K = transA_ ? A_shape[0] : A_shape[1];
113+
int64_t N = transB_ ? B_shape[0] : B_shape[1];
114+
193115
std::vector<int64_t> output_dims{M, N};
194116
auto* Y = context.Output(0, output_dims);
195117
int64_t output_size = Y->Shape().Size();
@@ -198,42 +120,36 @@ Status Gemm::ComputeInternal(ComputeContext& context) const {
198120
return Status::OK();
199121
}
200122

201-
// First try vec4 optimization if possible
202-
if (CanApplyGemmVec4(A, B)) {
203-
return ApplyGemmVec4(A, B, C, transA_, transB_, alpha_, beta_, context, Y);
204-
}
205-
206123
// WebGPU doesn't support binding a zero-sized buffer, so we need to check if A or B is empty.
207124
bool need_handle_matmul = A_shape.Size() > 0 && B_shape.Size() > 0;
208125
bool need_handle_bias = C && beta_;
209126

210-
GemmProgram program{transA_, transB_, alpha_, need_handle_bias, need_handle_matmul};
127+
if (M <= 8 && N <= 8 && K <= 8) {
128+
// Use naive implementation for small matrices
129+
GemmNaiveProgram program{transA_, transB_, alpha_, need_handle_bias, need_handle_matmul};
130+
if (need_handle_matmul) {
131+
program.AddInputs({{A, ProgramTensorMetadataDependency::Type},
132+
{B, ProgramTensorMetadataDependency::Type}});
133+
}
211134

212-
if (need_handle_matmul) {
213-
program.AddInputs({{A, ProgramTensorMetadataDependency::Type},
214-
{B, ProgramTensorMetadataDependency::Type}});
215-
}
135+
if (need_handle_bias) {
136+
program.AddInput({C, ProgramTensorMetadataDependency::Rank});
137+
}
216138

217-
if (need_handle_bias) {
218-
program.AddInput({C, ProgramTensorMetadataDependency::Rank});
139+
program.CacheHint(alpha_, transA_, transB_)
140+
.AddOutputs({{Y, ProgramTensorMetadataDependency::Type}})
141+
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
142+
.SetWorkgroupSize(WORKGROUP_SIZE)
143+
.AddUniformVariables({{static_cast<uint32_t>(output_size)},
144+
{static_cast<uint32_t>(M)},
145+
{static_cast<uint32_t>(N)},
146+
{static_cast<uint32_t>(K)},
147+
{alpha_},
148+
{beta_}});
149+
return context.RunProgram(program);
219150
}
220151

221-
const uint32_t TILE_SIZE = 16;
222-
const uint32_t num_tile_n = (N + TILE_SIZE - 1) / TILE_SIZE;
223-
const uint32_t num_tile_m = (M + TILE_SIZE - 1) / TILE_SIZE;
224-
225-
program.CacheHint(alpha_, transA_, transB_)
226-
.AddOutputs({{Y, ProgramTensorMetadataDependency::Type}})
227-
.SetDispatchGroupSize(num_tile_n * num_tile_m)
228-
.SetWorkgroupSize(TILE_SIZE, TILE_SIZE)
229-
.AddUniformVariables({{num_tile_n},
230-
{M},
231-
{N},
232-
{K},
233-
{alpha_},
234-
{beta_}});
235-
236-
return context.RunProgram(program);
152+
return ApplyGemmPacked(A, B, C, transA_, transB_, alpha_, beta_, context);
237153
}
238154

239155
} // namespace webgpu

onnxruntime/core/providers/webgpu/math/gemm.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
namespace onnxruntime {
1111
namespace webgpu {
1212

13-
class GemmProgram final : public Program<GemmProgram> {
13+
class GemmNaiveProgram final : public Program<GemmNaiveProgram> {
1414
public:
15-
GemmProgram(bool transA, bool transB, float alpha, bool need_handle_bias, bool need_handle_matmul)
16-
: Program{"Gemm"},
15+
GemmNaiveProgram(bool transA, bool transB, float alpha, bool need_handle_bias, bool need_handle_matmul)
16+
: Program{"GemmNaive"},
1717
transA_{transA},
1818
transB_{transB},
1919
alpha_{alpha},
@@ -23,7 +23,7 @@ class GemmProgram final : public Program<GemmProgram> {
2323
Status GenerateShaderCode(ShaderHelper& sh) const override;
2424

2525
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
26-
{"num_tile_n", ProgramUniformVariableDataType::Uint32},
26+
{"output_size", ProgramUniformVariableDataType::Uint32},
2727
{"M", ProgramUniformVariableDataType::Uint32},
2828
{"N", ProgramUniformVariableDataType::Uint32},
2929
{"K", ProgramUniformVariableDataType::Uint32},
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/math/gemm_packed.h"
5+
6+
#include "core/providers/webgpu/webgpu_utils.h"
7+
8+
#include "core/providers/webgpu/math/matmul_utils.h"
9+
#include "core/providers/webgpu/math/gemm_utils.h"
10+
11+
namespace onnxruntime {
12+
namespace webgpu {
13+
14+
Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
15+
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
16+
17+
// Each thread compute 4*4 elements
18+
InlinedVector<int64_t> elements_per_thread = InlinedVector<int64_t>({4, 4, 1});
19+
20+
const std::string data_type = "output_element_t";
21+
22+
if (need_handle_matmul_) {
23+
const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
24+
const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
25+
26+
MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_, is_vec4_);
27+
}
28+
if (is_vec4_) {
29+
ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_, output_components_));
30+
} else {
31+
ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_));
32+
}
33+
MatMulWriteFnSource(shader, output, need_handle_bias_, true, c_components_, output_components_, c_is_scalar_);
34+
35+
return Status::OK();
36+
}
37+
38+
Status ApplyGemmPacked(const Tensor* a,
39+
const Tensor* b,
40+
const Tensor* c,
41+
bool transA,
42+
bool transB,
43+
float alpha,
44+
float beta,
45+
ComputeContext& context) {
46+
const auto& a_shape = a->Shape();
47+
const auto& b_shape = b->Shape();
48+
49+
uint32_t M = onnxruntime::narrow<uint32_t>(transA ? a_shape[1] : a_shape[0]);
50+
uint32_t K = onnxruntime::narrow<uint32_t>(transA ? a_shape[0] : a_shape[1]);
51+
uint32_t N = onnxruntime::narrow<uint32_t>(transB ? b_shape[0] : b_shape[1]);
52+
53+
std::vector<int64_t> output_dims{M, N};
54+
auto* y = context.Output(0, output_dims);
55+
int64_t output_size = y->Shape().Size();
56+
57+
if (output_size == 0) {
58+
return Status::OK();
59+
}
60+
61+
// WebGPU doesn't support binding a zero-sized buffer, so we need to check if A or B is empty.
62+
bool need_handle_matmul = a_shape.Size() > 0 && b_shape.Size() > 0;
63+
bool need_handle_bias = c && beta;
64+
65+
const bool is_vec4 = a_shape[1] % 4 == 0 && b_shape[1] % 4 == 0;
66+
67+
// Components for A, B
68+
int components = is_vec4 ? 4 : 1;
69+
// Components for Y
70+
int output_components = (is_vec4 && N % 4 == 0) ? 4 : 1;
71+
// Components for C.
72+
int c_components = 1;
73+
74+
bool c_is_scalar = false;
75+
if (need_handle_bias) {
76+
const auto& c_shape = c->Shape();
77+
int64_t c_last_dim = c_shape[c_shape.NumDimensions() - 1];
78+
// `C` in GEMM might be broadcast to the output, and broadcasting requires the components to be consistent.
79+
// So we use vec4 for C when its last dimension is N, and the output is also a vec4.
80+
c_components = (c_last_dim == N && output_components == 4) ? 4 : 1;
81+
c_is_scalar = c_shape.Size() == 1;
82+
}
83+
84+
GemmProgram program{transA, transB, alpha, need_handle_bias, need_handle_matmul, c_components, c_is_scalar, output_components, is_vec4};
85+
86+
if (need_handle_matmul) {
87+
program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, components},
88+
{b, ProgramTensorMetadataDependency::TypeAndRank, components}});
89+
}
90+
91+
if (need_handle_bias) {
92+
program.AddInput({c, ProgramTensorMetadataDependency::TypeAndRank, c_components});
93+
}
94+
95+
const uint32_t TILE_SIZE = 32;
96+
const uint32_t num_tile_n = (N + TILE_SIZE - 1) / TILE_SIZE;
97+
const uint32_t num_tile_m = (M + TILE_SIZE - 1) / TILE_SIZE;
98+
99+
program.CacheHint(alpha, transA, transB, c_is_scalar)
100+
.AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}})
101+
.SetDispatchGroupSize(num_tile_n, num_tile_m, 1)
102+
.SetWorkgroupSize(GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_X, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Y, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Z)
103+
.AddUniformVariables({{alpha},
104+
{beta},
105+
{M}, /* dim_a_outer */
106+
{N}, /* dim_b_outer */
107+
{K}} /*dim_inner */
108+
);
109+
110+
return context.RunProgram(program);
111+
}
112+
113+
} // namespace webgpu
114+
} // namespace onnxruntime

0 commit comments

Comments
 (0)