Skip to content

Commit ead512a

Browse files
Address review comments
1 parent 211e82d commit ead512a

File tree

4 files changed

+73
-39
lines changed

4 files changed

+73
-39
lines changed

.github/workflows/third-party-benchmarks.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ jobs:
139139
python prefill_attention_benchmark.py --reports $REPORTS
140140
141141
source ../../../scripts/capture-hw-details.sh
142-
python ../../triton_kernels_benchmark/build_report.py $REPORTS/prefill-attn-performance.csv $REPORTS/attn-prefill-triton-report.csv --benchmark sglang-prefill-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
142+
python ../../triton_kernels_benchmark/build_report.py $REPORTS/sglang-prefill-attn-performance.csv $REPORTS/sglang-prefill-attn-triton-report.csv --benchmark sglang-prefill-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
143143
144144
- name: Run SGLANG attention decode stage benchmark
145145
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
@@ -149,7 +149,7 @@ jobs:
149149
python decode_attention_benchmark.py --reports $REPORTS
150150
151151
source ../../../scripts/capture-hw-details.sh
152-
python ../../triton_kernels_benchmark/build_report.py $REPORTS/decode-attn-performance.csv $REPORTS/attn-decode-triton-report.csv --benchmark sglang-decode-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
152+
python ../../triton_kernels_benchmark/build_report.py $REPORTS/sglang-decode-attn-performance.csv $REPORTS/sglang-decode-attn-triton-report.csv --benchmark sglang-decode-attn --compiler triton --param_cols "B,SEQ_LENS,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
153153
154154
- name: Run SGLANG attention append stage benchmark
155155
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
@@ -159,7 +159,7 @@ jobs:
159159
python extended_attention_benchmark.py --reports $REPORTS
160160
161161
source ../../../scripts/capture-hw-details.sh
162-
python ../../triton_kernels_benchmark/build_report.py $REPORTS/extended-attn-performance.csv $REPORTS/attn-append-triton-report.csv --benchmark sglang-extended-attn --compiler triton --param_cols "B,Q_LEN,PREFIX_LEN,KV_LEN,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
162+
python ../../triton_kernels_benchmark/build_report.py $REPORTS/sglang-extended-attn-performance.csv $REPORTS/sglang-append-attn-triton-report.csv --benchmark sglang-extended-attn --compiler triton --param_cols "B,Q_LEN,PREFIX_LEN,KV_LEN,H_Q,H_KV,D" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
163163
164164
- name: Run SGLANG Block FP8 GEMM benchmark
165165
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'block_fp8_gemm_benchmark.py') }}

benchmarks/third_party/sglang/decode_attention_benchmark.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,29 @@ def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device):
4343
sm_scale)
4444

4545

46+
def get_dtype(dtype_str: str):
47+
if dtype_str == 'bfloat16':
48+
return torch.bfloat16
49+
if dtype_str == 'float16':
50+
return torch.float16
51+
if dtype_str == 'float32':
52+
return torch.float32
53+
raise ValueError(f'Unsupported dtype: {dtype_str}')
54+
55+
56+
X_VALS = [[bs, *sizes, mode, dtype]
57+
for sizes in [(1024 + 64, 32, 8, 128), (1024 + 64, 32, 32, 96), (1024 + 64, 28, 4, 128)]
58+
for bs in [1, 16, 32, 64, 128]
59+
for mode in ['fwd']
60+
for dtype in ['bfloat16']]
61+
62+
4663
# pylint: disable=unused-argument
4764
@benchmark_suit.perf_report(
4865
benchmark_suit.Benchmark(
4966
# argument names to use as an x-axis for the plot
50-
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE'],
51-
x_vals=[ #
52-
[bs, 1024 + 64, 32, 8, 128, 'fwd'] for bs in [1, 16, 32, 64, 128]
53-
] + [ #
54-
[bs, 1024 + 64, 32, 32, 96, 'fwd'] for bs in [1, 16, 32, 64, 128]
55-
] + [ #
56-
[bs, 1024 + 64, 28, 4, 128, 'fwd'] for bs in [1, 16, 32, 64, 128]
57-
],
67+
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE', 'DTYPE'],
68+
x_vals=X_VALS,
5869
line_arg='provider',
5970
# argument name whose value corresponds to a different line in the plot
6071
# possible values for `line_arg``
@@ -68,19 +79,19 @@ def gen_args(B, N_CTX, H_Q, H_KV, D, dtype, device):
6879
# line styles
6980
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
7081
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
71-
plot_name='decode-attn-performance',
82+
plot_name='sglang-decode-attn-performance',
7283
# name for the plot. Used also as a file name for saving the plot.
7384
args={},
7485
))
75-
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, provider):
86+
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, DTYPE, provider):
7687
torch.manual_seed(0)
77-
dtype = torch.bfloat16
78-
quantiles = [0.5, 0.0, 1.0]
79-
N_CTX = SEQ_LENS
88+
dtype = get_dtype(DTYPE)
8089

90+
N_CTX = SEQ_LENS
8191
q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse, num_kv_splits, max_kv_splits, sm_scale = gen_args(
8292
B, N_CTX, H_Q, H_KV, D, dtype, 'xpu')
8393

94+
quantiles = [0.5, 0.0, 1.0]
8495
if provider == 'triton' and MODE == 'fwd':
8596
triton_fn = lambda: decode_attention_fwd(q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse,
8697
num_kv_splits, max_kv_splits, sm_scale)

