Skip to content

Commit 29e2711

Browse files
LiyangLingIntelLiyang Ling
authored and
Liyang Ling
committed
Add sglang to thirdparty test
1 parent 32f62fe commit 29e2711

File tree

8 files changed

+389
-109
lines changed

8 files changed

+389
-109
lines changed

.github/pins/sglang.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.4.4.post1
1+
0.4.5

.github/workflows/third-party-benchmarks.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ jobs:
110110
111111
- name: Install SGLANG
112112
run: |
113-
SGLANG_PIN_ID="$(<.github/pins/sglang.txt)"
114-
pip install sglang==$SGLANG_PIN_ID
113+
SGLANG_PIN="$(<.github/pins/sglang.txt)"
114+
pip install sglang==$SGLANG_PIN
115115
116116
- name: Run SGLANG attention prefill stage benchmark
117117
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'prefill_attention_benchmark.py') }}

.github/workflows/third-party-tests.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,19 @@ jobs:
9696
9797
pytest Liger-Kernel/test/
9898
99+
- name: Run SGLANG tests
100+
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
101+
run: |
102+
pip install transformers pandas pytest openai
103+
104+
SGLANG_PIN="$(<.github/pins/sglang.txt)"
105+
pip install datasets decord sglang==$SGLANG_PIN
106+
git clone https://github.com/sgl-project/sglang.git --branch $SGLANG_PIN --single-branch
107+
108+
cd sglang
109+
git apply ../benchmarks/third_party/sglang/sglang.patch
110+
pytest sglang/test/srt/test_triton_attention_kernels.py
111+
99112
- name: Upload test report
100113
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
101114
uses: actions/upload-artifact@v4

benchmarks/third_party/sglang/decode_attention_benchmark.py

Lines changed: 60 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,49 @@
55
import triton_kernels_benchmark as benchmark_suit
66

77

8+
def gen_args(BATCH, N_CTX, Q_HEAD_NUM, KV_HEAD_NUM, HEAD_DIM, dtype, device):
9+
10+
total_tokens = BATCH * N_CTX
11+
sm_scale = 1.0 / (HEAD_DIM**0.5)
12+
max_kv_splits = 8
13+
num_kv_splits = torch.full((BATCH, ), 4, dtype=torch.int32, device=device)
14+
15+
# q represents the new token being generated, one per batch
16+
q = torch.randn(BATCH, Q_HEAD_NUM, HEAD_DIM, dtype=dtype, device=device)
17+
18+
# k_buffer and v_buffer represent all previous tokens
19+
k_buffer = torch.randn(total_tokens, KV_HEAD_NUM, HEAD_DIM, dtype=dtype, device=device)
20+
v_buffer = torch.randn(total_tokens, KV_HEAD_NUM, HEAD_DIM, dtype=dtype, device=device)
21+
22+
# o will have the same shape as q
23+
o = torch.zeros(BATCH, Q_HEAD_NUM, HEAD_DIM, dtype=dtype, device=device)
24+
25+
b_seq_len = torch.full((BATCH, ), N_CTX, device=device)
26+
27+
kv_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device=device)
28+
kv_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len[:BATCH], dim=0)
29+
kv_indices = torch.arange(total_tokens, device=device)
30+
31+
attn_logits = torch.empty(
32+
(BATCH, Q_HEAD_NUM, max_kv_splits, HEAD_DIM),
33+
dtype=torch.float32,
34+
device=device,
35+
)
36+
attn_lse = torch.empty(
37+
(BATCH, Q_HEAD_NUM, max_kv_splits),
38+
dtype=torch.float32,
39+
device=device,
40+
)
41+
42+
return (q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse, num_kv_splits, max_kv_splits,
43+
sm_scale)
44+
45+
846
# pylint: disable=unused-argument
947
@benchmark_suit.perf_report(
1048
benchmark_suit.Benchmark(
1149
# argument names to use as an x-axis for the plot
12-
x_names=['BATCH', 'SEQ_LENS', 'Q_HEAD_NUM', 'KV_HEAD_NUM', 'HEAD_DIM', 'MODE', 'VALIDATE'],
50+
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE', 'VALIDATE'],
1351
x_vals=[ #
1452
[bs, [1024, 64], 32, 8, 128, 'fwd', False] for bs in [1, 16, 32, 64, 128]
1553
] + [ #
@@ -34,42 +72,31 @@
3472
# name for the plot. Used also as a file name for saving the plot.
3573
args={},
3674
))
37-
def benchmark(BATCH, SEQ_LENS, Q_HEAD_NUM, KV_HEAD_NUM, HEAD_DIM, MODE, VALIDATE, provider):
75+
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, VALIDATE, provider):
3876
torch.manual_seed(0)
3977
dtype = torch.bfloat16
78+
quantiles = [0.5, 0.0, 1.0]
4079
N_CTX = sum(SEQ_LENS)
41-
total_tokens = BATCH * N_CTX
42-
sm_scale = 1.0 / (HEAD_DIM**0.5)
43-
num_kv_splits = 8
44-
45-
# q represents the new token being generated, one per batch
46-
q = torch.randn(BATCH, Q_HEAD_NUM, HEAD_DIM, dtype=dtype, device='xpu')
47-
48-
# k_buffer and v_buffer represent all previous tokens
49-
k_buffer = torch.randn(total_tokens, KV_HEAD_NUM, HEAD_DIM, dtype=dtype, device='xpu')
50-
v_buffer = torch.randn(total_tokens, KV_HEAD_NUM, HEAD_DIM, dtype=dtype, device='xpu')
51-
52-
# o will have the same shape as q
53-
o = torch.zeros(BATCH, Q_HEAD_NUM, HEAD_DIM, dtype=dtype, device='xpu')
5480

