Skip to content

use npu_moe_gating_top_k_softmax #1355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions tests/e2e/singlecard/ops/test_gating_top_k_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
import torch
import torch_npu


@pytest.mark.parametrize(
'B',
[1, 16, 64, 128, 32768],
)
@pytest.mark.parametrize(
'D',
[8, 16, 32, 64, 128],
)
@pytest.mark.parametrize(
'top_k',
[1, 2, 4, 8],
)
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
(torch.float16, 1e-3, 1e-3),
(torch.bfloat16, 1e-3, 1e-3),
],
)
def test_quant_fpx_linear(B: int, D: int, top_k: int, dtype, atol, rtol):
x = torch.rand((B, D), dtype=dtype).to("npu")
# finished = torch.randint(1, size=(B,), dtype=torch.bool).to("npu")
finished = None
y, expert_idx, row_idx = torch_npu.npu_moe_gating_top_k_softmax(x,
finished,
k=top_k)

topk_weights = x.softmax(dim=-1)
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
topk_ids = topk_ids.to(torch.int32)
torch.allclose(y, topk_weights, atol=atol, rtol=rtol)
torch.allclose(expert_idx, topk_ids, atol=atol, rtol=rtol)
2 changes: 2 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@
# value to False to disable the optimized model.
"USE_OPTIMIZED_MODEL":
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
"SELECT_GATING_TOPK_SOTFMAX_EXPERTS":
lambda: bool(int(os.getenv("SELECT_GATING_TOPK_SOTFMAX_EXPERTS", '0'))),
}

# end-env-vars-definition
Expand Down
40 changes: 26 additions & 14 deletions vllm_ascend/ops/common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@
from vllm.model_executor.layers.fused_moe.layer import \
UnquantizedFusedMoEMethod

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_310p,
select_experts)
select_experts,
select_gating_top_k_softmax_experts)
from vllm_ascend.utils import is_310p

SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS


def forward_oot(
self,
Expand All @@ -44,19 +48,27 @@
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
global_num_experts=global_num_experts,
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)

if SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
topk_weights, topk_ids = select_gating_top_k_softmax_experts(

Check warning on line 53 in vllm_ascend/ops/common_fused_moe.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/ops/common_fused_moe.py#L52-L53

Added lines #L52 - L53 were not covered by tests
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize)
else:
topk_weights, topk_ids = select_experts(

Check warning on line 59 in vllm_ascend/ops/common_fused_moe.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/ops/common_fused_moe.py#L59

Added line #L59 was not covered by tests
global_num_experts=global_num_experts,
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)

if is_310p():
assert global_num_experts is not None
Expand Down
40 changes: 40 additions & 0 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
npu_wait_tensor)

MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS


def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
Expand Down Expand Up @@ -812,6 +813,39 @@
return final_hidden_states


def select_gating_top_k_softmax_experts(
hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int,
renormalize: bool) -> tuple[torch.Tensor, torch.Tensor]:
"""
Select top-k experts based on router logits.
only supports float16、bfloat16、float32

Args:
hidden_states: Hidden states of shape (num_tokens, hidden_size).
router_logits: Router logits of shape (num_tokens, num_experts).
top_k: Number of experts to select.
renormalize: Whether to renormalize the routing weights.

Returns:
topk_weights: Routing weights of shape (num_tokens, top_k).
topk_ids: Selected expert IDs of shape (num_tokens, top_k).

Raises:
ValueError: If an unsupported scoring function is provided.
"""
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(

Check warning on line 836 in vllm_ascend/ops/fused_moe.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/ops/fused_moe.py#L836

Added line #L836 was not covered by tests
router_logits, None, k=top_k)

# # Required by npu_moe_init_routing
# topk_weights = topk_weights.to(hidden_states.dtype)
# topk_ids = topk_ids.to(torch.int32)

if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

Check warning on line 844 in vllm_ascend/ops/fused_moe.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/ops/fused_moe.py#L843-L844

Added lines #L843 - L844 were not covered by tests

return topk_weights, topk_ids

Check warning on line 846 in vllm_ascend/ops/fused_moe.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/ops/fused_moe.py#L846

Added line #L846 was not covered by tests


def native_grouped_topk(
topk_weights: torch.Tensor,
num_expert_group: Optional[int],
Expand Down Expand Up @@ -1003,6 +1037,12 @@
# y2_flag=False, # old api; 第三个输出是否输出
routed_scaling_factor=1,
eps=float(1e-20))
elif SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
topk_weights, topk_ids = select_gating_top_k_softmax_experts(

Check warning on line 1041 in vllm_ascend/ops/fused_moe.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/ops/fused_moe.py#L1040-L1041

Added lines #L1040 - L1041 were not covered by tests
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize)
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
Expand Down
Loading