Skip to content

Commit 32f62fe

Browse files
LiyangLingIntelLiyang Ling
authored and
Liyang Ling
committed
Integrate sglang prefill/decode/extend kernel to benchmarks
Port prefill attn and decode attn from sglang Add validation temp add extend attention disable debug ir dump Update three stage attention benchmark Add sglang kernel benchmark to action use 1e-3 atol remove sglang benchmark from triton-benchmarks Fix setup bdist_wheel
1 parent 9856962 commit 32f62fe

File tree

6 files changed

+346
-3
lines changed

6 files changed

+346
-3
lines changed

.github/pins/sglang.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
0.4.4.post1

.github/workflows/third-party-benchmarks.yml

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,14 @@ jobs:
7272
- name: Setup Triton
7373
uses: ./.github/actions/setup-triton
7474

75-
- name: Install benchmark dependencies
75+
- name: Install benchmarks
7676
id: install
77+
run: |
78+
cd benchmarks
79+
pip install .
80+
81+
- name: Install benchmark dependencies
82+
id: install_deps
7783
run: |
7884
pip install transformers pandas pytest
7985
@@ -83,7 +89,7 @@ jobs:
8389
echo "REPORTS=$PWD/reports" >> $GITHUB_ENV
8490
8591
- name: Run Liger-Kernel benchmarks
86-
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
92+
if: ${{ steps.install_deps.outcome == 'success' && !cancelled() }}
8793
run: |
8894
source ./scripts/capture-hw-details.sh
8995
@@ -102,6 +108,38 @@ jobs:
102108
# Return the captured return code at the end
103109
exit "$RET_CODE"
104110
111+
- name: Install SGLANG
112+
run: |
113+
SGLANG_PIN_ID="$(<.github/pins/sglang.txt)"
114+
pip install sglang==$SGLANG_PIN_ID
115+
116+
- name: Run SGLANG attention prefill stage benchmark
117+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'prefill_attention_benchmark.py') }}
118+
run: |
119+
cd benchmarks/third_party/sglang
120+
python prefill_attention_benchmark --reports $REPORTS
121+
122+
source ../../scripts/capture-hw-details.sh
123+
python ../../scripts/build_report.py $REPORTS/prefill-attn-performance.csv $REPORTS/attn-prefill-triton-report.csv --benchmark attn --compiler triton --param_cols "B,N_CTX,H_Q,H_KV,D,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
124+
125+
- name: Run SGLANG attention decode stage benchmark
126+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'decode_attention_benchmark.py') }}
127+
run: |
128+
cd benchmarks/third_party/sglang
129+
python decode_attention_benchmark --reports $REPORTS
130+
131+
source ../../scripts/capture-hw-details.sh
132+
python ../../scripts/build_report.py $REPORTS/decode-attn-performance.csv $REPORTS/attn-decode-triton-report.csv --benchmark attn --compiler triton --param_cols "B,N_CTX,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
133+
134+
- name: Run SGLANG attention append stage benchmark
135+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'decode_attention_benchmark.py') }}
136+
run: |
137+
cd benchmarks/third_party/sglang
138+
python extended_attention_benchmark --reports $REPORTS
139+
140+
source ../../scripts/capture-hw-details.sh
141+
python ../../scripts/build_report.py $REPORTS/extended-attn-performance.csv $REPORTS/attn-append-triton-report.csv --benchmark attn --compiler triton --param_cols "B,N_CTX,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
142+
105143
- name: Upload benchmark reports
106144
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
107145
uses: actions/upload-artifact@v4
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import torch
2+
3+
from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
4+
5+
import triton_kernels_benchmark as benchmark_suit
6+
7+
8+
# pylint: disable=unused-argument
9+
@benchmark_suit.perf_report(
10+
benchmark_suit.Benchmark(
11+
# argument names to use as an x-axis for the plot
12+
x_names=['BATCH', 'SEQ_LENS', 'Q_HEAD_NUM', 'KV_HEAD_NUM', 'HEAD_DIM', 'MODE', 'VALIDATE'],
13+
x_vals=[ #
14+
[bs, [1024, 64], 32, 8, 128, 'fwd', False] for bs in [1, 16, 32, 64, 128]
15+
] + [ #
16+
[bs, [1024, 64], 32, 32, 96, 'fwd', False] for bs in [1, 16, 32, 64, 128]
17+
] + [ #
18+
[bs, [1024, 64], 28, 4, 128, 'fwd', False] for bs in [1, 16, 32, 64, 128]
19+
],
20+
line_arg='provider',
21+
# argument name whose value corresponds to a different line in the plot
22+
# possible values for `line_arg``
23+
line_vals=[
24+
'triton',
25+
],
26+
# label name for the lines
27+
line_names=[
28+
'Triton',
29+
],
30+
# line styles
31+
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
32+
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
33+
plot_name='decode-attn-performance',
34+
# name for the plot. Used also as a file name for saving the plot.
35+
args={},
36+
))
37+
def benchmark(BATCH, SEQ_LENS, Q_HEAD_NUM, KV_HEAD_NUM, HEAD_DIM, MODE, VALIDATE, provider):
38+
torch.manual_seed(0)
39+
dtype = torch.bfloat16
40+
N_CTX = sum(SEQ_LENS)
41+
total_tokens = BATCH * N_CTX
42+
sm_scale = 1.0 / (HEAD_DIM**0.5)
43+
num_kv_splits = 8
44+
45+
# q represents the new token being generated, one per batch
46+
q = torch.randn(BATCH, Q_HEAD_NUM, HEAD_DIM, dtype=dtype, device='xpu')
47+
48+
# k_buffer and v_buffer represent all previous tokens
49+
k_buffer = torch.randn(total_tokens, KV_HEAD_NUM, HEAD_DIM, dtype=dtype, device='xpu')
50+
v_buffer = torch.randn(total_tokens, KV_HEAD_NUM, HEAD_DIM, dtype=dtype, device='xpu')
51+
52+
# o will have the same shape as q
53+
o = torch.zeros(BATCH, Q_HEAD_NUM, HEAD_DIM, dtype=dtype, device='xpu')
54+
55+
b_seq_len = torch.full((BATCH, ), N_CTX, device='xpu')
56+
57+
kv_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device='xpu')
58+
kv_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len[:BATCH], dim=0)
59+
kv_indices = torch.arange(total_tokens, device='xpu')
60+
61+
attn_logits = torch.empty(
62+
(BATCH, Q_HEAD_NUM, num_kv_splits, HEAD_DIM + 1),
63+
dtype=torch.float32,
64+
device='xpu',
65+
)
66+
67+
quantiles = [0.5, 0.0, 1.0]
68+
69+
if provider == 'triton':
70+
triton_fn = lambda: decode_attention_fwd(q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits,
71+
num_kv_splits, sm_scale, logit_cap=0.0)
72+
# TODO: decode attention do not have validation function
73+
if VALIDATE:
74+
raise NotImplementedError('Validation is not implemented for decode stage')
75+
76+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
77+
78+
else:
79+
raise NotImplementedError(f'Unsupported provider {provider}')
80+
81+
tflops = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM * N_CTX) * N_CTX * HEAD_DIM * (1e-12) / (ms * 1e-3)
82+
83+
gbps = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM * N_CTX) * HEAD_DIM * 2 * (1e-9) / (ms * 1e-3)
84+
85+
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
86+
87+
88+
if __name__ == '__main__':
89+
benchmark.run(show_plots=False, print_data=True)
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import torch
2+
from sglang.srt.layers.attention.triton_ops.extend_attention import (
3+
extend_attention_fwd,
4+
redundant_attention,
5+
)
6+
import triton_kernels_benchmark as benchmark_suit
7+
8+
9+
# pylint: disable=unused-argument
10+
@benchmark_suit.perf_report(
11+
benchmark_suit.Benchmark(
12+
# argument names to use as an x-axis for the plot
13+
x_names=['BATCH', 'SEQ_LENS', 'Q_HEAD_NUM', 'KV_HEAD_NUM', 'HEAD_DIM', 'MODE', 'VALIDATE'],
14+
x_vals=[ #
15+
[bs, [1024, 128, 512], 32, 8, 128, 'fwd', True] for bs in [1, 16, 32, 64, 128]
16+
] + [ #
17+
[bs, [1024, 128, 512], 32, 32, 96, 'fwd', True] for bs in [1, 16, 32, 64, 128]
18+
] + [ #
19+
[bs, [1024, 128, 512], 28, 4, 128, 'fwd', True] for bs in [1, 16, 32, 64, 128]
20+
],
21+
line_arg='provider',
22+
# argument name whose value corresponds to a different line in the plot
23+
# possible values for `line_arg``
24+
line_vals=[
25+
'triton',
26+
],
27+
# label name for the lines
28+
line_names=[
29+
'Triton',
30+
],
31+
# line styles
32+
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
33+
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
34+
plot_name='extended-attn-performance',
35+
# name for the plot. Used also as a file name for saving the plot.
36+
args={},
37+
))
38+
def benchmark(BATCH, SEQ_LENS, Q_HEAD_NUM, KV_HEAD_NUM, HEAD_DIM, MODE, VALIDATE, provider):
39+
torch.manual_seed(0)
40+
dtype = torch.bfloat16
41+
N_CTX = sum(SEQ_LENS)
42+
43+
b_seq_len_prefix = torch.randint(1, N_CTX // 2, (BATCH, ), dtype=torch.int32, device='xpu')
44+
b_seq_len_extend = torch.randint(1, N_CTX // 2, (BATCH, ), dtype=torch.int32, device='xpu')
45+
b_seq_len = b_seq_len_prefix + b_seq_len_extend
46+
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
47+
48+
b_req_idx = torch.arange(BATCH, dtype=torch.int32, device='xpu')
49+
b_start_loc = torch.zeros((BATCH, ), dtype=torch.int32, device='xpu')
50+
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
51+
b_start_loc_extend = torch.zeros((BATCH, ), dtype=torch.int32, device='xpu')
52+
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
53+
54+
kv_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device='xpu')
55+
kv_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len_prefix[:BATCH], dim=0)
56+
kv_indices = torch.zeros((b_seq_len_prefix.sum().item(), ), dtype=torch.int32, device='xpu')
57+
58+
for i in range(BATCH):
59+
kv_indices[kv_indptr[i]:kv_indptr[i + 1]] = torch.arange(b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i])
60+
61+
total_token_num = torch.sum(b_seq_len).item()
62+
extend_token_num = torch.sum(b_seq_len_extend).item()
63+
k_buffer = torch.empty((total_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype,
64+
device='xpu').normal_(mean=0.1, std=0.2)
65+
v_buffer = torch.empty((total_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype,
66+
device='xpu').normal_(mean=0.1, std=0.2)
67+
68+
k_extend = torch.empty((extend_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu')
69+
v_extend = torch.empty((extend_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu')
70+
q_extend = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu')
71+
for i in range(BATCH):
72+
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
73+
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
74+
extend_start = b_start_loc_extend[i]
75+
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
76+
k_extend[extend_start:extend_end] = k_buffer[extend_start_in_buffer:extend_end_in_buffer]
77+
v_extend[extend_start:extend_end] = v_buffer[extend_start_in_buffer:extend_end_in_buffer]
78+
q_extend[extend_start:extend_end] = torch.empty((b_seq_len_extend[i], Q_HEAD_NUM, HEAD_DIM), dtype=dtype,
79+
device='xpu').normal_(mean=0.1, std=0.2)
80+
81+
o_extend = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu')
82+
o_redundant = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu')
83+
84+
b_seq_len_extend = b_seq_len - b_seq_len_prefix
85+
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
86+
qo_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device='xpu')
87+
qo_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len_extend[:BATCH], dim=0)
88+
89+
custom_mask = None
90+
mask_indptr = None
91+
92+
quantiles = [0.5, 0.0, 1.0]
93+
if provider == 'triton':
94+
95+
def triton_fn():
96+
extend_attention_fwd(q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, kv_indptr,
97+
kv_indices, custom_mask, mask_indptr, max_len_extend)
98+
return o_extend
99+
100+
# TODO: decode attention do not have validation function
101+
if VALIDATE:
102+
103+
def refer_fn():
104+
redundant_attention(q_extend, o_redundant, k_buffer, v_buffer, b_req_idx, b_start_loc, b_seq_len,
105+
b_seq_len_prefix, max_len_in_batch)
106+
return o_redundant
107+
108+
benchmark_suit.assert_close(triton_fn, refer_fn, atol=1e-3, rtol=1e-2, err_msg='extend to refer')
109+
110+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
111+
112+
else:
113+
raise NotImplementedError(f'Unsupported provider {provider}')
114+
115+
tflops = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM * N_CTX) * N_CTX * HEAD_DIM * (1e-12) / (ms * 1e-3)
116+
117+
gbps = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM * N_CTX) * HEAD_DIM * 2 * (1e-9) / (ms * 1e-3)
118+
119+
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
120+
121+
122+
if __name__ == '__main__':
123+
benchmark.run(show_plots=False, print_data=True)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import torch
2+
3+
from sglang.srt.layers.attention.triton_ops.prefill_attention import context_attention_fwd
4+
5+
import triton_kernels_benchmark as benchmark_suit
6+
7+
8+
# pylint: disable=unused-argument
9+
@benchmark_suit.perf_report(
10+
benchmark_suit.Benchmark(
11+
# argument names to use as an x-axis for the plot
12+
x_names=['BATCH', 'SEQ_LENS', 'Q_HEAD_NUM', 'KV_HEAD_NUM', 'HEAD_DIM', 'CAUSAL', 'MODE', 'VALIDATE'],
13+
x_vals=[ #
14+
[bs, [1024], 32, 8, 128, causal, 'fwd', False] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
15+
] + [ # noqa
16+
[bs, [1024], 32, 32, 96, causal, 'fwd', False] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
17+
] + [ # noqa
18+
[bs, [1024], 28, 4, 128, causal, 'fwd', False] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
19+
],
20+
line_arg='provider',
21+
# argument name whose value corresponds to a different line in the plot
22+
# possible values for `line_arg``
23+
line_vals=[
24+
'triton',
25+
],
26+
# label name for the lines
27+
line_names=[
28+
'Triton',
29+
],
30+
# line styles
31+
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
32+
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
33+
plot_name='prefill-attn-performance',
34+
# name for the plot. Used also as a file name for saving the plot.
35+
args={},
36+
))
37+
def benchmark(BATCH, SEQ_LENS, Q_HEAD_NUM, KV_HEAD_NUM, HEAD_DIM, CAUSAL, MODE, VALIDATE, provider):
38+
torch.manual_seed(0)
39+
dtype = torch.bfloat16
40+
device = 'xpu'
41+
N_CTX = sum(SEQ_LENS)
42+
max_seq_len = max(SEQ_LENS)
43+
44+
# Create random input tensors
45+
q = torch.randn((BATCH * N_CTX, Q_HEAD_NUM, HEAD_DIM), device=device, dtype=dtype)
46+
k = torch.randn((BATCH * N_CTX, KV_HEAD_NUM, HEAD_DIM), device=device, dtype=dtype)
47+
v = torch.randn((BATCH * N_CTX, KV_HEAD_NUM, HEAD_DIM), device=device, dtype=dtype)
48+
o = torch.zeros((BATCH * N_CTX, Q_HEAD_NUM, HEAD_DIM), device=device, dtype=dtype)
49+
50+
# Create b_start_loc and b_seq_len tensors
51+
b_start_loc = torch.tensor([0, SEQ_LENS[0]], device=device)
52+
b_seq_len = torch.tensor(SEQ_LENS, device=device)
53+
54+
quantiles = [0.5, 0.0, 1.0]
55+
if provider == 'triton':
56+
57+
triton_fn = lambda: context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=CAUSAL)
58+
59+
if VALIDATE:
60+
# FIXME: torch sdpa does not support different Q_HEAD_NUM and KV_HEAD_NUM
61+
cu_seq_lens = [0] * (len(SEQ_LENS) + 1)
62+
for i, seq_len in enumerate(SEQ_LENS):
63+
cu_seq_lens[i + 1] = cu_seq_lens[i] + seq_len
64+
65+
for i in range(len(SEQ_LENS)):
66+
start, end = cu_seq_lens[i], cu_seq_lens[i + 1]
67+
o_torch = torch.nn.functional.scaled_dot_product_attention(
68+
q[start:end].permute(1, 0, 2),
69+
k[start:end].permute(1, 0, 2),
70+
v[start:end].permute(1, 0, 2),
71+
is_causal=CAUSAL,
72+
).permute(1, 0, 2)
73+
74+
cos_sim = torch.nn.functional.cosine_similarity(o[start:end].flatten(), o_torch.flatten(), dim=0)
75+
assert cos_sim.item() > 1 - (1e-5)
76+
assert torch.allclose(o[start:end], o_torch, atol=1e-2)
77+
78+
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
79+
quantiles=quantiles)
80+
81+
else:
82+
raise NotImplementedError(f'Unsupported provider {provider}')
83+
84+
tflops = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM) * N_CTX * N_CTX * HEAD_DIM * (1e-12) / (ms * 1e-3)
85+
86+
gbps = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM) * N_CTX * HEAD_DIM * 2 * (1e-9) / (ms * 1e-3)
87+
88+
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv
89+
90+
91+
if __name__ == '__main__':
92+
benchmark.run(show_plots=False, print_data=True)

python/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from distutils.command.install import install
2525
from setuptools.command.develop import develop
2626
from setuptools.command.egg_info import egg_info
27-
from wheel.bdist_wheel import bdist_wheel
27+
from setuptools.command.bdist_wheel import bdist_wheel
2828

2929
import pybind11
3030

0 commit comments

Comments
 (0)