Skip to content

Commit 58954ba

Browse files
authored
[webgpu] Apply template to MatMulNBitsWideTile (#25353)
### Description This commit applies WGSL template to `MatMulNBitsWideTile` to improve code readability and enables more flexible data handling. As part of this change, support for 4-bit and 8-bit shaders has been consolidated, and a common `CEIL_DIV` utility has been introduced. The previous `ShaderUsage::UseUniform` and `ShaderUsage::UseIndicesTypeAlias` flags are no longer necessary and have been removed. ### Motivation and Context See above
1 parent 2d6a525 commit 58954ba

File tree

5 files changed

+277
-178
lines changed

5 files changed

+277
-178
lines changed

cmake/onnxruntime_providers_webgpu.cmake

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,24 @@
172172
file(MAKE_DIRECTORY ${WGSL_GENERATED_DIR})
173173

174174
# Find all WGSL template input files
175-
file(GLOB_RECURSE WGSL_TEMPLATE_FILES
176-
"${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template"
177-
"${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.wgsl.template")
175+
set(WGSL_SEARCH_PATHS "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template")
176+
if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
177+
list(APPEND WGSL_SEARCH_PATHS "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.wgsl.template")
178+
endif()
179+
file(GLOB_RECURSE WGSL_TEMPLATE_FILES ${WGSL_SEARCH_PATHS})
178180

179181
# Set wgsl-gen command line options as a list
180-
set(WGSL_GEN_OPTIONS "-i" "${ONNXRUNTIME_ROOT}/core/providers/webgpu/" "-i" "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/" "--output" "${WGSL_GENERATED_DIR}" "-I" "wgsl_template_gen/" "--preserve-code-ref" "--verbose")
182+
set(WGSL_GEN_OPTIONS
183+
"--output" "${WGSL_GENERATED_DIR}"
184+
"-I" "wgsl_template_gen/"
185+
"--preserve-code-ref"
186+
"--verbose"
187+
"-i" "${ONNXRUNTIME_ROOT}/core/providers/webgpu"
188+
)
189+
if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
190+
list(APPEND WGSL_GEN_OPTIONS "-i" "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu")
191+
endif()
192+
181193
if (onnxruntime_WGSL_TEMPLATE STREQUAL "static")
182194
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
183195
list(APPEND WGSL_GEN_OPTIONS "--generator" "static-cpp-literal")

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

Lines changed: 53 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ namespace {
2222

2323
constexpr unsigned int kMinMForTileOptimization = 4;
2424

25+
template <typename T>
26+
inline T ceil_div(T numerator, T denominator) {
27+
return (numerator + denominator - 1) / denominator;
28+
}
29+
2530
} // namespace
2631

2732
ONNX_OPERATOR_KERNEL_EX(
@@ -37,165 +42,24 @@ ONNX_OPERATOR_KERNEL_EX(
3742
MatMulNBits);
3843

3944
Status MatMulNBitsWideTileProgram::GenerateShaderCode(ShaderHelper& shader) const {
40-
const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
41-
const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
42-
shader.AddInput("scales", ShaderUsage::UseUniform);
45+
shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
46+
shader.AddInput("input_b", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
47+
shader.AddInput("scales", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
4348
if (has_zero_points_) {
44-
shader.AddInput("zero_points", ShaderUsage::UseUniform);
49+
shader.AddInput("zero_points", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
4550
}
46-
const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);
47-
48-
// Bock size 32, `a` component size 4, 8 `a` components per block.
49-
constexpr uint32_t kAComponentsForBlock32 = 8;
51+
shader.AddOutput("output", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
5052

5153
const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY();
5254
ORT_ENFORCE(tile_m_ == workgroup_size / 8, "tile_m must be workgroup_size / 8.");
5355
ORT_ENFORCE(tile_n_ == workgroup_size, "tile_n must be workgroup_size.");
56+
ORT_ENFORCE(nbits_ == 4 || nbits_ == 8, "Only 4/8 bits are supported for webgpu matmulnbits.");
5457

55-
// memory read/write helpers
56-
shader.AdditionalImplementation() << "fn mm_read_a(batch : u32, row : u32, col : u32) -> input_a_value_t {\n"
57-
<< " if (batch < uniforms.input_a_shape[0] && row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n"
58-
<< " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n"
59-
<< " }\n"
60-
<< " return input_a_value_t(0);\n"
61-
<< "}\n";
62-
if (nbits_ == 4) {
63-
shader.AdditionalImplementation() << "\n"
64-
<< "fn mm_read_b(row : u32, col : u32) -> input_b_value_t {\n"
65-
<< " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n"
66-
<< " return " << b.GetByIndices("input_b_indices_t(row, col, 0)") << ";\n"
67-
<< " }\n"
68-
<< " return input_b_value_t(0);\n"
69-
<< "}\n";
70-
71-
shader.AdditionalImplementation() << R"(
72-
fn dequantize_packed8xU4(packed_value : u32, zero_point : output_element_t, scale : output_element_t) -> mat2x4<output_element_t> {
73-
let lower_values: vec4<u32> = unpack4xU8(packed_value & 0x0F0F0F0Fu);
74-
let upper_values: vec4<u32> = unpack4xU8((packed_value >> 4u) & 0x0F0F0F0Fu);
75-
76-
let zero_matrix: mat2x4<output_element_t> = mat2x4<output_element_t>(
77-
zero_point, zero_point, zero_point, zero_point,
78-
zero_point, zero_point, zero_point, zero_point
79-
);
80-
81-
var dequantized_values: mat2x4<output_element_t> = mat2x4<output_element_t>(
82-
output_element_t(lower_values[0]), output_element_t(upper_values[0]),
83-
output_element_t(lower_values[1]), output_element_t(upper_values[1]),
84-
output_element_t(lower_values[2]), output_element_t(upper_values[2]),
85-
output_element_t(lower_values[3]), output_element_t(upper_values[3])
86-
);
87-
88-
dequantized_values = (dequantized_values - zero_matrix) * scale;
89-
return dequantized_values;
90-
}
91-
)";
92-
}
93-
94-
shader.AdditionalImplementation() << "\n"
95-
<< "fn mm_read_scale(row : u32, col : u32) -> output_element_t {\n"
96-
<< " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n"
97-
<< " return scales[row * uniforms.input_b_shape[1] + col];\n"
98-
<< " }\n"
99-
<< " return output_element_t(0);\n"
100-
<< "}\n"
101-
<< GenerateZeroPointReadingCode(nbits_, has_zero_points_);
102-
103-
shader.AdditionalImplementation() << "\n"
104-
<< "fn mm_write_y(batch : u32, row : u32, col : u32, value : output_value_t) {\n"
105-
<< " if (row < uniforms.output_shape[1] && col < uniforms.output_shape[2]) {\n"
106-
<< " " << y.SetByIndices("output_indices_t(batch, row, col)", "value") << "\n"
107-
<< " }\n"
108-
<< "}\n";
109-
110-
// declare const variables
111-
shader.AdditionalImplementation() << "\n"
112-
<< "// A block32 containing 8 components of `a`." << "\n"
113-
<< "const kAComponentsForBlock32 = " << kAComponentsForBlock32 << "u;\n"
114-
<< "const kTileM = " << tile_m_ << "u;\n"
115-
<< "const kTileN = " << tile_n_ << "u;\n";
116-
117-
// declare workgroup memory
118-
shader.AdditionalImplementation() << "\n"
119-
<< "var<workgroup> a_data_tile: array<array<input_a_value_t, kAComponentsForBlock32>, kTileM>;\n"
120-
<< "\n";
121-
122-
// main
123-
shader.MainFunctionBody() << R"MAIN_FN(
124-
let batch = workgroup_idx / (uniforms.num_M_tile * uniforms.num_N_tile);
125-
let row = ((workgroup_idx / uniforms.num_N_tile) % uniforms.num_M_tile) * kTileM;
126-
let col = (workgroup_idx % uniforms.num_N_tile) * kTileN;
127-
128-
let a_elements_per_col = uniforms.input_a_shape[2];
129-
let a_blocks_per_col = (a_elements_per_col + kAComponentsForBlock32 - 1) / kAComponentsForBlock32;
130-
131-
// Utilizing an f32 accumulator mitigated precision loss with minimal
132-
// performance impact compared to an f16 accumulator.
133-
var results : array<f32, kTileM>;
134-
for (var a_block_idx = 0u; a_block_idx < a_blocks_per_col; a_block_idx++) {
135-
// Load `a` elements into workgroup memory, TileM x kAComponentsForBlock32 (block32)
136-
let a_row_idx = local_idx / kAComponentsForBlock32;
137-
let a_col_idx = local_idx % kAComponentsForBlock32;
138-
a_data_tile[a_row_idx][a_col_idx] = mm_read_a(batch, row + a_row_idx, a_block_idx * kAComponentsForBlock32 + a_col_idx);
139-
workgroupBarrier();
140-
141-
let b_row = col + local_idx;
142-
let b_col = a_block_idx;
143-
144-
let scale = mm_read_scale(b_row, b_col);
145-
let zero_point = mm_read_zero(b_row, b_col, uniforms.input_b_shape[0], uniforms.zero_blocks_per_col);
146-
)MAIN_FN";
147-
148-
if (nbits_ == 4) {
149-
shader.MainFunctionBody() << R"MAIN_FN(
150-
let b_data = mm_read_b(b_row, b_col);
151-
// `b` component size is 4.
152-
for (var b_idx = 0u; b_idx < 4u; b_idx++) {
153-
let b_dequantized = dequantize_packed8xU4(b_data[b_idx], zero_point, scale);
154-
for (var m_idx = 0u; m_idx < kTileM; m_idx++) {
155-
let a_data0 = a_data_tile[m_idx][b_idx * 2u];
156-
let a_data1 = a_data_tile[m_idx][b_idx * 2u + 1u];
157-
158-
results[m_idx] += f32(dot(a_data0, b_dequantized[0])) + f32(dot(a_data1, b_dequantized[1]));
159-
}
160-
}
161-
)MAIN_FN";
162-
} else {
163-
shader.MainFunctionBody() << " var b_data0 = vec4<u32>(0);\n"
164-
" var b_data1 = vec4<u32>(0);\n"
165-
" if (b_row < uniforms.input_b_shape[0] && b_col < uniforms.input_b_shape[1]) {\n"
166-
<< " b_data0 = " << b.GetByIndices("input_b_indices_t(b_row, b_col, 0)") << ";\n"
167-
<< " b_data1 = " << b.GetByIndices("input_b_indices_t(b_row, b_col, 1)") << ";\n"
168-
" }"
169-
<< R"MAIN_FN(
170-
for (var b_idx = 0u; b_idx < 4u; b_idx++) {
171-
let b_dequantized0 = (vec4<output_element_t>(unpack4xU8(b_data0[b_idx])) - vec4<output_element_t>(zero_point)) * scale;
172-
let b_dequantized1 = (vec4<output_element_t>(unpack4xU8(b_data1[b_idx])) - vec4<output_element_t>(zero_point)) * scale;
173-
for (var m_idx = 0u; m_idx < kTileM; m_idx++) {
174-
let a_data0 = a_data_tile[m_idx][b_idx];
175-
let a_data1 = a_data_tile[m_idx][b_idx + 4u];
176-
177-
results[m_idx] += f32(dot(a_data0, b_dequantized0)) + f32(dot(a_data1, b_dequantized1));
178-
}
179-
}
180-
)MAIN_FN";
181-
}
182-
183-
shader.MainFunctionBody() << R"MAIN_FN(
184-
185-
workgroupBarrier();
186-
}
187-
188-
if (batch >= uniforms.input_a_shape[0]) {
189-
return;
190-
}
191-
192-
// Write the results.
193-
for (var m_idx = 0u; m_idx < kTileM; m_idx++) {
194-
mm_write_y(batch, row + m_idx, col + local_idx, output_value_t(results[m_idx]));
195-
}
196-
)MAIN_FN";
197-
198-
return Status::OK();
58+
return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_wide_tile.wgsl.template",
59+
WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_),
60+
WGSL_TEMPLATE_PARAMETER(nbits, nbits_),
61+
WGSL_TEMPLATE_PARAMETER(tile_m, tile_m_),
62+
WGSL_TEMPLATE_PARAMETER(tile_n, tile_n_));
19963
}
20064

