Skip to content

Commit a501ff0

Browse files
committed
optimize performance
Signed-off-by: fsx950223 <[email protected]>
1 parent e995fdc commit a501ff0

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def _vllm_layout_trans_kernel(
3434
v_buffer_ptr,
3535
k_values_ptr,
3636
v_values_ptr,
37+
b_query_lens_loc,
3738
b_seq_lens_loc,
3839
block_table,
3940
block_table_stride_0,
@@ -46,6 +47,13 @@ def _vllm_layout_trans_kernel(
4647
tl.arange(0, 2))
4748
batch_token_start, batch_token_end = tl.split(batch_token_indexes)
4849
seq_len = batch_token_end - batch_token_start
50+
51+
batch_query_indexes = tl.load(b_query_lens_loc + batch_idx +
52+
tl.arange(0, 2))
53+
batch_query_start, batch_query_end = tl.split(batch_query_indexes)
54+
query_len = batch_query_end - batch_query_start
55+
if query_len <= 1:
56+
return
4957
if block_idx * BLOCK_SIZE < seq_len:
5058
block_mask = (block_idx * BLOCK_SIZE +
5159
tl.arange(0, BLOCK_SIZE)[:, None]) < seq_len
@@ -69,8 +77,8 @@ def _vllm_layout_trans_kernel(
6977
tl.store(k_values_ptr + kv_values_off, k_vals, mask=block_mask)
7078
tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask)
7179

72-
def vllm_layout_trans(b_seq_lens_loc, block_table, k_buffer, v_buffer,
73-
max_seq_len, total_tokens):
80+
def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table,
81+
k_buffer, v_buffer, max_seq_len, total_tokens):
7482
H_KV = v_buffer.shape[2]
7583
D = v_buffer.shape[3]
7684
BLOCK_SIZE = v_buffer.shape[1]
@@ -89,6 +97,7 @@ def vllm_layout_trans(b_seq_lens_loc, block_table, k_buffer, v_buffer,
8997
v_buffer,
9098
k_values,
9199
v_values,
100+
b_query_lens_loc,
92101
b_seq_lens_loc,
93102
block_table,
94103
block_table.stride(0),
@@ -112,8 +121,8 @@ def flash_attn_varlen_func_impl(
112121
alibi_slopes: Optional[list[float]],
113122
block_table: torch.Tensor,
114123
) -> torch.Tensor:
115-
k, v = vllm_layout_trans(cu_seqlens_k, block_table, k_cache, v_cache,
116-
max_seqlen_k, total_tokens)
124+
k, v = vllm_layout_trans(cu_seqlens_q, cu_seqlens_k, block_table,
125+
k_cache, v_cache, max_seqlen_k, total_tokens)
117126
output = aiter.flash_attn_varlen_func(
118127
q=q,
119128
k=k,

0 commit comments

Comments
 (0)