Skip to content

Commit a0fa375

Browse files
Move fp8 gemm to sglang benchmark
1 parent 41b9e0d commit a0fa375

File tree

4 files changed

+73
-136
lines changed

4 files changed

+73
-136
lines changed

.github/workflows/third-party-benchmarks.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ jobs:
111111
- name: Install SGLANG
112112
run: |
113113
git clone https://github.com/sgl-project/sglang.git
114-
pip install sglang/python[srt_xpu]
114+
pip install sglang/python[dev_xpu]
115115
116116
- name: Run SGLANG attention prefill stage benchmark
117117
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
@@ -140,6 +140,15 @@ jobs:
140140
source ../../../scripts/capture-hw-details.sh
141141
python ../../triton_kernels_benchmark/build_report.py $REPORTS/extended-attn-performance.csv $REPORTS/attn-append-triton-report.csv --benchmark sglang-extended-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
142142
143+
- name: Run SGLANG Block FP8 GEMM benchmark
144+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'block_fp8_gemm_benchmark.py') }}
145+
run: |
146+
cd benchmarks/third_party/sglang
147+
python block_fp8_gemm_benchmark.py --reports $REPORTS
148+
149+
source ../../../scripts/capture-hw-details.sh
150+
python ../../../scripts/build_report.py $REPORTS/sglang-fp8-gemm-performance.csv $REPORTS/sglang-fp8-gemm-triton-report.csv --benchmark sglang-block-fp8-gemm --compiler triton --param_cols "M,N,K" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
151+
143152
- name: Upload benchmark reports
144153
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
145154
uses: actions/upload-artifact@v4

.github/workflows/triton-benchmarks.yml

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -332,14 +332,6 @@ jobs:
332332
source ../../scripts/capture-hw-details.sh
333333
python build_report.py $REPORTS/prefix-sums.csv $REPORTS/prefix_sums-triton-report.csv --benchmark prefix_sums --compiler triton --param_cols "N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
334334
335-
- name: Run SGLang FP8 GEMM benchmark
336-
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'prefix_sums.py') }}
337-
run: |
338-
cd benchmarks/triton_kernels_benchmark/sglang
339-
python block_fp8_matmul.py --reports $REPORTS
340-
source ../../scripts/capture-hw-details.sh
341-
python ../../scripts/build_report.py $REPORTS/block_fp8_matmul.csv $REPORTS/block_fp8_matmul-triton-report.csv --benchmark block_fp8_matmul --compiler triton --param_cols "N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
342-
343335
- name: Run micro benchmark
344336
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'micro_benchmarks.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'micro_benchmarks') }}
345337
run: |

benchmarks/triton_kernels_benchmark/sglang/block_fp8_matmul.py renamed to benchmarks/third_party/sglang/block_fp8_gemm_benchmark.py

Lines changed: 61 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,20 @@
11
"""
22
Block FP8 Gemm benchmark
33
============================
4-
54
This benchmark is come from SGLang kernels.
65
https://github.com/sgl-project/sglang/blob/07f944631e747d7489fde1f11de93e503afa90ba/python/sglang/srt/layers/quantization/fp8_kernel.py#L375
7-
86
"""
97

10-
import functools
11-
import json
12-
import logging
13-
import os
14-
from typing import Any, Dict, List, Optional
8+
from typing import List
159

1610
import torch
1711
import triton
1812
import triton.language as tl
1913

2014
import triton_kernels_benchmark as benchmark_suit
2115

22-
logger = logging.getLogger(__name__)
16+
DEVICE_NAME = torch.xpu.get_device_name()
17+
DEVICE_TOTAL_MEMORY = torch.xpu.get_device_properties().total_memory
2318

2419

