Skip to content

[MTP][V1] Adapt mtp with graph mode in v1. #1023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class AscendAttentionState(Enum):
PrefillCacheHit = 1
DecodeOnly = 2
ChunkedPrefill = 3
SpecDecoding = 4


@dataclass
Expand Down
82 changes: 63 additions & 19 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
AttentionMetadata,
MLAAttentionImpl)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)

Expand Down Expand Up @@ -86,6 +87,7 @@ class AscendMLADecodeMetadata:
seq_lens: torch.Tensor
max_seq_lens: int
seq_lens_list: list[int]
attn_mask: Optional[torch.Tensor] = None


@dataclass
Expand Down Expand Up @@ -169,6 +171,8 @@ def __init__(self,
self.runner = runner
scheduler_config = runner.scheduler_config
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled

def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
Expand All @@ -185,16 +189,24 @@ def reorder_batch(self, input_batch: "InputBatch",

for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
# For torch air graph mode we treat spec decoding as decode.
if self.torchair_graph_enabled:
if num_tokens - num_spec_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# For eager mode we treat spec decoding as chunked prefill.
else:
prefills.append(i)
num_prefill_tokens += num_tokens
if num_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens

# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
Expand Down Expand Up @@ -284,7 +296,8 @@ def build_dummy(self, num_reqs: int,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_seq_lens=1)
max_seq_lens=1,
attn_mask=self.runner.spec_attn_mask)
return self.metadata_cls( # type: ignore
num_input_tokens=num_actual_tokens,
num_actual_tokens=num_actual_tokens,
Expand Down Expand Up @@ -332,7 +345,7 @@ def build(
seq_lens = seq_lens_cpu
max_query_len = query_lens.max().item()
max_seq_lens = seq_lens.max().item()
query_start_loc = None
query_start_loc = common_attn_metadata.query_start_loc

prefill_metadata = None
if self._num_prefills > 0:
Expand Down Expand Up @@ -397,7 +410,8 @@ def build(
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_seq_lens=max_seq_lens)
max_seq_lens=max_seq_lens,
attn_mask=self.runner.spec_attn_mask)

return self.metadata_cls( # type: ignore
num_actual_tokens=num_actual_tokens,
Expand Down Expand Up @@ -461,6 +475,11 @@ def __init__(

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
# Adapt torch air graph mode with spec decoding.
speculative_config = get_current_vllm_config().speculative_config
if speculative_config is not None:
self.spec_token_num = speculative_config.num_speculative_tokens
assert self.spec_token_num > 0

def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
Expand Down Expand Up @@ -550,7 +569,10 @@ def _forward_prefill(
num_tokens = query.size(0)
attn_output = None
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
if attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill:
if attn_metadata.attn_state in [
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.SpecDecoding
]:
attn_output = torch.empty(num_tokens,
self.num_heads * self.v_head_dim,
dtype=query.dtype,
Expand Down Expand Up @@ -597,7 +619,7 @@ def _forward_prefill(
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
else:
raise RuntimeError(
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
)
attn_output = attn_output.reshape(
[num_tokens, self.num_heads * self.v_head_dim])
Expand Down Expand Up @@ -670,9 +692,28 @@ def _forward_decode(
dtype=q.dtype,
device=q.device)
if self.running_in_graph:
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
assert num_tokens % self.spec_token_num == 0
q_nope = (q_nope.view(
num_tokens // (self.spec_token_num + 1),
self.spec_token_num + 1,
self.num_heads,
-1,
).transpose(1, 2).contiguous())
q_pe = (q_pe.view(
num_tokens // (self.spec_token_num + 1),
self.spec_token_num + 1,
self.num_heads,
-1,
).transpose(1, 2).contiguous())
sparse_mode = 3
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
else:
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
sparse_mode = 0
spec_attn_mask = None
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
block_size = kv_c_and_k_pe_cache[0].shape[1]
Expand All @@ -690,7 +731,8 @@ def _forward_decode(
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BNSD",
atten_mask=attn_metadata.attn_mask,
atten_mask=spec_attn_mask,
sparse_mode=sparse_mode,
scale=self.scale,
antiquant_mode=0,
antiquant_scale=None,
Expand Down Expand Up @@ -732,7 +774,9 @@ def forward(
if attn_metadata is None:
# Profiling run.
return output
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.DecodeOnly
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
num_actual_toks = attn_metadata.num_actual_tokens
if k_pe is None and not self.running_in_graph:
kv_c, k_pe = self.kv_a_proj_with_mqa(
Expand Down
23 changes: 19 additions & 4 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):

# Set up speculative decoding.
self.use_spec_decode = False
self.spec_attn_mask = None
if self.speculative_config:
self.use_spec_decode = True
self.spec_attn_mask = torch.triu(torch.ones(2048,
2048,
dtype=torch.bool),
diagonal=1).to("npu")
if get_pp_group().is_last_rank:
if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config)
Expand Down Expand Up @@ -779,10 +784,13 @@ def _process_reqs(
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
num_valid_tokens = np.empty(num_reqs, dtype=np.int32)
max_num_scheduled_tokens = 0
for i, req_id in enumerate(self.input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens[i] = num_tokens
num_valid_tokens[i] = num_tokens - \
len(scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)

Expand Down Expand Up @@ -838,11 +846,16 @@ def _process_reqs(
out=self.slot_mapping_np[:total_num_scheduled_tokens])

ascend_config = get_ascend_config()
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
attn_state = AscendAttentionState.PrefillNoCache
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly
# Speculative decoding.
elif np.all(num_valid_tokens == 1):
attn_state = AscendAttentionState.SpecDecoding
# splitfuse
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
attn_state = AscendAttentionState.ChunkedPrefill
Expand Down Expand Up @@ -873,7 +886,9 @@ def _process_reqs(
seq_lens = self.seq_lens[:num_reqs]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)
with_prefill = attn_state != AscendAttentionState.DecodeOnly
with_prefill = attn_state not in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]

if self.dp_size > 1:
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
Expand All @@ -883,14 +898,14 @@ def _process_reqs(
# Add graph_pad_size here
if envs_ascend.VLLM_ENABLE_MC2 or (self.torchair_graph_enabled
and not with_prefill):
batch_size = len(seq_lens)
if self.dp_size > 1:
padded_batch_size = self.select_torchair_padded_batch_size(
max_num_tokens)
else:
padded_batch_size = self.select_torchair_padded_batch_size(
batch_size)
graph_pad_size = padded_batch_size - batch_size
total_num_scheduled_tokens)
graph_pad_size = padded_batch_size - total_num_scheduled_tokens

extra_builder_kwargs['graph_pad_size'] = graph_pad_size

if self.vllm_config.model_config.use_mla:
Expand Down
5 changes: 4 additions & 1 deletion vllm_ascend/worker/mtp_proposer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
set_current_vllm_config)
from vllm.forward_context import set_forward_context
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.utils import (
process_weights_after_loading, set_default_torch_dtype)
from vllm.v1.sample.metadata import SamplingMetadata

from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
Expand Down Expand Up @@ -199,6 +200,8 @@ def load_model(self) -> None:
loader.get_all_weights(
self.vllm_config.speculative_config.draft_model_config,
self.model))
process_weights_after_loading(self.model, draft_model_config,
target_device)


# TODO Using torch instead of triton may result in poor performance
Expand Down
Loading