55-
b_seq_len = torch.full((BATCH, ), N_CTX, device='xpu')
56-
57-
kv_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device='xpu')
58-
kv_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len[:BATCH], dim=0)
59-
kv_indices = torch.arange(total_tokens, device='xpu')
60-
61-
attn_logits = torch.empty(
62-
(BATCH, Q_HEAD_NUM, num_kv_splits, HEAD_DIM + 1),
63-
dtype=torch.float32,
64-
device='xpu',
65-
)
66-
67-
quantiles = [0.5, 0.0, 1.0]
81+
q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits, attn_lse, num_kv_splits, max_kv_splits, sm_scale = gen_args(
82+
B, N_CTX, H_Q, H_KV, D, dtype, 'xpu')
6883

6984
if provider == 'triton':
70-
triton_fn = lambda: decode_attention_fwd(q, k_buffer, v_buffer, o, kv_indptr, kv_indices, attn_logits,
71-
num_kv_splits, sm_scale, logit_cap=0.0)
72-
# TODO: decode attention do not have validation function
85+
triton_fn = lambda: decode_attention_fwd(
86+
q,
87+
k_buffer,
88+
v_buffer,
89+
o,
90+
kv_indptr,
91+
kv_indices,
92+
attn_logits,
93+
attn_lse,
94+
num_kv_splits,
95+
max_kv_splits,
96+
sm_scale,
97+
)
98+
99+
# TODO: decode attention should have the validation function
73100
if VALIDATE:
74101
raise NotImplementedError('Validation is not implemented for decode stage')
75102

@@ -78,9 +105,8 @@ def benchmark(BATCH, SEQ_LENS, Q_HEAD_NUM, KV_HEAD_NUM, HEAD_DIM, MODE, VALIDATE
78105
else:
79106
raise NotImplementedError(f'Unsupported provider {provider}')
80107

81-
tflops = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM * N_CTX) * N_CTX * HEAD_DIM * (1e-12) / (ms * 1e-3)
82-
83-
gbps = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM * N_CTX) * HEAD_DIM * 2 * (1e-9) / (ms * 1e-3)
108+
tflops = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * N_CTX * D * (1e-12) / (ms * 1e-3)
109+
gbps = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * D * 2 * (1e-9) / (ms * 1e-3)
84110

85111
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
86112

benchmarks/third_party/sglang/extended_attention_benchmark.py

Lines changed: 66 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,67 @@
66
import triton_kernels_benchmark as benchmark_suit
77

88

