Skip to content

[WIP][Kernel]FusedMoE LoRA #21229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

CNTRYROA
Copy link
Contributor

@CNTRYROA CNTRYROA commented Jul 19, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

This PR is an experimental extension to FusedMoE, aiming to support parallel inference with multiple LoRA models.

✅ Completed:

  • Verified to work under both TP and torch.compile modes.
  • Verified correctness on the DeepSeek-V2-Lite-Chat model.

❌ Work in progress:

  • Clear and comprehensive code comments are still missing.
  • Kernel generalization: support for quantization, multiple data types, and compatibility with various GPU architectures.
  • Extend FusedMoEPackedLoRALayerWeights to support the new packed kind.
  • Add a create_dummy_lora_weights method to support FusedMoEPackedLoRALayerWeights.
  • No test coverage yet.

Any suggestions are welcome!

Test Plan

vllm serve deepseek-ai/DeepSeek-V2-Lite-Chat --trust-remote-code -tp 2  --enable-lora --max-loras 6 --lora-modules  lora1=wuchen01/DeepSeek-V2-Lite-Chat-All-LoRA --max-lora-rank 64 
curl --location 'http://127.0.0.1:8000/v1/chat/completions' \
--header 'Content-Type: application/json' \
--data '{
    "stream":false,
    "messages": [
        {"role": "system", "content": "You are a helpful assistant."},
        {
            "role": "user",
            "content": "who are you?"
        }
    ],
    "model": "lora1"
    }'

Test Result

(Optional) Documentation Update

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Jul 19, 2025
Copy link

mergify bot commented Jul 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @CNTRYROA.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added deepseek Related to DeepSeek models needs-rebase labels Jul 19, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an experimental extension for FusedMoE to support parallel inference with multiple LoRA models. The changes are extensive, touching CUDA kernels, Triton kernels, and various Python layers for model execution and LoRA management. The core idea is to enable LoRA adapters for Mixture-of-Experts layers, which is a significant feature enhancement.

My review focuses on the correctness and robustness of the new kernels and their integration. I've identified a few critical and high-severity issues that should be addressed:

  • A race condition in the new CUDA kernel (moe_lora_align_sum_kernels.cu) could lead to incorrect behavior or memory corruption.
  • Incorrect shared memory calculation in the same CUDA kernel could lead to resource exhaustion and launch failures.
  • Hardcoded values for the number of experts and LoRA rank in the Triton kernels and Python wrappers limit the general applicability of this feature.

These issues are important to fix to ensure the stability and correctness of this new feature. The PR is a work-in-progress, and addressing these points will significantly improve its quality and readiness for merging.

Comment on lines 121 to 146
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t rank_post_pad =
tokens_cnts[dim3_index(lora_id, blockDim.x + 1, num_experts,
threadIdx.x, expert_id)] +
cumsum[dim2_index(blockDim.x + 1, lora_id, expert_id)];
int mask = token_lora_mapping[i / topk_num] == lora_id;
// sorted_token_ids[dim2_index(max_num_tokens_padded,lora_id,rank_post_pad)]
// = i;
// ++tokens_cnts[dim3_index(lora_id,blockDim.x+1,num_experts, threadIdx.x,
// expert_id)];
sorted_token_ids[dim2_index(max_num_tokens_padded, lora_id,
rank_post_pad)] =
sorted_token_ids[dim2_index(max_num_tokens_padded, lora_id,
rank_post_pad)] *
(1 - mask) +
i * mask;
tokens_cnts[dim3_index(lora_id, blockDim.x + 1, num_experts, threadIdx.x,
expert_id)] += mask;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a race condition in this loop. When mask is 0, tokens_cnts is not incremented for the current expert_id. This means that multiple tokens within the same thread that have mask=0 and map to the same expert will compute the same rank_post_pad. This leads to a write race on sorted_token_ids, as multiple iterations will attempt to write to the same memory location.

The multiplication trick ... * (1 - mask) + ... * mask is not safe here because the write location is not unique for each iteration when mask is 0.

To fix this, the write to sorted_token_ids and the increment of tokens_cnts should be inside a conditional block.

    if (token_lora_mapping[i / topk_num] == lora_id) {
      int32_t expert_id = topk_ids[i];
      /** The cumsum[expert_id] stores the starting index of the tokens that the
       * expert with expert_id needs to process, and
       * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
       * processed by the expert with expert_id within the current thread's token
       * shard.
       */
      int32_t rank_post_pad =
          tokens_cnts[dim3_index(lora_id, blockDim.x + 1, num_experts,
                                 threadIdx.x, expert_id)] +
          cumsum[dim2_index(blockDim.x + 1, lora_id, expert_id)];
      sorted_token_ids[dim2_index(max_num_tokens_padded, lora_id,
                                  rank_post_pad)] = i;
      tokens_cnts[dim3_index(lora_id, blockDim.x + 1, num_experts, threadIdx.x,
                             expert_id)]++;
    }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand — if it's the same thread accessing the same memory region multiple times, why would there be a race condition? The accesses within a thread should be executed serially, right?

Comment on lines 171 to 173
const int32_t shared_mem_i16 =
((num_thread + 1) * num_experts) * max_loras * sizeof(uint16_t) +
(num_experts + 1) * max_loras * sizeof(int32_t);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The shared memory size calculation for tokens_cnts appears to be incorrect. It uses num_thread, which is max((int32_t)num_experts, WARP_SIZE), but the kernel launch configuration uses num_experts for blockDim.x, and the kernel code itself uses blockDim.x for indexing.

If num_experts < WARP_SIZE, this will allocate more shared memory than necessary. This is inefficient and could lead to launch failures if it exceeds device limits for no reason. The calculation should consistently use num_experts instead of num_thread to match the kernel's logic.

  const int32_t shared_mem_i16 =
      ((num_experts + 1) * num_experts) * max_loras * sizeof(uint16_t) +
      (num_experts + 1) * max_loras * sizeof(int32_t);

Comment on lines 94 to 95
if expert_id >= 64:
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The number of experts is hardcoded to 64. This limits the kernel's applicability to models with exactly 64 experts (like DeepSeek-V2). For this kernel to be general, the number of experts should be passed as a parameter.

You should add num_experts: tl.constexpr to the kernel signature and use it for the check.

Suggested change
if expert_id >= 64:
return
if expert_id >= num_experts:
return

@jeejeelee jeejeelee self-assigned this Jul 20, 2025
@CNTRYROA CNTRYROA requested a review from WoosukKwon as a code owner July 22, 2025 10:45
@mergify mergify bot removed the needs-rebase label Jul 22, 2025
Copy link

mergify bot commented Jul 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @CNTRYROA.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build deepseek Related to DeepSeek models needs-rebase
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants