diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 292294b93..e6a237678 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -100,6 +100,7 @@ class AscendAttentionState(Enum): PrefillCacheHit = 1 DecodeOnly = 2 ChunkedPrefill = 3 + SpecDecoding = 4 @dataclass diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 91ddf4388..226a5705c 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -8,6 +8,7 @@ AttentionMetadata, MLAAttentionImpl) from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import get_current_vllm_config from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) @@ -86,6 +87,7 @@ class AscendMLADecodeMetadata: seq_lens: torch.Tensor max_seq_lens: int seq_lens_list: list[int] + attn_mask: Optional[torch.Tensor] = None @dataclass @@ -169,6 +171,8 @@ def __init__(self, self.runner = runner scheduler_config = runner.scheduler_config self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -185,16 +189,24 @@ def reorder_batch(self, input_batch: "InputBatch", for i, req_id in enumerate(input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] - # for now treat 1 scheduled token as "decode" even if its not, - # we should update this to something like < 8 in the future but - # currently the TritonMLA._forward_decode only supports - # num_tokens = 1 - if num_tokens == 1: - decodes.append(i) - num_decode_tokens += num_tokens + num_spec_tokens = len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + # For torch air graph mode we treat spec decoding as decode. + if self.torchair_graph_enabled: + if num_tokens - num_spec_tokens == 1: + decodes.append(i) + num_decode_tokens += num_tokens + else: + prefills.append(i) + num_prefill_tokens += num_tokens + # For eager mode we treat spec decoding as chunked prefill. else: - prefills.append(i) - num_prefill_tokens += num_tokens + if num_tokens == 1: + decodes.append(i) + num_decode_tokens += num_tokens + else: + prefills.append(i) + num_prefill_tokens += num_tokens # We hope that this is fairly minimal since decodes # should be around for a number of iterations so hopefully they are @@ -284,7 +296,8 @@ def build_dummy(self, num_reqs: int, block_table=block_table, seq_lens=seq_lens, seq_lens_list=seq_lens.tolist(), - max_seq_lens=1) + max_seq_lens=1, + attn_mask=self.runner.spec_attn_mask) return self.metadata_cls( # type: ignore num_input_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens, @@ -332,7 +345,7 @@ def build( seq_lens = seq_lens_cpu max_query_len = query_lens.max().item() max_seq_lens = seq_lens.max().item() - query_start_loc = None + query_start_loc = common_attn_metadata.query_start_loc prefill_metadata = None if self._num_prefills > 0: @@ -397,7 +410,8 @@ def build( block_table=block_table, seq_lens=seq_lens, seq_lens_list=seq_lens.tolist(), - max_seq_lens=max_seq_lens) + max_seq_lens=max_seq_lens, + attn_mask=self.runner.spec_attn_mask) return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, @@ -461,6 +475,11 @@ def __init__( ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + # Adapt torch air graph mode with spec decoding. + speculative_config = get_current_vllm_config().speculative_config + if speculative_config is not None: + self.spec_token_num = speculative_config.num_speculative_tokens + assert self.spec_token_num > 0 def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) @@ -550,7 +569,10 @@ def _forward_prefill( num_tokens = query.size(0) attn_output = None # Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache - if attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill: + if attn_metadata.attn_state in [ + AscendAttentionState.ChunkedPrefill, + AscendAttentionState.SpecDecoding + ]: attn_output = torch.empty(num_tokens, self.num_heads * self.v_head_dim, dtype=query.dtype, @@ -597,7 +619,7 @@ def _forward_prefill( attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim) else: raise RuntimeError( - "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !" + "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !" ) attn_output = attn_output.reshape( [num_tokens, self.num_heads * self.v_head_dim]) @@ -670,9 +692,28 @@ def _forward_decode( dtype=q.dtype, device=q.device) if self.running_in_graph: - # TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim] - q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) - q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim] + if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: + assert num_tokens % self.spec_token_num == 0 + q_nope = (q_nope.view( + num_tokens // (self.spec_token_num + 1), + self.spec_token_num + 1, + self.num_heads, + -1, + ).transpose(1, 2).contiguous()) + q_pe = (q_pe.view( + num_tokens // (self.spec_token_num + 1), + self.spec_token_num + 1, + self.num_heads, + -1, + ).transpose(1, 2).contiguous()) + sparse_mode = 3 + spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore + else: + q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) + q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + sparse_mode = 0 + spec_attn_mask = None # shape of knope/k_pe for npu graph mode should be: # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] block_size = kv_c_and_k_pe_cache[0].shape[1] @@ -690,7 +731,8 @@ def _forward_decode( num_heads=self.num_heads, num_key_value_heads=self.num_kv_heads, input_layout="BNSD", - atten_mask=attn_metadata.attn_mask, + atten_mask=spec_attn_mask, + sparse_mode=sparse_mode, scale=self.scale, antiquant_mode=0, antiquant_scale=None, @@ -732,7 +774,9 @@ def forward( if attn_metadata is None: # Profiling run. return output - self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.DecodeOnly + self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [ + AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding + ] num_actual_toks = attn_metadata.num_actual_tokens if k_pe is None and not self.running_in_graph: kv_c, k_pe = self.kv_a_proj_with_mqa( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 2590212b5..03be38f31 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -203,8 +203,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): # Set up speculative decoding. self.use_spec_decode = False + self.spec_attn_mask = None if self.speculative_config: self.use_spec_decode = True + self.spec_attn_mask = torch.triu(torch.ones(2048, + 2048, + dtype=torch.bool), + diagonal=1).to("npu") if get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) @@ -779,10 +784,13 @@ def _process_reqs( # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32) + num_valid_tokens = np.empty(num_reqs, dtype=np.int32) max_num_scheduled_tokens = 0 for i, req_id in enumerate(self.input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens[i] = num_tokens + num_valid_tokens[i] = num_tokens - \ + len(scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) @@ -838,11 +846,16 @@ def _process_reqs( out=self.slot_mapping_np[:total_num_scheduled_tokens]) ascend_config = get_ascend_config() + use_spec_decode = len( + scheduler_output.scheduled_spec_decode_tokens) > 0 if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens): attn_state = AscendAttentionState.PrefillNoCache # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. elif np.all(num_scheduled_tokens == 1): attn_state = AscendAttentionState.DecodeOnly + # Speculative decoding. + elif np.all(num_valid_tokens == 1): + attn_state = AscendAttentionState.SpecDecoding # splitfuse elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled: attn_state = AscendAttentionState.ChunkedPrefill @@ -873,7 +886,9 @@ def _process_reqs( seq_lens = self.seq_lens[:num_reqs] common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, seq_lens=seq_lens) - with_prefill = attn_state != AscendAttentionState.DecodeOnly + with_prefill = attn_state not in [ + AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding + ] if self.dp_size > 1: max_num_tokens, with_prefill = self._get_forward_metadata_across_dp( @@ -883,14 +898,14 @@ def _process_reqs( # Add graph_pad_size here if envs_ascend.VLLM_ENABLE_MC2 or (self.torchair_graph_enabled and not with_prefill): - batch_size = len(seq_lens) if self.dp_size > 1: padded_batch_size = self.select_torchair_padded_batch_size( max_num_tokens) else: padded_batch_size = self.select_torchair_padded_batch_size( - batch_size) - graph_pad_size = padded_batch_size - batch_size + total_num_scheduled_tokens) + graph_pad_size = padded_batch_size - total_num_scheduled_tokens + extra_builder_kwargs['graph_pad_size'] = graph_pad_size if self.vllm_config.model_config.use_mla: diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 8782df181..ba8406fa0 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -4,7 +4,8 @@ set_current_vllm_config) from vllm.forward_context import set_forward_context from vllm.model_executor.model_loader import get_model_loader -from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.model_loader.utils import ( + process_weights_after_loading, set_default_torch_dtype) from vllm.v1.sample.metadata import SamplingMetadata from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata @@ -199,6 +200,8 @@ def load_model(self) -> None: loader.get_all_weights( self.vllm_config.speculative_config.draft_model_config, self.model)) + process_weights_after_loading(self.model, draft_model_config, + target_device) # TODO Using torch instead of triton may result in poor performance