benchmarks/third_party/sglang/extended_attention_benchmark.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,30 @@ def gen_args(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, dtype, device):
5555
return params
5656

5757

58+
def get_dtype(dtype_str: str):
59+
if dtype_str == 'bfloat16':
60+
return torch.bfloat16
61+
if dtype_str == 'float16':
62+
return torch.float16
63+
if dtype_str == 'float32':
64+
return torch.float32
65+
raise ValueError(f'Unsupported dtype: {dtype_str}')
66+
67+
68+
X_VALS = [[bs, *sizes, mode, dtype]
69+
for sizes in [(512, 1024 + 128, 512, 32, 8, 128), (512, 1024 + 128, 512, 32, 32,
70+
96), (512, 1024 + 128, 512, 28, 4, 128)]
71+
for bs in [1, 16, 32, 64, 128]
72+
for mode in ['fwd']
73+
for dtype in ['bfloat16']]
74+
75+
5876
# pylint: disable=unused-argument
5977
@benchmark_suit.perf_report(
6078
benchmark_suit.Benchmark(
6179
# argument names to use as an x-axis for the plot
62-
x_names=['B', 'Q_LEN', 'PREFIX_LEN', 'KV_LEN', 'H_Q', 'H_KV', 'D', 'MODE'],
63-
x_vals=[ #
64-
[bs, 512, 1024 + 128, 512, 32, 8, 128, 'fwd'] for bs in [1, 16, 32, 64, 128]
65-
] + [ #
66-
[bs, 512, 1024 + 128, 512, 32, 32, 96, 'fwd'] for bs in [1, 16, 32, 64, 128]
67-
] + [ #
68-
[bs, 512, 1024 + 128, 512, 28, 4, 128, 'fwd'] for bs in [1, 16, 32, 64, 128]
69-
],
80+
x_names=['B', 'Q_LEN', 'PREFIX_LEN', 'KV_LEN', 'H_Q', 'H_KV', 'D', 'MODE', 'DTYPE'],
81+
x_vals=X_VALS,
7082
line_arg='provider',
7183
# argument name whose value corresponds to a different line in the plot
7284
# possible values for `line_arg``
@@ -80,14 +92,13 @@ def gen_args(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, dtype, device):
8092
# line styles
8193
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
8294
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
83-
plot_name='extended-attn-performance',
95+
plot_name='sglang-extended-attn-performance',
8496
# name for the plot. Used also as a file name for saving the plot.
8597
args={},
8698
))
87-
def benchmark(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, MODE, provider):
99+
def benchmark(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, MODE, DTYPE, provider):
88100
torch.manual_seed(0)
89-
90-
dtype = torch.bfloat16
101+
dtype = get_dtype(DTYPE)
91102

92103
params = gen_args(B, Q_LEN, PREFIX_LEN, KV_LEN, H_Q, H_KV, D, dtype, 'xpu')
93104
q_extend, k_extend, v_extend, o_extend = params[0]

benchmarks/third_party/sglang/prefill_attention_benchmark.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,30 @@ def gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, device):
2222
return (q, k, v, o, b_start_loc, b_seq_len, max_seq_len)
2323

2424

25+
def get_dtype(dtype_str: str):
26+
if dtype_str == 'bfloat16':
27+
return torch.bfloat16
28+
if dtype_str == 'float16':
29+
return torch.float16
30+
if dtype_str == 'float32':
31+
return torch.float32
32+
raise ValueError(f'Unsupported dtype: {dtype_str}')
33+
34+
35+
X_VALS = [[bs, *sizes, causal, mode, dtype]
36+
for sizes in [(1024, 32, 8, 128), (1024, 32, 32, 96), (1024, 28, 4, 128)]
37+
for bs in [1, 16, 32, 64, 128]
38+
for causal in [True, False]
39+
for mode in ['fwd']
40+
for dtype in ['bfloat16']]
41+
42+
2543
# pylint: disable=unused-argument
2644
@benchmark_suit.perf_report(
2745
benchmark_suit.Benchmark(
2846
# argument names to use as an x-axis for the plot
29-
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'CAUSAL', 'MODE'],
30-
x_vals=[ #
31-
[bs, 1024, 32, 8, 128, causal, 'fwd'] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
32-
] + [ #
33-
[bs, 1024, 32, 32, 96, causal, 'fwd'] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
34-
] + [ #
35-
[bs, 1024, 28, 4, 128, causal, 'fwd'] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
36-
],
47+
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'CAUSAL', 'MODE', 'DTYPE'],
48+
x_vals=X_VALS,
3749
line_arg='provider',
3850
# argument name whose value corresponds to a different line in the plot
3951
# possible values for `line_arg``
@@ -47,13 +59,13 @@ def gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, device):
4759
# line styles
4860
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
4961
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
50-
plot_name='prefill-attn-performance',
62+
plot_name='sglang-prefill-attn-performance',
5163
# name for the plot. Used also as a file name for saving the plot.
5264
args={},
5365
))
54-
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, CAUSAL, MODE, provider):
66+
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, CAUSAL, MODE, DTYPE, provider):
5567
torch.manual_seed(0)
56-
dtype = torch.bfloat16
68+
dtype = get_dtype(DTYPE)
5769

5870
q, k, v, o, b_start_loc, b_seq_len, max_seq_len = gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, 'xpu')
5971

0 commit comments

Comments
 (0)