diff --git a/examples/disaggregate_prefill_v1/README.md b/examples/disaggregate_prefill_v1/README.md index f096dcd1c..d35472382 100644 --- a/examples/disaggregate_prefill_v1/README.md +++ b/examples/disaggregate_prefill_v1/README.md @@ -30,15 +30,14 @@ Execution Sequence * Run prefill server P1 on first node ```shell -export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'` -export GLOO_SOCKET_IFNAME="eth0" +export HCCL_IF_IP=172.19.32.175 # node ip +export GLOO_SOCKET_IFNAME="eth0" # network card name export TP_SOCKET_IFNAME="eth0" export HCCL_SOCKET_IFNAME="eth0" export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json export OMP_PROC_BIND=false export OMP_NUM_THREADS=100 export VLLM_USE_V1=1 -export VLLM_VERSION=0.9.1 vllm serve /data01/deepseek_r1_w8a8_zhw \ --host 0.0.0.0 \ --port 20002 \ @@ -71,7 +70,7 @@ vllm serve /data01/deepseek_r1_w8a8_zhw \ * Run prefill server P2 on second node ```shell -export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'` +export HCCL_IF_IP=172.19.241.49 export GLOO_SOCKET_IFNAME="eth0" export TP_SOCKET_IFNAME="eth0" export HCCL_SOCKET_IFNAME="eth0" @@ -79,7 +78,6 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example export OMP_PROC_BIND=false export OMP_NUM_THREADS=100 export VLLM_USE_V1=1 -export VLLM_VERSION=0.9.1 vllm serve /data01/deepseek_r1_w8a8_zhw \ --host 0.0.0.0 \ --port 20002 \ @@ -113,7 +111,7 @@ vllm serve /data01/deepseek_r1_w8a8_zhw \ * Run decode server d1 on third node ```shell -export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'` +export HCCL_IF_IP=172.19.123.51 export GLOO_SOCKET_IFNAME="eth0" export TP_SOCKET_IFNAME="eth0" export HCCL_SOCKET_IFNAME="eth0" @@ -121,7 +119,6 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example export OMP_PROC_BIND=false export OMP_NUM_THREADS=100 export VLLM_USE_V1=1 -export VLLM_VERSION=0.9.1 vllm serve /data01/deepseek_r1_w8a8_zhw \ --host 0.0.0.0 \ --port 20002 \ @@ -154,7 +151,7 @@ vllm serve /data01/deepseek_r1_w8a8_zhw \ * Run decode server d2 on last node ```shell -export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'` +export HCCL_IF_IP=172.19.190.36 export GLOO_SOCKET_IFNAME="eth0" export TP_SOCKET_IFNAME="eth0" export HCCL_SOCKET_IFNAME="eth0" @@ -162,7 +159,6 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example export OMP_PROC_BIND=false export OMP_NUM_THREADS=100 export VLLM_USE_V1=1 -export VLLM_VERSION=0.9.1 vllm serve /data01/deepseek_r1_w8a8_zhw \ --host 0.0.0.0 \ --port 20002 \ diff --git a/examples/disaggregate_prefill_v1/gen_ranktable.sh b/examples/disaggregate_prefill_v1/gen_ranktable.sh index 516bea956..33d4a32e8 100644 --- a/examples/disaggregate_prefill_v1/gen_ranktable.sh +++ b/examples/disaggregate_prefill_v1/gen_ranktable.sh @@ -8,7 +8,6 @@ while [[ $# -gt 0 ]]; do case "$1" in --ips) shift - # 收集所有后续参数直到遇到下一个选项或结束 while [[ $# -gt 0 && ! "$1" == --* ]]; do IPs+=("$1") shift diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index a4e712540..6e478e82e 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -13,6 +13,7 @@ UnquantizedLinearMethod) from vllm.utils import cdiv, round_down +from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV from vllm_ascend.attention.attention_v1 import AscendAttentionState @@ -933,18 +934,12 @@ def _forward_decode( q_pe: torch.Tensor, k_nope: torch.Tensor, k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, + kv_c_and_k_pe_cache: Tuple[torch.Tensor], attn_metadata: AscendMLAMetadata, ) -> torch.Tensor: decode_meta = attn_metadata.decode assert decode_meta is not None - - q = torch.cat([q_nope, q_pe], dim=-1) - num_tokens = q.size(0) - attn_output = torch.empty( - [num_tokens, self.num_heads, self.kv_lora_rank], - dtype=q.dtype, - device=q.device) + num_tokens = q_nope.size(0) if self.running_in_graph: # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim] if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: @@ -1003,16 +998,35 @@ def _forward_decode( actual_seq_lengths_kv=decode_meta.seq_lens_list, ) else: - torch_npu._npu_paged_attention_mla( - query=q, - key_cache=kv_c_and_k_pe_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.decode.block_table, # type:ignore - context_lens=attn_metadata.decode.seq_lens, # type:ignore - mla_vheadsize=self.kv_lora_rank, - out=attn_output) + # The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will + # be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become + # public available + assert len(kv_c_and_k_pe_cache) > 1 + if envs.VLLM_ASCEND_MLA_PA: + attn_output = torch_npu.atb.npu_multi_head_latent_attention( + q_nope, q_pe, kv_c_and_k_pe_cache[0], + kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table, + attn_metadata.decode.seq_lens, self.num_heads, self.scale, + self.num_kv_heads) + else: + q = torch.cat([q_nope, q_pe], dim=-1) + attn_output = torch.empty( + [num_tokens, self.num_heads, self.kv_lora_rank], + dtype=q.dtype, + device=q.device) + k_cache = torch.cat( + [kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1) + torch_npu._npu_paged_attention_mla( + query=q, + key_cache=k_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.decode. + block_table, # type:ignore + context_lens=attn_metadata.decode.seq_lens, # type:ignore + mla_vheadsize=self.kv_lora_rank, + out=attn_output) current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is None: return self._v_up_proj_and_o_proj(attn_output) @@ -1193,10 +1207,11 @@ def forward( decode_k_nope, decode_k_pe, kv_cache, attn_metadata) else: - combined_cache = torch.cat([kv_cache[0], kv_cache[1]], dim=-1) - output_decode = self._forward_decode( - decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe, - combined_cache, attn_metadata) + output_decode = self._forward_decode(decode_ql_nope, + decode_q_pe, + decode_k_nope, + decode_k_pe, kv_cache, + attn_metadata) current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is not None: with torch.npu.stream(current_ms_metadata.comm_stream): diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index fd5d6c681..27d013172 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -132,7 +132,11 @@ # rpc communication listening port, which will be used to receive the agent metadata from the # remote worker. "VLLM_LLMDD_RPC_PORT": - lambda: int(os.getenv("VLLM_LLMDD_RPC_PORT", 5557)) + lambda: int(os.getenv("VLLM_LLMDD_RPC_PORT", 5557)), + # Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible + # and the mla_pa will be the default path of deepseek decode path. + "VLLM_ASCEND_MLA_PA": + lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0)) } # end-env-vars-definition diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index f92275e23..da12a6509 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -314,7 +314,7 @@ def forward( is_prefill = is_prefill or attn_metadata.with_prefill_across_dp # If this node is kv_consumer, we force the moe always runs in decode path to make sure # the behaviour aligned between dummy_run and normal model_execute. - if self.kv_consumer is not None: + if self.kv_consumer: is_prefill = False enable_force_load_balance = False diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 6949abada..d2618d69f 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -121,7 +121,10 @@ def fused_experts_with_mc2( if log2phy: topk_ids = log2phy[topk_ids] global_bs = 0 - moe_expert_num = len(expert_map) + global_redundant_expert_num + if (expert_map is not None): + moe_expert_num = len(expert_map) + global_redundant_expert_num + else: + moe_expert_num = global_redundant_expert_num # hidden_states = hidden_states.bfloat16() kwargs_mc2 = { "x": hidden_states,