9+
def gen_args(BATCH, N_CTX, Q_HEAD_NUM, KV_HEAD_NUM, HEAD_DIM, dtype, device):
10+
11+
b_seq_len_prefix = torch.randint(1, N_CTX // 2, (BATCH, ), dtype=torch.int32, device=device)
12+
b_seq_len_extend = torch.randint(1, N_CTX // 2, (BATCH, ), dtype=torch.int32, device=device)
13+
b_seq_len = b_seq_len_prefix + b_seq_len_extend
14+
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
15+
16+
b_req_idx = torch.arange(BATCH, dtype=torch.int32, device=device)
17+
b_start_loc = torch.zeros((BATCH, ), dtype=torch.int32, device=device)
18+
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
19+
b_start_loc_extend = torch.zeros((BATCH, ), dtype=torch.int32, device=device)
20+
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
21+
22+
kv_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device=device)
23+
kv_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len_prefix[:BATCH], dim=0)
24+
kv_indices = torch.zeros((b_seq_len_prefix.sum().item(), ), dtype=torch.int32, device=device)
25+
26+
for i in range(BATCH):
27+
kv_indices[kv_indptr[i]:kv_indptr[i + 1]] = torch.arange(b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i])
28+
29+
total_token_num = torch.sum(b_seq_len).item()
30+
extend_token_num = torch.sum(b_seq_len_extend).item()
31+
k_buffer = torch.empty((total_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype,
32+
device=device).normal_(mean=0.1, std=0.2)
33+
v_buffer = torch.empty((total_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype,
34+
device=device).normal_(mean=0.1, std=0.2)
35+
36+
k_extend = torch.empty((extend_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device)
37+
v_extend = torch.empty((extend_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device)
38+
q_extend = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device)
39+
for i in range(BATCH):
40+
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
41+
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
42+
extend_start = b_start_loc_extend[i]
43+
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
44+
k_extend[extend_start:extend_end] = k_buffer[extend_start_in_buffer:extend_end_in_buffer]
45+
v_extend[extend_start:extend_end] = v_buffer[extend_start_in_buffer:extend_end_in_buffer]
46+
q_extend[extend_start:extend_end] = torch.empty((b_seq_len_extend[i], Q_HEAD_NUM, HEAD_DIM), dtype=dtype,
47+
device=device).normal_(mean=0.1, std=0.2)
48+
49+
o_extend = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device)
50+
o_redundant = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device=device)
51+
52+
b_seq_len_extend = b_seq_len - b_seq_len_prefix
53+
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
54+
qo_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device=device)
55+
qo_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len_extend[:BATCH], dim=0)
56+
57+
params = []
58+
params.append((q_extend, k_extend, v_extend, o_extend, o_redundant))
59+
params.append((k_buffer, v_buffer))
60+
params.append((qo_indptr, kv_indptr, kv_indices, max_len_extend))
61+
params.append((b_req_idx, b_start_loc, b_seq_len, b_seq_len_prefix, max_len_in_batch))
62+
return params
63+
64+
965
# pylint: disable=unused-argument
1066
@benchmark_suit.perf_report(
1167
benchmark_suit.Benchmark(
1268
# argument names to use as an x-axis for the plot
13-
x_names=['BATCH', 'SEQ_LENS', 'Q_HEAD_NUM', 'KV_HEAD_NUM', 'HEAD_DIM', 'MODE', 'VALIDATE'],
69+
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'MODE', 'VALIDATE'],
1470
x_vals=[ #
1571
[bs, [1024, 128, 512], 32, 8, 128, 'fwd', True] for bs in [1, 16, 32, 64, 128]
1672
] + [ #
@@ -35,57 +91,17 @@
3591
# name for the plot. Used also as a file name for saving the plot.
3692
args={},
3793
))
38-
def benchmark(BATCH, SEQ_LENS, Q_HEAD_NUM, KV_HEAD_NUM, HEAD_DIM, MODE, VALIDATE, provider):
94+
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, MODE, VALIDATE, provider):
3995
torch.manual_seed(0)
96+
4097
dtype = torch.bfloat16
4198
N_CTX = sum(SEQ_LENS)
4299