20165
// Apply similar idea with DP4AMatMulNBitsSmallMProgram algorithm.
@@ -408,38 +272,55 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
408272

409273
// WideTileProgram
410274
// This program is optimized for Block32 prefill using Tile16x128.
411-
const bool use_wide_tile_program = block_size == 32 && components_a == 4 && components_b == 4 && M >= kMinMForTileOptimization;
275+
const bool use_wide_tile_program = block_size == 32 &&
276+
components_a == 4 &&
277+
components_b == 4 &&
278+
M >= kMinMForTileOptimization;
412279
if (use_wide_tile_program) {
413280
// Enforce output components to 1.
414281
components = 1;
415282

416283
constexpr uint32_t workgroup_size = 128;
417284
constexpr uint32_t tile_m = workgroup_size / 8;
418285
constexpr uint32_t tile_n = workgroup_size;
419-
uint32_t num_N_tile = (N + tile_n - 1) / tile_n;
420-
uint32_t num_M_tile = (M + tile_m - 1) / tile_m;
286+
const uint32_t num_N_tile = ceil_div(N, tile_n);
287+
const uint32_t num_M_tile = ceil_div(M, tile_m);
421288

422289
MatMulNBitsWideTileProgram program{has_zero_points, tile_m, tile_n, nbits};
423290
program.SetWorkgroupSize(workgroup_size);
424-
program.SetDispatchGroupSize((N + tile_n - 1) / tile_n,
425-
(M + tile_m - 1) / tile_m,
426-
batch_count);
427-
program.CacheHint("Tile" + std::to_string(tile_m) + "x" + std::to_string(tile_n) + "_Block32");
428-
429-
TensorShape reshaped_a_shape{batch_count, M, K / components_a};
430-
TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b};
431-
TensorShape reshaped_y_shape{batch_count, M, N / components};
432-
433-
program
434-
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, onnxruntime::narrow<int>(components_a)},
435-
{b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, onnxruntime::narrow<int>(components_b * 4)},
436-
{scales, ProgramTensorMetadataDependency::None}})
437-
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, onnxruntime::narrow<int>(components)})
438-
.AddUniformVariables({{block_size}, {zero_blocks_per_col}, {num_N_tile}, {num_M_tile}})
439-
.CacheHint(nbits, has_zero_points);
291+
program.SetDispatchGroupSize(num_N_tile, num_M_tile, batch_count);
292+
293+
constexpr uint32_t kU32Components = 4;
294+
const uint32_t components_b_with_u32 = components_b * kU32Components;
295+
const uint32_t K_of_b = n_blocks_per_col * blob_size / components_b_with_u32;
296+
const uint32_t K_of_a = K / components_a;
297+
298+
program.AddInput({a,
299+
ProgramTensorMetadataDependency::TypeAndRank,
300+
onnxruntime::narrow<int>(components_a)});
301+
program.AddInput({b,
302+
ProgramTensorMetadataDependency::TypeAndRank,
303+
onnxruntime::narrow<int>(components_b_with_u32)});
304+
program.AddInput({scales, ProgramTensorMetadataDependency::TypeAndRank});
440305
if (has_zero_points) {
441-
program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
306+
program.AddInput({zero_points,
307+
ProgramTensorMetadataDependency::TypeAndRank,
308+
{ceil_div(zero_points->Shape().Size(), static_cast<int64_t>(4))},
309+
4});
442310
}
311+
program.AddOutput({y,
312+
ProgramTensorMetadataDependency::TypeAndRank,
313+
onnxruntime::narrow<int>(components)});
314+
program.AddUniformVariables({{batch_count},
315+
{M},
316+
{N},
317+
{K_of_a},
318+
{K_of_b},
319+
{n_blocks_per_col},
320+
{zero_blocks_per_col},
321+
{num_N_tile},
322+
{num_M_tile}});
323+
program.CacheHint(nbits, has_zero_points);
443324

