Skip to content

Commit f2a8926

Browse files
Integrate sglang prefill/decode/extend kernel to benchmarks
Port prefill attn and decode attn from sglang Add validation temp add extend attention disable debug ir dump Update three stage attention benchmark Add sglang kernel benchmark to action use 1e-3 atol remove sglang benchmark from triton-benchmarks Fix setup bdist_wheel
1 parent 03e391f commit f2a8926

File tree

6 files changed

+346
-3
lines changed

6 files changed

+346
-3
lines changed

.github/pins/sglang.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
0.4.4.post1

.github/workflows/third-party-benchmarks.yml

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,14 @@ jobs:
7878
run: |
7979
pip install python/dist/*.whl
8080
81-
- name: Install benchmark dependencies
81+
- name: Install benchmarks
8282
id: install
83+
run: |
84+
cd benchmarks
85+
pip install .
86+
87+
- name: Install benchmark dependencies
88+
id: install_deps
8389
run: |
8490
pip install transformers pandas pytest
8591
@@ -89,7 +95,7 @@ jobs:
8995
echo "REPORTS=$PWD/reports" >> $GITHUB_ENV
9096
9197
- name: Run Liger-Kernel benchmarks
92-
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
98+
if: ${{ steps.install_deps.outcome == 'success' && !cancelled() }}
9399
run: |
94100
source ./scripts/capture-hw-details.sh
95101
@@ -108,6 +114,38 @@ jobs:
108114
# Return the captured return code at the end
109115
exit "$RET_CODE"
110116
117+
- name: Install SGLANG
118+
run: |
119+
SGLANG_PIN_ID="$(<.github/pins/sglang.txt)"
120+
pip install sglang==$SGLANG_PIN_ID
121+
122+
- name: Run SGLANG attention prefill stage benchmark
123+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'prefill_attention_benchmark.py') }}
124+
run: |
125+
cd benchmarks/third_party/sglang
126+
python prefill_attention_benchmark --reports $REPORTS
127+
128+
source ../../scripts/capture-hw-details.sh
129+
python ../../scripts/build_report.py $REPORTS/prefill-attn-performance.csv $REPORTS/attn-prefill-triton-report.csv --benchmark attn --compiler triton --param_cols "B,N_CTX,H_Q,H_KV,D,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
130+
131+
- name: Run SGLANG attention decode stage benchmark
132+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'decode_attention_benchmark.py') }}
133+
run: |
134+
cd benchmarks/third_party/sglang
135+
python decode_attention_benchmark --reports $REPORTS
136+
137+
source ../../scripts/capture-hw-details.sh
138+
python ../../scripts/build_report.py $REPORTS/decode-attn-performance.csv $REPORTS/attn-decode-triton-report.csv --benchmark attn --compiler triton --param_cols "B,N_CTX,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
139+
140+
- name: Run SGLANG attention append stage benchmark
141+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'decode_attention_benchmark.py') }}
142+
run: |
143+
cd benchmarks/third_party/sglang
144+
python extended_attention_benchmark --reports $REPORTS
145+
146+
source ../../scripts/capture-hw-details.sh
147+
python ../../scripts/build_report.py $REPORTS/extended-attn-performance.csv $REPORTS/attn-append-triton-report.csv --benchmark attn --compiler triton --param_cols "B,N_CTX,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
148+
111149
- name: Upload benchmark reports
112150
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
113151
uses: actions/upload-artifact@v4
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import torch
2+
3+
from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
4+
5+
import triton_kernels_benchmark as benchmark_suit
6+
7+
8+
# pylint: disable=unused-argument
9+
@benchmark_suit.perf_report(
10+
benchmark_suit.Benchmark(
11+
# argument names to use as an x-axis for the plot
12+
x_names=['BATCH', 'SEQ_LENS', 'Q_HEAD_NUM', 'KV_HEAD_NUM', 'HEAD_DIM', 'MODE', 'VALIDATE'],
13+
x_vals=[ #
14+
[bs, [1024, 64], 32, 8, 128, 'fwd', False] for bs in [1, 16, 32, 64, 128]
15+
] + [ #
16+
[bs, [1024, 64], 32, 32, 96, 'fwd', False] for bs in [1, 16, 32, 64, 128]
17+
] + [ #
18+
[bs, [1024, 64], 28, 4, 128, 'fwd', False] for bs in [1, 16, 32, 64, 128]
19+
],
20+
line_arg='provider',
21+
# argument name whose value corresponds to a different line in the plot
22+
# possible values for `line_arg``
23+
line_vals=[
24+
'triton',
25+
],
26+
# label name for the lines
27+
line_names=[
28+
'Triton',
29+
],
30+
# line styles
31+
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
32+
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
33+
plot_name='decode-attn-performance',
34+
# name for the plot. Used also as a file name for saving the plot.
35+
args={},
36+
))
37+
def benchmark(BATCH, SEQ_LENS, Q_HEAD_NUM, KV_HEAD_NUM, HEAD_DIM, MODE, VALIDATE, provider):
38+
torch.manual_seed(0)
39+
dtype = torch.bfloat16
40+
N_CTX = sum(SEQ_LENS)
41+
total_tokens = BATCH * N_CTX
42+
sm_scale = 1.0 / (HEAD_DIM**0.5)
43+
num_kv_splits = 8
44+
45+
# q represents the new token being generated, one per batch
46+
q = torch.randn(BATCH, Q_HEAD_NUM, HEAD_DIM, dtype=dtype, device='xpu')
47+
48+
# k_buffer and v_buffer represent all previous tokens
49+
k_buffer = torch.randn(total_tokens, KV_HEAD_NUM, HEAD_DIM, dtype=dtype, device='xpu')
50+
v_buffer = torch.randn(total_tokens, KV_HEAD_NUM, HEAD_DIM, dtype=dtype, device='xpu')
51+
52+
# o will have the same shape as q
53+
o = torch.zeros(BATCH, Q_HEAD_NUM, HEAD_DIM, dtype=dtype, device='xpu')
54+
55+
b_seq_len = torch.full((BATCH, ), N_CTX, device='xpu')
56+
57+
kv_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device='xpu')
58+
kv_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len[:BATCH], dim=0)
59+
kv_indices = torch.arange(total_tokens, device='xpu')
60+
61+
attn_logits = torch.empty(
62+
(BATCH, Q_HEAD_NUM, num_kv_splits, HEAD_DIM + 1),
63+
dtype=torch.float32,
64+
device='xpu',
65+
)
66+
67+
quantiles = [0.5, 0.0, 1.0]
68+
69+
if provider == 'triton':
70+
triton_fn = lambda: decode_attention_fwd(q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits,
71+
num_kv_splits, sm_scale, logit_cap=0.0)
72+
# TODO: decode attention do not have validation function
73+
if VALIDATE:
74+
raise NotImplementedError('Validation is not implemented for decode stage')
75+
76+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
77+
78+
else:
79+
raise NotImplementedError(f'Unsupported provider {provider}')
80+
81+
tflops = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM * N_CTX) * N_CTX * HEAD_DIM * (1e-12) / (ms * 1e-3)
82+
83+
gbps = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM * N_CTX) * HEAD_DIM * 2 * (1e-9) / (ms * 1e-3)
84+
85+
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
86+
87+
88+
if __name__ == '__main__':
89+
benchmark.run(show_plots=False, print_data=True)
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import torch
2+
from sglang.srt.layers.attention.triton_ops.extend_attention import (
3+
extend_attention_fwd,
4+
redundant_attention,
5+
)
6+
import triton_kernels_benchmark as benchmark_suit
7+
8+
9+
# pylint: disable=unused-argument
10+
@benchmark_suit.perf_report(
11+
benchmark_suit.Benchmark(
12+
# argument names to use as an x-axis for the plot
13+
x_names=['BATCH', 'SEQ_LENS', 'Q_HEAD_NUM', 'KV_HEAD_NUM', 'HEAD_DIM', 'MODE', 'VALIDATE'],
14+
x_vals=[ #
15+
[bs, [1024, 128, 512], 32, 8, 128, 'fwd', True] for bs in [1, 16, 32, 64, 128]
16+
] + [ #
17+
[bs, [1024, 128, 512], 32, 32, 96, 'fwd', True] for bs in [1, 16, 32, 64, 128]
18+
] + [ #
19+
[bs, [1024, 128, 512], 28, 4, 128, 'fwd', True] for bs in [1, 16, 32, 64, 128]
20+
],
21+
line_arg='provider',
22+
# argument name whose value corresponds to a different line in the plot
23+
# possible values for `line_arg``
24+
line_vals=[
25+
'triton',
26+
],
27+
# label name for the lines
28+
line_names=[
29+
'Triton',
30+
],
31+
# line styles
32+
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
33+
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
34+
plot_name='extended-attn-performance',
35+
# name for the plot. Used also as a file name for saving the plot.
36+
args={},
37+
))
38+
def benchmark(BATCH, SEQ_LENS, Q_HEAD_NUM, KV_HEAD_NUM, HEAD_DIM, MODE, VALIDATE, provider):
39+
torch.manual_seed(0)
40+
dtype = torch.bfloat16
41+
N_CTX = sum(SEQ_LENS)
42+
43+
b_seq_len_prefix = torch.randint(1, N_CTX // 2, (BATCH, ), dtype=torch.int32, device='xpu')
44+
b_seq_len_extend = torch.randint(1, N_CTX // 2, (BATCH, ), dtype=torch.int32, device='xpu')
45+
b_seq_len = b_seq_len_prefix + b_seq_len_extend
46+
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
47+
48+
b_req_idx = torch.arange(BATCH, dtype=torch.int32, device='xpu')
49+
b_start_loc = torch.zeros((BATCH, ), dtype=torch.int32, device='xpu')
50+
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
51+
b_start_loc_extend = torch.zeros((BATCH, ), dtype=torch.int32, device='xpu')
52+
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
53+
54+
kv_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device='xpu')
55+
kv_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len_prefix[:BATCH], dim=0)
56+
kv_indices = torch.zeros((b_seq_len_prefix.sum().item(), ), dtype=torch.int32, device='xpu')
57+
58+
for i in range(BATCH):
59+
kv_indices[kv_indptr[i]:kv_indptr[i + 1]] = torch.arange(b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i])
60+
61+
total_token_num = torch.sum(b_seq_len).item()
62+
extend_token_num = torch.sum(b_seq_len_extend).item()
63+
k_buffer = torch.empty((total_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype,
64+
device='xpu').normal_(mean=0.1, std=0.2)
65+
v_buffer = torch.empty((total_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype,
66+
device='xpu').normal_(mean=0.1, std=0.2)
67+
68+
k_extend = torch.empty((extend_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu')
69+
v_extend = torch.empty((extend_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu')
70+
q_extend = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu')
71+
for i in range(BATCH):
72+
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
73+
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
74+
extend_start = b_start_loc_extend[i]
75+
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
76+
k_extend[extend_start:extend_end] = k_buffer[extend_start_in_buffer:extend_end_in_buffer]
77+
v_extend[extend_start:extend_end] = v_buffer[extend_start_in_buffer:extend_end_in_buffer]
78+
q_extend[extend_start:extend_end] = torch.empty((b_seq_len_extend[i], Q_HEAD_NUM, HEAD_DIM), dtype=dtype,
79+
device='xpu').normal_(mean=0.1, std=0.2)
80+
81+
o_extend = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu')
82+
o_redundant = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu')
83+
84+
b_seq_len_extend = b_seq_len - b_seq_len_prefix
85+
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
86+
qo_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device='xpu')
87+
qo_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len_extend[:BATCH], dim=0)
88+
89+
custom_mask = None
90+
mask_indptr = None
91+
92+
quantiles = [0.5, 0.0, 1.0]
93+
if provider == 'triton':
94+
95+
def triton_fn():
96+
extend_attention_fwd(q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, kv_indptr,
97+
kv_indices, custom_mask, mask_indptr, max_len_extend)
98+
return o_extend
99+
100+
# TODO: decode attention do not have validation function
101+
if VALIDATE:
102+
103+
def refer_fn():
104+
redundant_attention(q_extend, o_redundant, k_buffer, v_buffer, b_req_idx, b_start_loc, b_seq_len,
105+
b_seq_len_prefix, max_len_in_batch)
106+
return o_redundant
107+
108+
benchmark_suit.assert_close(triton_fn, refer_fn, atol=1e-3, rtol=1e-2, err_msg='extend to refer')
109+
110+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
111+
112+
else:
113+
raise NotImplementedError(f'Unsupported provider {provider}')
114+
115+
tflops = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM * N_CTX) * N_CTX * HEAD_DIM * (1e-12) / (ms * 1e-3)
116+
117+
gbps = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM * N_CTX) * HEAD_DIM * 2 * (1e-9) / (ms * 1e-3)
118+
119+
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
120+
121+
122+
if __name__ == '__main__':
123+
benchmark.run(show_plots=False, print_data=True)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import torch
2+
3+
from sglang.srt.layers.attention.triton_ops.prefill_attention import context_attention_fwd
4+
5+
import triton_kernels_benchmark as benchmark_suit
6+
7+
8+
# pylint: disable=unused-argument
9+
@benchmark_suit.perf_report(
10+
benchmark_suit.Benchmark(
11+
# argument names to use as an x-axis for the plot
12+
x_names=['BATCH', 'SEQ_LENS', 'Q_HEAD_NUM', 'KV_HEAD_NUM', 'HEAD_DIM', 'CAUSAL', 'MODE', 'VALIDATE'],
13+
x_vals=[ #
14+
[bs, [1024], 32, 8, 128, causal, 'fwd', False] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
15+
] + [ # noqa
16+
[bs, [1024], 32, 32, 96, causal, 'fwd', False] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
17+
] + [ # noqa
18+
[bs, [1024], 28, 4, 128, causal, 'fwd', False] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
19+
],
20+
line_arg='provider',
21+
# argument name whose value corresponds to a different line in the plot
22+
# possible values for `line_arg``
23+
line_vals=[
24+
'triton',
25+
],
26+
# label name for the lines
27+
line_names=[
28+
'Triton',
29+
],
30+
# line styles
31+
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
32+
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
33+
plot_name='prefill-attn-performance',
34+
# name for the plot. Used also as a file name for saving the plot.
35+
args={},
36+
))
37+
def benchmark(BATCH, SEQ_LENS, Q_HEAD_NUM, KV_HEAD_NUM, HEAD_DIM, CAUSAL, MODE, VALIDATE, provider):
38+
torch.manual_seed(0)
39+
dtype = torch.bfloat16
40+
device = 'xpu'
41+
N_CTX = sum(SEQ_LENS)
42+
max_seq_len = max(SEQ_LENS)
43+
44+
# Create random input tensors
45+
q = torch.randn((BATCH * N_CTX, Q_HEAD_NUM, HEAD_DIM), device=device, dtype=dtype)
46+
k = torch.randn((BATCH * N_CTX, KV_HEAD_NUM, HEAD_DIM), device=device, dtype=dtype)
47+
v = torch.randn((BATCH * N_CTX, KV_HEAD_NUM, HEAD_DIM), device=device, dtype=dtype)
48+
o = torch.zeros((BATCH * N_CTX, Q_HEAD_NUM, HEAD_DIM), device=device, dtype=dtype)
49+
50+
# Create b_start_loc and b_seq_len tensors
51+
b_start_loc = torch.tensor([0, SEQ_LENS[0]], device=device)
52+
b_seq_len = torch.tensor(SEQ_LENS, device=device)
53+
54+
quantiles = [0.5, 0.0, 1.0]
55+
if provider == 'triton':
56+
57+
triton_fn = lambda: context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=CAUSAL)
58+
59+
if VALIDATE:
60+
# FIXME: torch sdpa does not support different Q_HEAD_NUM and KV_HEAD_NUM
61+
cu_seq_lens = [0] * (len(SEQ_LENS) + 1)
62+
for i, seq_len in enumerate(SEQ_LENS):
63+
cu_seq_lens[i + 1] = cu_seq_lens[i] + seq_len
64+
65+
for i in range(len(SEQ_LENS)):
66+
start, end = cu_seq_lens[i], cu_seq_lens[i + 1]
67+
o_torch = torch.nn.functional.scaled_dot_product_attention(
68+
q[start:end].permute(1, 0, 2),
69+
k[start:end].permute(1, 0, 2),
70+
v[start:end].permute(1, 0, 2),
71+
is_causal=CAUSAL,
72+
).permute(1, 0, 2)
73+
74+
cos_sim = torch.nn.functional.cosine_similarity(o[start:end].flatten(), o_torch.flatten(), dim=0)
75+
assert cos_sim.item() > 1 - (1e-5)
76+
assert torch.allclose(o[start:end], o_torch, atol=1e-2)
77+
78+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
79+
quantiles=quantiles)
80+
81+
else:
82+
raise NotImplementedError(f'Unsupported provider {provider}')
83+
84+
tflops = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM) * N_CTX * N_CTX * HEAD_DIM * (1e-12) / (ms * 1e-3)
85+
86+
gbps = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM) * N_CTX * HEAD_DIM * 2 * (1e-9) / (ms * 1e-3)
87+
88+
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv
89+
90+
91+
if __name__ == '__main__':
92+
benchmark.run(show_plots=False, print_data=True)

python/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from distutils.command.install import install
2525
from setuptools.command.develop import develop
2626
from setuptools.command.egg_info import egg_info
27-
from wheel.bdist_wheel import bdist_wheel
27+
from setuptools.command.bdist_wheel import bdist_wheel
2828

2929
import pybind11
3030

0 commit comments

Comments
 (0)