@@ -34,6 +34,7 @@ def _vllm_layout_trans_kernel(
34
34
v_buffer_ptr ,
35
35
k_values_ptr ,
36
36
v_values_ptr ,
37
+ b_query_lens_loc ,
37
38
b_seq_lens_loc ,
38
39
block_table ,
39
40
block_table_stride_0 ,
@@ -46,6 +47,13 @@ def _vllm_layout_trans_kernel(
46
47
tl .arange (0 , 2 ))
47
48
batch_token_start , batch_token_end = tl .split (batch_token_indexes )
48
49
seq_len = batch_token_end - batch_token_start
50
+
51
+ batch_query_indexes = tl .load (b_query_lens_loc + batch_idx +
52
+ tl .arange (0 , 2 ))
53
+ batch_query_start , batch_query_end = tl .split (batch_query_indexes )
54
+ query_len = batch_query_end - batch_query_start
55
+ if query_len <= 1 :
56
+ return
49
57
if block_idx * BLOCK_SIZE < seq_len :
50
58
block_mask = (block_idx * BLOCK_SIZE +
51
59
tl .arange (0 , BLOCK_SIZE )[:, None ]) < seq_len
@@ -69,8 +77,8 @@ def _vllm_layout_trans_kernel(
69
77
tl .store (k_values_ptr + kv_values_off , k_vals , mask = block_mask )
70
78
tl .store (v_values_ptr + kv_values_off , v_vals , mask = block_mask )
71
79
72
- def vllm_layout_trans (b_seq_lens_loc , block_table , k_buffer , v_buffer ,
73
- max_seq_len , total_tokens ):
80
+ def vllm_layout_trans (b_query_lens_loc , b_seq_lens_loc , block_table ,
81
+ k_buffer , v_buffer , max_seq_len , total_tokens ):
74
82
H_KV = v_buffer .shape [2 ]
75
83
D = v_buffer .shape [3 ]
76
84
BLOCK_SIZE = v_buffer .shape [1 ]
@@ -89,6 +97,7 @@ def vllm_layout_trans(b_seq_lens_loc, block_table, k_buffer, v_buffer,
89
97
v_buffer ,
90
98
k_values ,
91
99
v_values ,
100
+ b_query_lens_loc ,
92
101
b_seq_lens_loc ,
93
102
block_table ,
94
103
block_table .stride (0 ),
@@ -112,8 +121,8 @@ def flash_attn_varlen_func_impl(
112
121
alibi_slopes : Optional [list [float ]],
113
122
block_table : torch .Tensor ,
114
123
) -> torch .Tensor :
115
- k , v = vllm_layout_trans (cu_seqlens_k , block_table , k_cache , v_cache ,
116
- max_seqlen_k , total_tokens )
124
+ k , v = vllm_layout_trans (cu_seqlens_q , cu_seqlens_k , block_table ,
125
+ k_cache , v_cache , max_seqlen_k , total_tokens )
117
126
output = aiter .flash_attn_varlen_func (
118
127
q = q ,
119
128
k = k ,
0 commit comments