444325
return context.RunProgram(program);
445326
}

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,15 @@ using namespace onnxruntime::webgpu;
1515
class MatMulNBitsWideTileProgram final : public Program<MatMulNBitsWideTileProgram> {
1616
public:
1717
MatMulNBitsWideTileProgram(bool has_zero_points, uint32_t tile_m, uint32_t tile_n, uint32_t nbits)
18-
: Program{"MatMulNBitsWideTileProgram"}, has_zero_points_{has_zero_points}, tile_m_(tile_m), tile_n_(tile_n), nbits_(nbits) {}
18+
: Program{"MatMulNBitsWideTile"}, has_zero_points_{has_zero_points}, tile_m_(tile_m), tile_n_(tile_n), nbits_(nbits) {}
1919

2020
Status GenerateShaderCode(ShaderHelper& sh) const override;
21-
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32},
21+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"Batch", ProgramUniformVariableDataType::Uint32},
22+
{"M", ProgramUniformVariableDataType::Uint32},
23+
{"N", ProgramUniformVariableDataType::Uint32},
24+
{"K_of_a", ProgramUniformVariableDataType::Uint32},
25+
{"K_of_b", ProgramUniformVariableDataType::Uint32},
26+
{"n_blocks_per_col", ProgramUniformVariableDataType::Uint32},
2227
{"zero_blocks_per_col", ProgramUniformVariableDataType::Uint32},
2328
{"num_N_tile", ProgramUniformVariableDataType::Uint32},
2429
{"num_M_tile", ProgramUniformVariableDataType::Uint32});

0 commit comments

Comments
 (0)