43-
b_seq_len_prefix = torch.randint(1, N_CTX // 2, (BATCH, ), dtype=torch.int32, device='xpu')
44-
b_seq_len_extend = torch.randint(1, N_CTX // 2, (BATCH, ), dtype=torch.int32, device='xpu')
45-
b_seq_len = b_seq_len_prefix + b_seq_len_extend
46-
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
47-
48-
b_req_idx = torch.arange(BATCH, dtype=torch.int32, device='xpu')
49-
b_start_loc = torch.zeros((BATCH, ), dtype=torch.int32, device='xpu')
50-
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
51-
b_start_loc_extend = torch.zeros((BATCH, ), dtype=torch.int32, device='xpu')
52-
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
53-
54-
kv_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device='xpu')
55-
kv_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len_prefix[:BATCH], dim=0)
56-
kv_indices = torch.zeros((b_seq_len_prefix.sum().item(), ), dtype=torch.int32, device='xpu')
57-
58-
for i in range(BATCH):
59-
kv_indices[kv_indptr[i]:kv_indptr[i + 1]] = torch.arange(b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i])
60-
61-
total_token_num = torch.sum(b_seq_len).item()
62-
extend_token_num = torch.sum(b_seq_len_extend).item()
63-
k_buffer = torch.empty((total_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype,
64-
device='xpu').normal_(mean=0.1, std=0.2)
65-
v_buffer = torch.empty((total_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype,
66-
device='xpu').normal_(mean=0.1, std=0.2)
67-
68-
k_extend = torch.empty((extend_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu')
69-
v_extend = torch.empty((extend_token_num, KV_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu')
70-
q_extend = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu')
71-
for i in range(BATCH):
72-
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
73-
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
74-
extend_start = b_start_loc_extend[i]
75-
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
76-
k_extend[extend_start:extend_end] = k_buffer[extend_start_in_buffer:extend_end_in_buffer]
77-
v_extend[extend_start:extend_end] = v_buffer[extend_start_in_buffer:extend_end_in_buffer]
78-
q_extend[extend_start:extend_end] = torch.empty((b_seq_len_extend[i], Q_HEAD_NUM, HEAD_DIM), dtype=dtype,
79-
device='xpu').normal_(mean=0.1, std=0.2)
80-
81-
o_extend = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu')
82-
o_redundant = torch.empty((extend_token_num, Q_HEAD_NUM, HEAD_DIM), dtype=dtype, device='xpu')
83-
84-
b_seq_len_extend = b_seq_len - b_seq_len_prefix
85-
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
86-
qo_indptr = torch.zeros((BATCH + 1, ), dtype=torch.int32, device='xpu')
87-
qo_indptr[1:BATCH + 1] = torch.cumsum(b_seq_len_extend[:BATCH], dim=0)
88-
100+
params = gen_args(B, N_CTX, H_Q, H_KV, D, dtype, 'xpu')
101+
q_extend, k_extend, v_extend, o_extend, o_redundant = params[0]
102+
k_buffer, v_buffer = params[1]
103+
qo_indptr, kv_indptr, kv_indices, max_len_extend = params[2]
104+
b_req_idx, b_start_loc, b_seq_len, b_seq_len_prefix, max_len_in_batch = params[3]
89105
custom_mask = None
90106
mask_indptr = None
91107

@@ -97,7 +113,6 @@ def triton_fn():
97113
kv_indices, custom_mask, mask_indptr, max_len_extend)
98114
return o_extend
99115

100-
# TODO: decode attention do not have validation function
101116
if VALIDATE:
102117

103118
def refer_fn():
@@ -112,9 +127,8 @@ def refer_fn():
112127
else:
113128
raise NotImplementedError(f'Unsupported provider {provider}')
114129

115-
tflops = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM * N_CTX) * N_CTX * HEAD_DIM * (1e-12) / (ms * 1e-3)
116-
117-
gbps = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM * N_CTX) * HEAD_DIM * 2 * (1e-9) / (ms * 1e-3)
130+
tflops = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * N_CTX * D * (1e-12) / (ms * 1e-3)
131+
gbps = lambda ms: 2 * B * (H_Q + H_KV * N_CTX) * D * 2 * (1e-9) / (ms * 1e-3)
118132

119133
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
120134

benchmarks/third_party/sglang/prefill_attention_benchmark.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,33 @@
55
import triton_kernels_benchmark as benchmark_suit
66

77

8+
def gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, device):
9+
max_seq_len = max(SEQ_LENS)
10+
N_CTX = sum(SEQ_LENS)
11+
12+
# Create random input tensors
13+
q = torch.randn((B * N_CTX, H_Q, D), device=device, dtype=dtype)
14+
k = torch.randn((B * N_CTX, H_KV, D), device=device, dtype=dtype)
15+
v = torch.randn((B * N_CTX, H_KV, D), device=device, dtype=dtype)
16+
o = torch.zeros((B * N_CTX, H_Q, D), device=device, dtype=dtype)
17+
18+
# Create b_start_loc and b_seq_len tensors
19+
b_start_loc = torch.tensor([0, SEQ_LENS[0]], device=device)
20+
b_seq_len = torch.tensor(SEQ_LENS, device=device)
21+
22+
return (q, k, v, o, b_start_loc, b_seq_len, max_seq_len)
23+
24+
825
# pylint: disable=unused-argument
926
@benchmark_suit.perf_report(
1027
benchmark_suit.Benchmark(
1128
# argument names to use as an x-axis for the plot
12-
x_names=['BATCH', 'SEQ_LENS', 'Q_HEAD_NUM', 'KV_HEAD_NUM', 'HEAD_DIM', 'CAUSAL', 'MODE', 'VALIDATE'],
29+
x_names=['B', 'SEQ_LENS', 'H_Q', 'H_KV', 'D', 'CAUSAL', 'MODE', 'VALIDATE'],
1330
x_vals=[ #
1431
[bs, [1024], 32, 8, 128, causal, 'fwd', False] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
15-
] + [ # noqa
32+
] + [ #
1633
[bs, [1024], 32, 32, 96, causal, 'fwd', False] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
17-
] + [ # noqa
34+
] + [ #
1835
[bs, [1024], 28, 4, 128, causal, 'fwd', False] for causal in [True, False] for bs in [1, 16, 32, 64, 128]
1936
],
2037
line_arg='provider',
@@ -34,30 +51,20 @@
3451
# name for the plot. Used also as a file name for saving the plot.
3552
args={},
3653
))
37-
def benchmark(BATCH, SEQ_LENS, Q_HEAD_NUM, KV_HEAD_NUM, HEAD_DIM, CAUSAL, MODE, VALIDATE, provider):
54+
def benchmark(B, SEQ_LENS, H_Q, H_KV, D, CAUSAL, MODE, VALIDATE, provider):
3855
torch.manual_seed(0)
3956
dtype = torch.bfloat16
40-
device = 'xpu'
4157
N_CTX = sum(SEQ_LENS)
42-
max_seq_len = max(SEQ_LENS)
4358

44-
# Create random input tensors
45-
q = torch.randn((BATCH * N_CTX, Q_HEAD_NUM, HEAD_DIM), device=device, dtype=dtype)
46-
k = torch.randn((BATCH * N_CTX, KV_HEAD_NUM, HEAD_DIM), device=device, dtype=dtype)
47-
v = torch.randn((BATCH * N_CTX, KV_HEAD_NUM, HEAD_DIM), device=device, dtype=dtype)
48-
o = torch.zeros((BATCH * N_CTX, Q_HEAD_NUM, HEAD_DIM), device=device, dtype=dtype)
49-
50-
# Create b_start_loc and b_seq_len tensors
51-
b_start_loc = torch.tensor([0, SEQ_LENS[0]], device=device)
52-
b_seq_len = torch.tensor(SEQ_LENS, device=device)
59+
q, k, v, o, b_start_loc, b_seq_len, max_seq_len = gen_args(B, SEQ_LENS, H_Q, H_KV, D, dtype, 'xpu')
5360

5461
quantiles = [0.5, 0.0, 1.0]
5562
if provider == 'triton':
5663

5764
triton_fn = lambda: context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=CAUSAL)
5865

5966
if VALIDATE:
60-
# FIXME: torch sdpa does not support different Q_HEAD_NUM and KV_HEAD_NUM
67+
# FIXME: torch sdpa does not support different H_Q and H_KV
6168
cu_seq_lens = [0] * (len(SEQ_LENS) + 1)
6269
for i, seq_len in enumerate(SEQ_LENS):
6370
cu_seq_lens[i + 1] = cu_seq_lens[i] + seq_len
@@ -81,9 +88,8 @@ def benchmark(BATCH, SEQ_LENS, Q_HEAD_NUM, KV_HEAD_NUM, HEAD_DIM, CAUSAL, MODE,
8188
else:
8289
raise NotImplementedError(f'Unsupported provider {provider}')
8390

84-
tflops = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM) * N_CTX * N_CTX * HEAD_DIM * (1e-12) / (ms * 1e-3)
85-
86-
gbps = lambda ms: 2 * BATCH * (Q_HEAD_NUM + KV_HEAD_NUM) * N_CTX * HEAD_DIM * 2 * (1e-9) / (ms * 1e-3)
91+
tflops = lambda ms: 2 * B * (H_Q + H_KV) * N_CTX * N_CTX * D * (1e-12) / (ms * 1e-3)
92+
gbps = lambda ms: 2 * B * (H_Q + H_KV) * N_CTX * D * 2 * (1e-9) / (ms * 1e-3)
8793

8894
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv
8995

0 commit comments

Comments
 (0)