1
1
"""
2
2
Block FP8 Gemm benchmark
3
3
============================
4
-
5
4
This benchmark is come from SGLang kernels.
6
5
https://github.com/sgl-project/sglang/blob/07f944631e747d7489fde1f11de93e503afa90ba/python/sglang/srt/layers/quantization/fp8_kernel.py#L375
7
-
8
6
"""
9
7
10
- import functools
11
- import json
12
- import logging
13
- import os
14
- from typing import Any , Dict , List , Optional
8
+ from typing import List
15
9
16
10
import torch
17
11
import triton
18
12
import triton .language as tl
19
13
20
14
import triton_kernels_benchmark as benchmark_suit
21
15
22
- logger = logging .getLogger (__name__ )
16
+ DEVICE_NAME = torch .xpu .get_device_name ()
17
+ DEVICE_TOTAL_MEMORY = torch .xpu .get_device_properties ().total_memory
23
18
24
19
25
20
@triton .jit
@@ -107,42 +102,6 @@ def _w8a8_block_fp8_matmul(
107
102
tl .store (c_ptrs , c , mask = c_mask )
108
103
109
104
110
- @functools .lru_cache
111
- def get_w8a8_block_fp8_configs (N : int , K : int , block_n : int , block_k : int ) -> Optional [Dict [int , Any ]]:
112
- """
113
- Return optimized configurations for the w8a8 block fp8 kernel.
114
-
115
- The return value will be a dictionary that maps an irregular grid of
116
- batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
117
- kernel on a given batch size bs, the closest batch size in the grid should
118
- be picked and the associated configuration chosen to invoke the kernel.
119
- """
120
-
121
- # First look up if an optimized configuration is available in the configs
122
- # directory
123
- device_name = torch .xpu .get_device_name (0 ).replace (" " , "_" )
124
- json_file_name = f"N={ N } ,K={ K } ,device_name={ device_name } ,dtype=fp8_w8a8,block_shape=[{ block_n } , { block_k } ].json"
125
-
126
- config_file_path = os .path .join (os .path .dirname (os .path .realpath (__file__ )), "configs" , json_file_name )
127
- if os .path .exists (config_file_path ):
128
- with open (config_file_path , "r" , encoding = "utf-8" ) as f :
129
- logger .info (
130
- "Using configuration from %s for W8A8 Block FP8 kernel." ,
131
- config_file_path ,
132
- )
133
- # If a configuration has been found, return it
134
- return {int (key ): val for key , val in json .load (f ).items ()}
135
-
136
- # If no optimized configuration is available, we will use the default
137
- # configuration
138
- logger .warning (
139
- ("Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! "
140
- "Config file not found at %s" ),
141
- config_file_path ,
142
- )
143
- return None
144
-
145
-
146
105
def w8a8_block_fp8_matmul (
147
106
A : torch .Tensor ,
148
107
B : torch .Tensor ,
@@ -152,18 +111,15 @@ def w8a8_block_fp8_matmul(
152
111
output_dtype : torch .dtype = torch .float16 ,
153
112
) -> torch .Tensor :
154
113
"""This function performs matrix multiplication with block-wise quantization.
155
-
156
114
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
157
115
The output is returned in the specified `output_dtype`.
158
-
159
116
Args:
160
117
A: The input tensor, e.g., activation.
161
118
B: The input tensor, e.g., weight.
162
119
As: The per-token-group quantization scale for `A`.
163
120
Bs: The per-block quantization scale for `B`.
164
121
block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
165
122
output_dytpe: The dtype of the returned tensor.
166
-
167
123
Returns:
168
124
torch.Tensor: The result of matmul.
169
125
"""
@@ -183,22 +139,16 @@ def w8a8_block_fp8_matmul(
183
139
C_shape = A .shape [:- 1 ] + (N , )
184
140
C = A .new_empty (C_shape , dtype = output_dtype )
185
141
186
- configs = get_w8a8_block_fp8_configs (N , K , block_size [0 ], block_size [1 ])
187
- if configs :
188
- # If an optimal configuration map has been found, look up the
189
- # optimal config
190
- config = configs [min (configs .keys (), key = lambda x : abs (x - M ))]
191
- else :
192
- # Default config
193
- # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
194
- config = {
195
- "BLOCK_SIZE_M" : 64 ,
196
- "BLOCK_SIZE_N" : block_size [0 ],
197
- "BLOCK_SIZE_K" : block_size [1 ],
198
- "GROUP_SIZE_M" : 32 ,
199
- "num_warps" : 4 ,
200
- "num_stages" : 3 ,
201
- }
142
+ # Default config
143
+ # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
144
+ config = {
145
+ "BLOCK_SIZE_M" : 64 ,
146
+ "BLOCK_SIZE_N" : block_size [0 ],
147
+ "BLOCK_SIZE_K" : block_size [1 ],
148
+ "GROUP_SIZE_M" : 32 ,
149
+ "num_warps" : 4 ,
150
+ "num_stages" : 3 ,
151
+ }
202
152
203
153
def grid (META ):
204
154
return (triton .cdiv (M , META ["BLOCK_SIZE_M" ]) * triton .cdiv (N , META ["BLOCK_SIZE_N" ]), )
@@ -232,7 +182,7 @@ def grid(META):
232
182
return C
233
183
234
184
235
- # Reference path
185
+ # For test
236
186
def native_w8a8_block_fp8_matmul (A , B , As , Bs , block_size , output_dtype = torch .float16 ):
237
187
"""This function performs matrix multiplication with block-wise quantization using native torch.
238
188
@@ -284,55 +234,51 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl
284
234
return C
285
235
286
236
287
- X_VALS = [[1 , 1024 * i , 1024 * i , 1024 * i ] for i in [1 , 2 , 4 , 8 ]] + [
288
- [1 , 1 , 13824 , 5120 ],
289
- [1 , 4 , 12288 , 4096 ],
290
- [1 , 512 , 8192 , 8192 ],
291
- [1 , 512 , 8192 , 32768 ],
292
- [1 , 512 , 32768 , 8192 ],
293
- [1 , 1024 , 8192 , 16384 ],
294
- [1 , 1024 , 8192 , 28672 ],
295
- [1 , 3072 , 3072 , 4096 ], # FIXME: Remove this case when gemm_streamk_benchmark can get better performance
296
- [1 , 4096 , 8192 , 16384 ],
297
- [1 , 8192 , 1024 , 16384 ],
298
- [1 , 8192 , 4096 , 16384 ],
299
- [1 , 16384 , 1024 , 8192 ],
300
- [1 , 16384 , 4096 , 8192 ],
301
- [1 , 16384 , 8192 , 1024 ],
302
- [1 , 16384 , 8192 , 4096 ],
303
- [4 , 32768 , 128 , 4096 ],
304
- [4 , 32768 , 4096 , 128 ],
305
- [32 , 4096 , 128 , 4096 ],
306
- [4096 , 8 , 128 , 16384 ],
307
- [4096 , 8 , 16384 , 128 ],
308
- ]
309
-
310
- DEVICE_NAME = torch .xpu .get_device_name ()
311
- DEVICE_TOTAL_MEMORY = torch .xpu .get_device_properties ().total_memory
312
-
313
-
314
- def is_enough_memory (x_val ):
315
- # x_val: (B, M, N, K)
316
- B , M , N , K = x_val
317
- # a: (B, M, K) float8_e4m3
318
- # b: (B, N, K) float8_e4m3
319
- # c: (B, M, N) bfloat16
320
- # pytorch reference: (B, M, N) float32
321
- required_memory = B * M * K * 1 + B * N * K * 1 + B * M * N * 2 * 2
237
+ def has_enough_memory (x_val ):
238
+ # x_val: (M, N, K)
239
+ M , N , K = x_val
240
+ # a: (M, K) float8_e4m3
241
+ # b: (N, K) float8_e4m3
242
+ # c: (M, N) bfloat16
243
+ # pytorch reference: (M, N) float32
244
+ required_memory = M * K * 1 + N * K * 1 + M * N * 2 * 2
322
245
enough_memory = required_memory < DEVICE_TOTAL_MEMORY
323
246
if not enough_memory :
324
247
print (f"'{ x_val } ' combination skipped for '{ DEVICE_NAME } '; { required_memory = } but { DEVICE_TOTAL_MEMORY = } " )
325
248
return enough_memory
326
249
327
250
328
- X_VALS = [x_val for x_val in X_VALS if is_enough_memory (x_val )]
251
+ X_VALS = [[1024 * i , 1024 * i , 1024 * i ] for i in [1 , 2 , 4 , 8 ]] + [
252
+ [1 , 13824 , 5120 ],
253
+ [4 , 12288 , 4096 ],
254
+ [512 , 8192 , 8192 ],
255
+ [512 , 8192 , 32768 ],
256
+ [512 , 32768 , 8192 ],
257
+ [1024 , 8192 , 16384 ],
258
+ [1024 , 8192 , 28672 ],
259
+ [3072 , 3072 , 4096 ],
260
+ [4096 , 8192 , 16384 ],
261
+ [8192 , 1024 , 16384 ],
262
+ [8192 , 4096 , 16384 ],
263
+ [16384 , 1024 , 8192 ],
264
+ [16384 , 4096 , 8192 ],
265
+ [16384 , 8192 , 1024 ],
266
+ [16384 , 8192 , 4096 ],
267
+ [32768 , 128 , 4096 ],
268
+ [32768 , 4096 , 128 ],
269
+ [4096 , 128 , 4096 ],
270
+ [8 , 128 , 16384 ],
271
+ [8 , 16384 , 128 ],
272
+ ]
273
+
274
+ X_VALS = [x_val for x_val in X_VALS if has_enough_memory (x_val )]
329
275
330
276
331
277
# Benchmark Performance
332
278
@benchmark_suit .perf_report (
333
279
benchmark_suit .Benchmark (
334
280
# argument names to use as an x-axis for the plot
335
- x_names = ["B" , " M" , "N" , "K" ],
281
+ x_names = ["M" , "N" , "K" ],
336
282
# different possible values for `x_name`
337
283
x_vals = X_VALS ,
338
284
line_arg = "provider" ,
@@ -342,16 +288,14 @@ def is_enough_memory(x_val):
342
288
line_names = ["Triton" ],
343
289
# line styles
344
290
ylabel = ["GB/s" , "TFlops" ], # label name for the y-axis
345
- plot_name = "matmul -performance" ,
291
+ plot_name = "sglang-fp8-gemm -performance" ,
346
292
# name for the plot. Used also as a file name for saving the plot.
347
293
args = {},
348
294
))
349
- def benchmark (B , M , N , K , provider ):
350
- assert provider == "triton"
295
+ def benchmark (M , N , K , provider ):
296
+ torch . manual_seed ( 0 )
351
297
352
298
block_size = [128 , 128 ]
353
-
354
- torch .manual_seed (0 )
355
299
factor_for_scale = 1e-2
356
300
fp8_info = torch .finfo (torch .float8_e4m3fn )
357
301
fp8_max , fp8_min = fp8_info .max , fp8_info .min
@@ -371,15 +315,18 @@ def benchmark(B, M, N, K, provider):
371
315
372
316
quantiles = [0.5 , 0.0 , 1.0 ]
373
317
374
- triton_fn = lambda : w8a8_block_fp8_matmul (A_fp8 , B_fp8 , As , Bs , block_size )
375
- torch_fn = lambda : native_w8a8_block_fp8_matmul (A_fp8 , B_fp8 , As , Bs , block_size )
376
- rtol = 1e-2
377
- atol = 3e-4
378
- benchmark_suit .assert_close (triton_fn , torch_fn , atol = atol , rtol = rtol , err_msg = "triton to torch" )
379
- _ , min_ms , max_ms , mean_ms , cv = benchmark_suit .do_bench (triton_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles )
318
+ if provider == "triton" :
319
+ triton_fn = lambda : w8a8_block_fp8_matmul (A_fp8 , B_fp8 , As , Bs , block_size )
320
+ torch_fn = lambda : native_w8a8_block_fp8_matmul (A_fp8 , B_fp8 , As , Bs , block_size )
321
+ benchmark_suit .assert_close (triton_fn , torch_fn , atol = 3e-4 , rtol = 1e-2 , err_msg = "triton to torch" )
322
+ _ , min_ms , max_ms , mean_ms , cv = benchmark_suit .do_bench (triton_fn , n_warmup = 10 , n_repeat = 10 ,
323
+ quantiles = quantiles )
324
+
325
+ else :
326
+ raise NotImplementedError (f"Unsupported provider { provider } " )
380
327
381
- tflops = lambda ms : 2 * B * M * N * K * (1e-12 ) / (ms * 1e-3 )
382
- gbps = lambda ms : B * (( M * K + K * N ) + 2.0 * (M * N ) ) * (1e-9 ) / (ms * 1e-3 )
328
+ tflops = lambda ms : 2 * M * N * K * (1e-12 ) / (ms * 1e-3 )
329
+ gbps = lambda ms : ( M * K + K * N ) + 2.0 * (M * N ) * (1e-9 ) / (ms * 1e-3 )
383
330
384
331
return (gbps (mean_ms ), gbps (max_ms ), gbps (min_ms )), (tflops (mean_ms ), tflops (max_ms ), tflops (min_ms )), cv
385
332
0 commit comments