2520
@triton.jit
@@ -107,42 +102,6 @@ def _w8a8_block_fp8_matmul(
107102
tl.store(c_ptrs, c, mask=c_mask)
108103

109104

110-
@functools.lru_cache
111-
def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, block_k: int) -> Optional[Dict[int, Any]]:
112-
"""
113-
Return optimized configurations for the w8a8 block fp8 kernel.
114-
115-
The return value will be a dictionary that maps an irregular grid of
116-
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
117-
kernel on a given batch size bs, the closest batch size in the grid should
118-
be picked and the associated configuration chosen to invoke the kernel.
119-
"""
120-
121-
# First look up if an optimized configuration is available in the configs
122-
# directory
123-
device_name = torch.xpu.get_device_name(0).replace(" ", "_")
124-
json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n}, {block_k}].json"
125-
126-
config_file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
127-
if os.path.exists(config_file_path):
128-
with open(config_file_path, "r", encoding="utf-8") as f:
129-
logger.info(
130-
"Using configuration from %s for W8A8 Block FP8 kernel.",
131-
config_file_path,
132-
)
133-
# If a configuration has been found, return it
134-
return {int(key): val for key, val in json.load(f).items()}
135-
136-
# If no optimized configuration is available, we will use the default
137-
# configuration
138-
logger.warning(
139-
("Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! "
140-
"Config file not found at %s"),
141-
config_file_path,
142-
)
143-
return None
144-
145-
146105
def w8a8_block_fp8_matmul(
147106
A: torch.Tensor,
148107
B: torch.Tensor,
@@ -152,18 +111,15 @@ def w8a8_block_fp8_matmul(
152111
output_dtype: torch.dtype = torch.float16,
153112
) -> torch.Tensor:
154113
"""This function performs matrix multiplication with block-wise quantization.
155-
156114
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
157115
The output is returned in the specified `output_dtype`.
158-
159116
Args:
160117
A: The input tensor, e.g., activation.
161118
B: The input tensor, e.g., weight.
162119
As: The per-token-group quantization scale for `A`.
163120
Bs: The per-block quantization scale for `B`.
164121
block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
165122
output_dytpe: The dtype of the returned tensor.
166-
167123
Returns:
168124
torch.Tensor: The result of matmul.
169125
"""
@@ -183,22 +139,16 @@ def w8a8_block_fp8_matmul(
183139
C_shape = A.shape[:-1] + (N, )
184140
C = A.new_empty(C_shape, dtype=output_dtype)
185141

186-
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
187-
if configs:
188-
# If an optimal configuration map has been found, look up the
189-
# optimal config
190-
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
191-
else:
192-
# Default config
193-
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
194-
config = {
195-
"BLOCK_SIZE_M": 64,
196-
"BLOCK_SIZE_N": block_size[0],
197-
"BLOCK_SIZE_K": block_size[1],
198-
"GROUP_SIZE_M": 32,
199-
"num_warps": 4,
200-
"num_stages": 3,
201-
}
142+
# Default config
143+
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
144+
config = {
145+
"BLOCK_SIZE_M": 64,
146+
"BLOCK_SIZE_N": block_size[0],
147+
"BLOCK_SIZE_K": block_size[1],
148+
"GROUP_SIZE_M": 32,
149+
"num_warps": 4,
150+
"num_stages": 3,
151+
}
202152

203153
def grid(META):
204154
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), )
@@ -232,7 +182,7 @@ def grid(META):
232182
return C
233183

234184

