@@ -22,6 +22,11 @@ namespace {
2222
2323constexpr unsigned int kMinMForTileOptimization = 4 ;
2424
25+ template <typename T>
26+ inline T ceil_div (T numerator, T denominator) {
27+ return (numerator + denominator - 1 ) / denominator;
28+ }
29+
2530} // namespace
2631
2732ONNX_OPERATOR_KERNEL_EX (
@@ -37,165 +42,24 @@ ONNX_OPERATOR_KERNEL_EX(
3742 MatMulNBits);
3843
3944Status MatMulNBitsWideTileProgram::GenerateShaderCode (ShaderHelper& shader) const {
40- const auto & a = shader.AddInput (" input_a" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias );
41- const auto & b = shader.AddInput (" input_b" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias );
42- shader.AddInput (" scales" , ShaderUsage::UseUniform );
45+ shader.AddInput (" input_a" , ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias );
46+ shader.AddInput (" input_b" , ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias );
47+ shader.AddInput (" scales" , ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias );
4348 if (has_zero_points_) {
44- shader.AddInput (" zero_points" , ShaderUsage::UseUniform );
49+ shader.AddInput (" zero_points" , ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias );
4550 }
46- const auto & y = shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);
47-
48- // Bock size 32, `a` component size 4, 8 `a` components per block.
49- constexpr uint32_t kAComponentsForBlock32 = 8 ;
51+ shader.AddOutput (" output" , ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
5052
5153 const uint32_t workgroup_size = WorkgroupSizeX () * WorkgroupSizeY ();
5254 ORT_ENFORCE (tile_m_ == workgroup_size / 8 , " tile_m must be workgroup_size / 8." );
5355 ORT_ENFORCE (tile_n_ == workgroup_size, " tile_n must be workgroup_size." );
56+ ORT_ENFORCE (nbits_ == 4 || nbits_ == 8 , " Only 4/8 bits are supported for webgpu matmulnbits." );
5457
55- // memory read/write helpers
56- shader.AdditionalImplementation () << " fn mm_read_a(batch : u32, row : u32, col : u32) -> input_a_value_t {\n "
57- << " if (batch < uniforms.input_a_shape[0] && row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n "
58- << " return " << a.GetByIndices (" input_a_indices_t(batch, row, col)" ) << " ;\n "
59- << " }\n "
60- << " return input_a_value_t(0);\n "
61- << " }\n " ;
62- if (nbits_ == 4 ) {
63- shader.AdditionalImplementation () << " \n "
64- << " fn mm_read_b(row : u32, col : u32) -> input_b_value_t {\n "
65- << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n "
66- << " return " << b.GetByIndices (" input_b_indices_t(row, col, 0)" ) << " ;\n "
67- << " }\n "
68- << " return input_b_value_t(0);\n "
69- << " }\n " ;
70-
71- shader.AdditionalImplementation () << R"(
72- fn dequantize_packed8xU4(packed_value : u32, zero_point : output_element_t, scale : output_element_t) -> mat2x4<output_element_t> {
73- let lower_values: vec4<u32> = unpack4xU8(packed_value & 0x0F0F0F0Fu);
74- let upper_values: vec4<u32> = unpack4xU8((packed_value >> 4u) & 0x0F0F0F0Fu);
75-
76- let zero_matrix: mat2x4<output_element_t> = mat2x4<output_element_t>(
77- zero_point, zero_point, zero_point, zero_point,
78- zero_point, zero_point, zero_point, zero_point
79- );
80-
81- var dequantized_values: mat2x4<output_element_t> = mat2x4<output_element_t>(
82- output_element_t(lower_values[0]), output_element_t(upper_values[0]),
83- output_element_t(lower_values[1]), output_element_t(upper_values[1]),
84- output_element_t(lower_values[2]), output_element_t(upper_values[2]),
85- output_element_t(lower_values[3]), output_element_t(upper_values[3])
86- );
87-
88- dequantized_values = (dequantized_values - zero_matrix) * scale;
89- return dequantized_values;
90- }
91- )" ;
92- }
93-
94- shader.AdditionalImplementation () << " \n "
95- << " fn mm_read_scale(row : u32, col : u32) -> output_element_t {\n "
96- << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n "
97- << " return scales[row * uniforms.input_b_shape[1] + col];\n "
98- << " }\n "
99- << " return output_element_t(0);\n "
100- << " }\n "
101- << GenerateZeroPointReadingCode (nbits_, has_zero_points_);
102-
103- shader.AdditionalImplementation () << " \n "
104- << " fn mm_write_y(batch : u32, row : u32, col : u32, value : output_value_t) {\n "
105- << " if (row < uniforms.output_shape[1] && col < uniforms.output_shape[2]) {\n "
106- << " " << y.SetByIndices (" output_indices_t(batch, row, col)" , " value" ) << " \n "
107- << " }\n "
108- << " }\n " ;
109-
110- // declare const variables
111- shader.AdditionalImplementation () << " \n "
112- << " // A block32 containing 8 components of `a`." << " \n "
113- << " const kAComponentsForBlock32 = " << kAComponentsForBlock32 << " u;\n "
114- << " const kTileM = " << tile_m_ << " u;\n "
115- << " const kTileN = " << tile_n_ << " u;\n " ;
116-
117- // declare workgroup memory
118- shader.AdditionalImplementation () << " \n "
119- << " var<workgroup> a_data_tile: array<array<input_a_value_t, kAComponentsForBlock32>, kTileM>;\n "
120- << " \n " ;
121-
122- // main
123- shader.MainFunctionBody () << R"MAIN_FN(
124- let batch = workgroup_idx / (uniforms.num_M_tile * uniforms.num_N_tile);
125- let row = ((workgroup_idx / uniforms.num_N_tile) % uniforms.num_M_tile) * kTileM;
126- let col = (workgroup_idx % uniforms.num_N_tile) * kTileN;
127-
128- let a_elements_per_col = uniforms.input_a_shape[2];
129- let a_blocks_per_col = (a_elements_per_col + kAComponentsForBlock32 - 1) / kAComponentsForBlock32;
130-
131- // Utilizing an f32 accumulator mitigated precision loss with minimal
132- // performance impact compared to an f16 accumulator.
133- var results : array<f32, kTileM>;
134- for (var a_block_idx = 0u; a_block_idx < a_blocks_per_col; a_block_idx++) {
135- // Load `a` elements into workgroup memory, TileM x kAComponentsForBlock32 (block32)
136- let a_row_idx = local_idx / kAComponentsForBlock32;
137- let a_col_idx = local_idx % kAComponentsForBlock32;
138- a_data_tile[a_row_idx][a_col_idx] = mm_read_a(batch, row + a_row_idx, a_block_idx * kAComponentsForBlock32 + a_col_idx);
139- workgroupBarrier();
140-
141- let b_row = col + local_idx;
142- let b_col = a_block_idx;
143-
144- let scale = mm_read_scale(b_row, b_col);
145- let zero_point = mm_read_zero(b_row, b_col, uniforms.input_b_shape[0], uniforms.zero_blocks_per_col);
146- )MAIN_FN" ;
147-
148- if (nbits_ == 4 ) {
149- shader.MainFunctionBody () << R"MAIN_FN(
150- let b_data = mm_read_b(b_row, b_col);
151- // `b` component size is 4.
152- for (var b_idx = 0u; b_idx < 4u; b_idx++) {
153- let b_dequantized = dequantize_packed8xU4(b_data[b_idx], zero_point, scale);
154- for (var m_idx = 0u; m_idx < kTileM; m_idx++) {
155- let a_data0 = a_data_tile[m_idx][b_idx * 2u];
156- let a_data1 = a_data_tile[m_idx][b_idx * 2u + 1u];
157-
158- results[m_idx] += f32(dot(a_data0, b_dequantized[0])) + f32(dot(a_data1, b_dequantized[1]));
159- }
160- }
161- )MAIN_FN" ;
162- } else {
163- shader.MainFunctionBody () << " var b_data0 = vec4<u32>(0);\n "
164- " var b_data1 = vec4<u32>(0);\n "
165- " if (b_row < uniforms.input_b_shape[0] && b_col < uniforms.input_b_shape[1]) {\n "
166- << " b_data0 = " << b.GetByIndices (" input_b_indices_t(b_row, b_col, 0)" ) << " ;\n "
167- << " b_data1 = " << b.GetByIndices (" input_b_indices_t(b_row, b_col, 1)" ) << " ;\n "
168- " }"
169- << R"MAIN_FN(
170- for (var b_idx = 0u; b_idx < 4u; b_idx++) {
171- let b_dequantized0 = (vec4<output_element_t>(unpack4xU8(b_data0[b_idx])) - vec4<output_element_t>(zero_point)) * scale;
172- let b_dequantized1 = (vec4<output_element_t>(unpack4xU8(b_data1[b_idx])) - vec4<output_element_t>(zero_point)) * scale;
173- for (var m_idx = 0u; m_idx < kTileM; m_idx++) {
174- let a_data0 = a_data_tile[m_idx][b_idx];
175- let a_data1 = a_data_tile[m_idx][b_idx + 4u];
176-
177- results[m_idx] += f32(dot(a_data0, b_dequantized0)) + f32(dot(a_data1, b_dequantized1));
178- }
179- }
180- )MAIN_FN" ;
181- }
182-
183- shader.MainFunctionBody () << R"MAIN_FN(
184-
185- workgroupBarrier();
186- }
187-
188- if (batch >= uniforms.input_a_shape[0]) {
189- return;
190- }
191-
192- // Write the results.
193- for (var m_idx = 0u; m_idx < kTileM; m_idx++) {
194- mm_write_y(batch, row + m_idx, col + local_idx, output_value_t(results[m_idx]));
195- }
196- )MAIN_FN" ;
197-
198- return Status::OK ();
58+ return WGSL_TEMPLATE_APPLY (shader, " quantization/matmul_nbits_wide_tile.wgsl.template" ,
59+ WGSL_TEMPLATE_PARAMETER (has_zero_points, has_zero_points_),
60+ WGSL_TEMPLATE_PARAMETER (nbits, nbits_),
61+ WGSL_TEMPLATE_PARAMETER (tile_m, tile_m_),
62+ WGSL_TEMPLATE_PARAMETER (tile_n, tile_n_));
19963}
20064
20165// Apply similar idea with DP4AMatMulNBitsSmallMProgram algorithm.
@@ -408,38 +272,55 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
408272
409273 // WideTileProgram
410274 // This program is optimized for Block32 prefill using Tile16x128.
411- const bool use_wide_tile_program = block_size == 32 && components_a == 4 && components_b == 4 && M >= kMinMForTileOptimization ;
275+ const bool use_wide_tile_program = block_size == 32 &&
276+ components_a == 4 &&
277+ components_b == 4 &&
278+ M >= kMinMForTileOptimization ;
412279 if (use_wide_tile_program) {
413280 // Enforce output components to 1.
414281 components = 1 ;
415282
416283 constexpr uint32_t workgroup_size = 128 ;
417284 constexpr uint32_t tile_m = workgroup_size / 8 ;
418285 constexpr uint32_t tile_n = workgroup_size;
419- uint32_t num_N_tile = (N + tile_n - 1 ) / tile_n ;
420- uint32_t num_M_tile = (M + tile_m - 1 ) / tile_m ;
286+ const uint32_t num_N_tile = ceil_div (N, tile_n) ;
287+ const uint32_t num_M_tile = ceil_div (M, tile_m) ;
421288
422289 MatMulNBitsWideTileProgram program{has_zero_points, tile_m, tile_n, nbits};
423290 program.SetWorkgroupSize (workgroup_size);
424- program.SetDispatchGroupSize ((N + tile_n - 1 ) / tile_n,
425- (M + tile_m - 1 ) / tile_m,
426- batch_count);
427- program.CacheHint (" Tile" + std::to_string (tile_m) + " x" + std::to_string (tile_n) + " _Block32" );
428-
429- TensorShape reshaped_a_shape{batch_count, M, K / components_a};
430- TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b};
431- TensorShape reshaped_y_shape{batch_count, M, N / components};
432-
433- program
434- .AddInputs ({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, onnxruntime::narrow<int >(components_a)},
435- {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, onnxruntime::narrow<int >(components_b * 4 )},
436- {scales, ProgramTensorMetadataDependency::None}})
437- .AddOutput ({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, onnxruntime::narrow<int >(components)})
438- .AddUniformVariables ({{block_size}, {zero_blocks_per_col}, {num_N_tile}, {num_M_tile}})
439- .CacheHint (nbits, has_zero_points);
291+ program.SetDispatchGroupSize (num_N_tile, num_M_tile, batch_count);
292+
293+ constexpr uint32_t kU32Components = 4 ;
294+ const uint32_t components_b_with_u32 = components_b * kU32Components ;
295+ const uint32_t K_of_b = n_blocks_per_col * blob_size / components_b_with_u32;
296+ const uint32_t K_of_a = K / components_a;
297+
298+ program.AddInput ({a,
299+ ProgramTensorMetadataDependency::TypeAndRank,
300+ onnxruntime::narrow<int >(components_a)});
301+ program.AddInput ({b,
302+ ProgramTensorMetadataDependency::TypeAndRank,
303+ onnxruntime::narrow<int >(components_b_with_u32)});
304+ program.AddInput ({scales, ProgramTensorMetadataDependency::TypeAndRank});
440305 if (has_zero_points) {
441- program.AddInput ({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape ().Size () + 3 ) / 4 }, 4 });
306+ program.AddInput ({zero_points,
307+ ProgramTensorMetadataDependency::TypeAndRank,
308+ {ceil_div (zero_points->Shape ().Size (), static_cast <int64_t >(4 ))},
309+ 4 });
442310 }
311+ program.AddOutput ({y,
312+ ProgramTensorMetadataDependency::TypeAndRank,
313+ onnxruntime::narrow<int >(components)});
314+ program.AddUniformVariables ({{batch_count},
315+ {M},
316+ {N},
317+ {K_of_a},
318+ {K_of_b},
319+ {n_blocks_per_col},
320+ {zero_blocks_per_col},
321+ {num_N_tile},
322+ {num_M_tile}});
323+ program.CacheHint (nbits, has_zero_points);
443324
444325 return context.RunProgram (program);
445326 }
0 commit comments