8
8
AttentionMetadata ,
9
9
MLAAttentionImpl )
10
10
from vllm .attention .backends .utils import PAD_SLOT_ID
11
+ from vllm .config import get_current_vllm_config
11
12
from vllm .model_executor .layers .linear import (LinearBase ,
12
13
UnquantizedLinearMethod )
13
14
@@ -83,6 +84,7 @@ class AscendMLADecodeMetadata:
83
84
seq_lens : torch .Tensor
84
85
max_seq_lens : int
85
86
seq_lens_list : list [int ]
87
+ attn_mask : torch .Tensor
86
88
87
89
88
90
@dataclass
@@ -170,11 +172,13 @@ def reorder_batch(self, input_batch: "InputBatch",
170
172
171
173
for i , req_id in enumerate (input_batch .req_ids ):
172
174
num_tokens = scheduler_output .num_scheduled_tokens [req_id ]
175
+ num_spec_tokens = len (
176
+ scheduler_output .scheduled_spec_decode_tokens .get (req_id , []))
173
177
# for now treat 1 scheduled token as "decode" even if its not,
174
178
# we should update this to something like < 8 in the future but
175
179
# currently the TritonMLA._forward_decode only supports
176
180
# num_tokens = 1
177
- if num_tokens == 1 :
181
+ if num_tokens - num_spec_tokens == 1 :
178
182
decodes .append (i )
179
183
num_decode_tokens += num_tokens
180
184
else :
@@ -317,7 +321,7 @@ def build(
317
321
seq_lens = seq_lens_cpu
318
322
max_query_len = query_lens .max ().item ()
319
323
max_seq_lens = seq_lens .max ().item ()
320
- query_start_loc = None
324
+ query_start_loc = common_attn_metadata . query_start_loc
321
325
322
326
prefill_metadata = None
323
327
if self ._num_prefills > 0 :
@@ -382,7 +386,8 @@ def build(
382
386
block_table = block_table ,
383
387
seq_lens = seq_lens ,
384
388
seq_lens_list = seq_lens .tolist (),
385
- max_seq_lens = max_seq_lens )
389
+ max_seq_lens = max_seq_lens ,
390
+ attn_mask = self .runner .spec_attn_mask )
386
391
387
392
return self .metadata_cls ( # type: ignore
388
393
num_actual_tokens = num_actual_tokens ,
@@ -445,6 +450,17 @@ def __init__(
445
450
446
451
ascend_config = get_ascend_config ()
447
452
self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
453
+ # Adapt torch air graph mode with spec decoding.
454
+ speculative_config = get_current_vllm_config ().speculative_config
455
+ self .fia_sparse_mode = 0
456
+ self .use_spec_decode = False
457
+ # We need to set the sparse_mode of fused_infer_attention op to 3
458
+ # in spec decoding scenario in order to pass in attention mask.
459
+ if speculative_config is not None :
460
+ self .fia_sparse_mode = 3
461
+ self .use_spec_decode = True
462
+ self .spec_token_num = speculative_config .num_speculative_tokens
463
+ assert self .spec_token_num > 0
448
464
449
465
def _v_up_proj_and_o_proj (self , x ):
450
466
# Convert from (B, N, L) to (N, B, L)
@@ -646,9 +662,24 @@ def _forward_decode(
646
662
dtype = q .dtype ,
647
663
device = q .device )
648
664
if self .running_in_graph :
649
- # TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
650
- q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
651
- q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
665
+ # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
666
+ if self .use_spec_decode :
667
+ assert num_tokens % self .spec_token_num == 0
668
+ q_nope = (q_nope .view (
669
+ num_tokens // (self .spec_token_num + 1 ),
670
+ self .spec_token_num + 1 ,
671
+ self .num_heads ,
672
+ - 1 ,
673
+ ).transpose (1 , 2 ).contiguous ())
674
+ q_pe = (q_pe .view (
675
+ num_tokens // (self .spec_token_num + 1 ),
676
+ self .spec_token_num + 1 ,
677
+ self .num_heads ,
678
+ - 1 ,
679
+ ).transpose (1 , 2 ).contiguous ())
680
+ else :
681
+ q_nope = q_nope .view (num_tokens , self .num_heads , 1 , - 1 )
682
+ q_pe = q_pe .view (num_tokens , self .num_heads , 1 , - 1 )
652
683
# shape of knope/k_pe for npu graph mode should be:
653
684
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
654
685
block_size = kv_c_and_k_pe_cache [0 ].shape [1 ]
@@ -666,7 +697,8 @@ def _forward_decode(
666
697
num_heads = self .num_heads ,
667
698
num_key_value_heads = self .num_kv_heads ,
668
699
input_layout = "BNSD" ,
669
- atten_mask = attn_metadata .attn_mask ,
700
+ atten_mask = attn_metadata .decode .attn_mask , # type:ignore
701
+ sparse_mode = self .fia_sparse_mode ,
670
702
scale = self .scale ,
671
703
antiquant_mode = 0 ,
672
704
antiquant_scale = None ,
0 commit comments