235-
# Reference path
185+
# For test
236186
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
237187
"""This function performs matrix multiplication with block-wise quantization using native torch.
238188
@@ -284,55 +234,51 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl
284234
return C
285235

286236

287-
X_VALS = [[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + [
288-
[1, 1, 13824, 5120],
289-
[1, 4, 12288, 4096],
290-
[1, 512, 8192, 8192],
291-
[1, 512, 8192, 32768],
292-
[1, 512, 32768, 8192],
293-
[1, 1024, 8192, 16384],
294-
[1, 1024, 8192, 28672],
295-
[1, 3072, 3072, 4096], # FIXME: Remove this case when gemm_streamk_benchmark can get better performance
296-
[1, 4096, 8192, 16384],
297-
[1, 8192, 1024, 16384],
298-
[1, 8192, 4096, 16384],
299-
[1, 16384, 1024, 8192],
300-
[1, 16384, 4096, 8192],
301-
[1, 16384, 8192, 1024],
302-
[1, 16384, 8192, 4096],
303-
[4, 32768, 128, 4096],
304-
[4, 32768, 4096, 128],
305-
[32, 4096, 128, 4096],
306-
[4096, 8, 128, 16384],
307-
[4096, 8, 16384, 128],
308-
]
309-
310-
DEVICE_NAME = torch.xpu.get_device_name()
311-
DEVICE_TOTAL_MEMORY = torch.xpu.get_device_properties().total_memory
312-
313-
314-
def is_enough_memory(x_val):
315-
# x_val: (B, M, N, K)
316-
B, M, N, K = x_val
317-
# a: (B, M, K) float8_e4m3
318-
# b: (B, N, K) float8_e4m3
319-
# c: (B, M, N) bfloat16
320-
# pytorch reference: (B, M, N) float32
321-
required_memory = B * M * K * 1 + B * N * K * 1 + B * M * N * 2 * 2
237+
def has_enough_memory(x_val):
238+
# x_val: (M, N, K)
239+
M, N, K = x_val
240+
# a: (M, K) float8_e4m3
241+
# b: (N, K) float8_e4m3
242+
# c: (M, N) bfloat16
243+
# pytorch reference: (M, N) float32
244+
required_memory = M * K * 1 + N * K * 1 + M * N * 2 * 2
322245
enough_memory = required_memory < DEVICE_TOTAL_MEMORY
323246
if not enough_memory:
324247
print(f"'{x_val}' combination skipped for '{DEVICE_NAME}'; {required_memory=} but {DEVICE_TOTAL_MEMORY=}")
325248
return enough_memory
326249

327250

328-
X_VALS = [x_val for x_val in X_VALS if is_enough_memory(x_val)]
251+
X_VALS = [[1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + [
252+
[1, 13824, 5120],
253+
[4, 12288, 4096],
254+
[512, 8192, 8192],
255+
[512, 8192, 32768],
256+
[512, 32768, 8192],
257+
[1024, 8192, 16384],
258+
[1024, 8192, 28672],
259+
[3072, 3072, 4096],
260+
[4096, 8192, 16384],
261+
[8192, 1024, 16384],
262+
[8192, 4096, 16384],
263+
[16384, 1024, 8192],
264+
[16384, 4096, 8192],
265+
[16384, 8192, 1024],
266+
[16384, 8192, 4096],
267+
[32768, 128, 4096],
268+
[32768, 4096, 128],
269+
[4096, 128, 4096],
270+
[8, 128, 16384],
271+
[8, 16384, 128],
272+
]
273+
274+
X_VALS = [x_val for x_val in X_VALS if has_enough_memory(x_val)]
329275

330276

331277
# Benchmark Performance
332278
@benchmark_suit.perf_report(
333279
benchmark_suit.Benchmark(
334280
# argument names to use as an x-axis for the plot
335-
x_names=["B", "M", "N", "K"],
281+
x_names=["M", "N", "K"],
336282
# different possible values for `x_name`
337283
x_vals=X_VALS,
338284
line_arg="provider",
@@ -342,16 +288,14 @@ def is_enough_memory(x_val):
342288
line_names=["Triton"],
343289
# line styles
344290
ylabel=["GB/s", "TFlops"], # label name for the y-axis
345-
plot_name="matmul-performance",
291+
plot_name="sglang-fp8-gemm-performance",
346292
# name for the plot. Used also as a file name for saving the plot.
347293
args={},
348294
))
349-
def benchmark(B, M, N, K, provider):
350-
assert provider == "triton"
295+
def benchmark(M, N, K, provider):
296+
torch.manual_seed(0)
351297

352298
block_size = [128, 128]
353-
354-
torch.manual_seed(0)
355299
factor_for_scale = 1e-2
356300
fp8_info = torch.finfo(torch.float8_e4m3fn)
357301
fp8_max, fp8_min = fp8_info.max, fp8_info.min
@@ -371,15 +315,18 @@ def benchmark(B, M, N, K, provider):
371315

372316
quantiles = [0.5, 0.0, 1.0]
373317

374-
triton_fn = lambda: w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size)
375-
torch_fn = lambda: native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size)
376-
rtol = 1e-2
377-
atol = 3e-4
378-
benchmark_suit.assert_close(triton_fn, torch_fn, atol=atol, rtol=rtol, err_msg="triton to torch")
379-
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
318+
if provider == "triton":
319+
triton_fn = lambda: w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size)
320+
torch_fn = lambda: native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size)
321+
benchmark_suit.assert_close(triton_fn, torch_fn, atol=3e-4, rtol=1e-2, err_msg="triton to torch")
322+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
323+
quantiles=quantiles)
324+
325+
else:
326+
raise NotImplementedError(f"Unsupported provider {provider}")
380327

381-
tflops = lambda ms: 2 * B * M * N * K * (1e-12) / (ms * 1e-3)
382-
gbps = lambda ms: B * ((M * K + K * N) + 2.0 * (M * N)) * (1e-9) / (ms * 1e-3)
328+
tflops = lambda ms: 2 * M * N * K * (1e-12) / (ms * 1e-3)
329+
gbps = lambda ms: (M * K + K * N) + 2.0 * (M * N) * (1e-9) / (ms * 1e-3)
383330

384331
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv
385332

benchmarks/third_party/sglang/decode_attention_benchmark.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,19 +82,8 @@ def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, VALIDATE, provider):
8282
B, N_CTX, H_Q, H_KV, D, dtype, 'xpu')
8383

8484
if provider == 'triton':
85-
triton_fn = lambda: decode_attention_fwd(
86-
q,
87-
k_buffer,
88-
v_buffer,
89-
o,
90-
kv_indptr,
91-
kv_indices,
92-
attn_logits,
93-
attn_lse,
94-
num_kv_splits,
95-
max_kv_splits,
96-
sm_scale,
97-
)
85+
triton_fn = lambda: decode_attention_fwd(q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse,
86+
num_kv_splits, max_kv_splits, sm_scale)
9887

9988
# TODO: decode attention should have the validation function
10089
if VALIDATE:

0 commit comments

Comments
 (0)