diff --git a/tests/e2e/singlecard/ops/test_gating_top_k_softmax.py b/tests/e2e/singlecard/ops/test_gating_top_k_softmax.py new file mode 100644 index 000000000..4edcdfdea --- /dev/null +++ b/tests/e2e/singlecard/ops/test_gating_top_k_softmax.py @@ -0,0 +1,37 @@ +import pytest +import torch +import torch_npu + + +@pytest.mark.parametrize( + 'B', + [1, 16, 64, 128, 32768], +) +@pytest.mark.parametrize( + 'D', + [8, 16, 32, 64, 128], +) +@pytest.mark.parametrize( + 'top_k', + [1, 2, 4, 8], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-3, 1e-3), + ], +) +def test_quant_fpx_linear(B: int, D: int, top_k: int, dtype, atol, rtol): + x = torch.rand((B, D), dtype=dtype).to("npu") + # finished = torch.randint(1, size=(B,), dtype=torch.bool).to("npu") + finished = None + y, expert_idx, row_idx = torch_npu.npu_moe_gating_top_k_softmax(x, + finished, + k=top_k) + + topk_weights = x.softmax(dim=-1) + topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) + topk_ids = topk_ids.to(torch.int32) + torch.allclose(y, topk_weights, atol=atol, rtol=rtol) + torch.allclose(expert_idx, topk_ids, atol=atol, rtol=rtol) diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 659924162..50769f4bd 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -121,6 +121,8 @@ # value to False to disable the optimized model. "USE_OPTIMIZED_MODEL": lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))), + "SELECT_GATING_TOPK_SOTFMAX_EXPERTS": + lambda: bool(int(os.getenv("SELECT_GATING_TOPK_SOTFMAX_EXPERTS", '0'))), } # end-env-vars-definition diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 4e21c744a..cb5fb4080 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -21,10 +21,14 @@ from vllm.model_executor.layers.fused_moe.layer import \ UnquantizedFusedMoEMethod +import vllm_ascend.envs as envs_ascend from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge, - select_experts) + select_experts, + select_gating_top_k_softmax_experts) from vllm_ascend.utils import is_310p +SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS + def forward_oot( self, @@ -44,19 +48,27 @@ def forward_oot( apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: - topk_weights, topk_ids = select_experts( - global_num_experts=global_num_experts, - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) + + if SELECT_GATING_TOPK_SOTFMAX_EXPERTS: + topk_weights, topk_ids = select_gating_top_k_softmax_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + renormalize=renormalize) + else: + topk_weights, topk_ids = select_experts( + global_num_experts=global_num_experts, + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) if topk_ids.shape[1] < top_k or is_310p(): assert global_num_experts is not None diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index bea0dc555..97ccd9b7e 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -43,6 +43,7 @@ npu_wait_tensor) MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER +SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, @@ -812,6 +813,39 @@ def fused_experts( return final_hidden_states +def select_gating_top_k_softmax_experts( + hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int, + renormalize: bool) -> tuple[torch.Tensor, torch.Tensor]: + """ + Select top-k experts based on router logits. + only supports float16、bfloat16、float32 + + Args: + hidden_states: Hidden states of shape (num_tokens, hidden_size). + router_logits: Router logits of shape (num_tokens, num_experts). + top_k: Number of experts to select. + renormalize: Whether to renormalize the routing weights. + + Returns: + topk_weights: Routing weights of shape (num_tokens, top_k). + topk_ids: Selected expert IDs of shape (num_tokens, top_k). + + Raises: + ValueError: If an unsupported scoring function is provided. + """ + topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax( + router_logits, None, k=top_k) + + # # Required by npu_moe_init_routing + # topk_weights = topk_weights.to(hidden_states.dtype) + # topk_ids = topk_ids.to(torch.int32) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + def native_grouped_topk( topk_weights: torch.Tensor, num_expert_group: Optional[int], @@ -1004,6 +1038,12 @@ def apply( # y2_flag=False, # old api; 第三个输出是否输出 routed_scaling_factor=1, eps=float(1e-20)) + elif SELECT_GATING_TOPK_SOTFMAX_EXPERTS: + topk_weights, topk_ids = select_gating_top_k_softmax_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + renormalize=renormalize) else: topk_weights, topk_ids = select_experts( hidden_states=x,