Skip to content

Conversation

@zyongye
Copy link
Member

@zyongye zyongye commented May 23, 2025

FIX #16294

Can only merge when migrating to pytorch 3.8 and triton 3.4

Adding triton_kernels from Triton repo.

Performance breakdown:

median_ttft_ms
output_throughput
median_tpot_ms

Setup:
CPU: Intel Xeon Gold 6126 CPU @ 2.60GHz
GPU: RTX 8xA6000, driver version 560.35.05, CUDA 12.6 (shard model with TP if needed)

Dataset: SharedGPT with 100 requests (infqps)

Dependencies:

  • Triton forks here. This fork adds renormalize=False support for models like qwen1.5 and mixtral 7x8b.
  • Torch nightly build to mitigate bugs that come from the new Triton version

(Updates 06.04.25: Since the triton accepts our PR, we only need nightly build pytorch and triton)

To run the new kernel, set the environment variable VLLM_USE_EXP_MOE=1

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify
Copy link

mergify bot commented May 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @zyongye.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link

mergify bot commented May 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @zyongye.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 23, 2025
@bnellnm
Copy link
Collaborator

bnellnm commented May 27, 2025

@zyongye can you implement this using the framework described in modular_kernel.py? That way it'll be possible to easily use the new kernels with different communication mechanisms, e.g. pplx + deepep

@bnellnm
Copy link
Collaborator

bnellnm commented Jun 4, 2025

I didn't realize that these kernels were so incompatible with the current modular architecture. I think we should actually leave the triton kernels as a standalone MoE alternative until we can come up with a more general framework to handle them.

@zyongye zyongye changed the title Porting triton_kernels for FusedMoE [Kernel] Porting triton_kernels for FusedMoE Jun 4, 2025
@mergify mergify bot removed the needs-rebase label Jun 4, 2025
@zyongye zyongye marked this pull request as ready for review June 4, 2025 22:31
@simon-mo simon-mo requested a review from zhuohan123 June 5, 2025 04:39
@mergify
Copy link

mergify bot commented Jun 5, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @zyongye.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 5, 2025
@mergify mergify bot removed the needs-rebase label Jun 10, 2025
@mergify mergify bot added the qwen Related to Qwen models label Jun 19, 2025
@mergify
Copy link

mergify bot commented Jun 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @zyongye.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added needs-rebase performance Performance-related issues labels Jun 19, 2025
@yuan-luo
Copy link

yuan-luo commented Jun 27, 2025

@zyongye I'm porting triton_kernels to SGLang. Currently I managed to replace SGLang fused_moe with the triton kernel and added unit test to verify that the new Triton kernel is invoked as expected. But the problem is E2E failed in Triton 3.4.0 (nightly-build). I'm curious about how you ported this kernel. Did you upgrade Triton to 3.4.0 directly, or port this new matmul-og required 3.4.0 Triton kernels into 3.3.1 on-demand?

I also tried to port Triton v3.4.0's new functions to v3.3.1 so as to make E2E work, eventually I gave up this approach because the work scope does not converge.

Update:
I fixed the issue in Triton 3.4.0. Now new Triton MoE kernel works in SGLang.

intermediate_cache2.shape:torch.Size([264, 384])
[DEBUG] ======= triton_kernel_moe_forward
hidden.shape:torch.Size([33, 2048])
w1.shape:torch.Size([128, 768, 2048])
w2.shape:torch.Size([128, 2048, 384])
intermediate_cache1.shape:torch.Size([264, 768])
intermediate_cache2.shape:torch.Size([264, 384])
[2025-06-27 19:18:26 TP0] Decode batch. #running-req: 1, #token: 66, token usage: 0.00, cuda graph: True, gen throughput (token/s): 1.79, #queue-req: 0
[2025-06-27 19:18:31 TP0] Decode batch. #running-req: 1, #token: 106, token usage: 0.00, cuda graph: True, gen throughput (token/s): 8.13, #queue-req: 0
[2025-06-27 19:18:36 TP0] Decode batch. #running-req: 1, #token: 146, token usage: 0.00, cuda graph: True, gen throughput (token/s): 8.12, #queue-req: 0
[2025-06-27 19:18:41 TP0] Decode batch. #running-req: 1, #token: 186, token usage: 0.00, cuda graph: True, gen throughput (token/s): 8.11, #queue-req: 0
[2025-06-27 19:18:46 TP0] Decode batch. #running-req: 1, #token: 226, token usage: 0.00, cuda graph: True, gen throughput (token/s): 8.11, #queue-req: 0
[2025-06-27 19:18:46] INFO:     127.0.0.1:48648 - "POST /v1/chat/completions HTTP/1.1" 200 OK

But in Triton 3.4.0 the performance is pretty slow. The previous gen throughput (token/s): is 200, and now is gen throughput (token/s): 7.99.

[2025-06-28 12:02:48 TP0] Load weight end. type=Qwen3MoeForCausalLM, dtype=torch.bfloat16, avail mem=65.44 GB, mem usage=28.57 GB.
[2025-06-28 12:02:49 TP0] KV Cache is allocated. #tokens: 1150305, K size: 26.33 GB, V size: 26.33 GB
[2025-06-28 12:02:49 TP0] Memory pool end. avail mem=12.13 GB
[2025-06-28 12:02:49 TP1] KV Cache is allocated. #tokens: 1150305, K size: 26.33 GB, V size: 26.33 GB
[2025-06-28 12:02:49 TP0] max_total_num_tokens=1150305, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=4096, context_len=40960, available_gpu_mem=12.03 GB
[2025-06-28 12:02:49] INFO:     Started server process [269830]
[2025-06-28 12:02:49] INFO:     Waiting for application startup.
[2025-06-28 12:02:49] INFO:     Application startup complete.
[2025-06-28 12:02:49] INFO:     Uvicorn running on http://127.0.0.1:30000 (Press CTRL+C to quit)
[2025-06-28 12:02:50] INFO:     127.0.0.1:48128 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-06-28 12:02:50 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-06-28 12:03:06] INFO:     127.0.0.1:48136 - "POST /generate HTTP/1.1" 200 OK
[2025-06-28 12:03:06] The server is fired up and ready to roll!
[2025-06-28 12:03:28 TP0] Prefill batch. #new-seq: 1, #new-token: 33, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-06-28 12:03:35 TP0] Decode batch. #running-req: 1, #token: 66, token usage: 0.00, cuda graph: False, gen throughput (token/s): 0.86, #queue-req: 0
[2025-06-28 12:03:40 TP0] Decode batch. #running-req: 1, #token: 106, token usage: 0.00, cuda graph: False, gen throughput (token/s): 8.01, #queue-req: 0
[2025-06-28 12:03:45 TP0] Decode batch. #running-req: 1, #token: 146, token usage: 0.00, cuda graph: False, gen throughput (token/s): 8.01, #queue-req: 0
[2025-06-28 12:03:50 TP0] Decode batch. #running-req: 1, #token: 186, token usage: 0.00, cuda graph: False, gen throughput (token/s): 7.99, #queue-req: 0
[2025-06-28 12:03:55 TP0] Decode batch. #running-req: 1, #token: 226, token usage: 0.00, cuda graph: False, gen throughput (token/s): 7.99, #queue-req: 0
[2025-06-28 12:03:56] INFO:     127.0.0.1:49608 - "POST /v1/chat/completions HTTP/1.1" 200 OK

@zyongye
Copy link
Member Author

zyongye commented Jun 27, 2025

@yuan-luo Yes, you need both nightly build triton and pytorch 2.8 to run these kernels.

@yuan-luo
Copy link

@yuan-luo Yes, you need both nightly build triton and pytorch 2.8 to run these kernels.

Hi @zyongye , I use pytorch 2.8.0, and triton nightly build, but the performance is still slow, while the result is correct.

[2025-06-28 17:37:32 TP0] Load weight end. type=Qwen3MoeForCausalLM, dtype=torch.bfloat16, avail mem=65.44 GB, mem usage=28.57 GB.
[2025-06-28 17:37:32 TP1] KV Cache is allocated. #tokens: 1150305, K size: 26.33 GB, V size: 26.33 GB
[2025-06-28 17:37:32 TP0] KV Cache is allocated. #tokens: 1150305, K size: 26.33 GB, V size: 26.33 GB
[2025-06-28 17:37:32 TP0] Memory pool end. avail mem=12.13 GB
[2025-06-28 17:37:32 TP0] Capture cuda graph begin. This can take up to several minutes. avail mem=12.03 GB
Capture cuda graph bs [1, 2, 4, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256]
Capturing batches (bs=1 avail_mem=7.54 GB): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:54<00:00,  1.56s/it]
[2025-06-28 17:38:27 TP0] Registering 3395 cuda graph addresses
[2025-06-28 17:38:27 TP1] Registering 3395 cuda graph addresses
[2025-06-28 17:38:27 TP0] Capture cuda graph end. Time elapsed: 54.89 s. mem usage=4.52 GB. avail mem=7.52 GB.
[2025-06-28 17:38:27 TP0] max_total_num_tokens=1150305, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=4096, context_len=40960, available_gpu_mem=7.52 GB
[2025-06-28 17:38:27] INFO:     Started server process [293154]
[2025-06-28 17:38:27] INFO:     Waiting for application startup.
[2025-06-28 17:38:27] INFO:     Application startup complete.
[2025-06-28 17:38:27] INFO:     Uvicorn running on http://127.0.0.1:30000 (Press CTRL+C to quit)
[2025-06-28 17:38:28] INFO:     127.0.0.1:41556 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-06-28 17:38:28 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-06-28 17:38:31] INFO:     127.0.0.1:41558 - "POST /generate HTTP/1.1" 200 OK
[2025-06-28 17:38:31] The server is fired up and ready to roll!
[2025-06-28 17:38:51 TP0] Prefill batch. #new-seq: 1, #new-token: 33, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-06-28 17:38:57 TP0] Decode batch. #running-req: 1, #token: 66, token usage: 0.00, cuda graph: True, gen throughput (token/s): 1.34, #queue-req: 0
[2025-06-28 17:39:02 TP0] Decode batch. #running-req: 1, #token: 106, token usage: 0.00, cuda graph: True, gen throughput (token/s): 8.12, #queue-req: 0
[2025-06-28 17:39:07 TP0] Decode batch. #running-req: 1, #token: 146, token usage: 0.00, cuda graph: True, gen throughput (token/s): 8.12, #queue-req: 0
[2025-06-28 17:39:12 TP0] Decode batch. #running-req: 1, #token: 186, token usage: 0.00, cuda graph: True, gen throughput (token/s): 8.12, #queue-req: 0
[2025-06-28 17:39:17 TP0] Decode batch. #running-req: 1, #token: 226, token usage: 0.00, cuda graph: True, gen throughput (token/s): 8.12, #queue-req: 0
[2025-06-28 17:39:18] INFO:     127.0.0.1:51480 - "POST /v1/chat/completions HTTP/1.1" 200 OK

@yuan-luo
Copy link

This is the benchmark result for the kernels.

$python ./benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py
Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM.
====imported _matmul_ogs
====imported _matmul_ogs
====imported SwizzlingType
shape_configs={'num_experts': 128, 'topk': 8, 'hidden_size': 2048, 'shard_intermediate_size': 768, 'dtype': torch.bfloat16, 'block_shape': None}
benchmark sglang_fused_moe_triton_v340 with batch_size=512
benchmark sglang_fused_moe_triton with batch_size=512
Config file not found at /opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_H20.json. Fallback to triton version 3.2.0 and use MoE kernel config from /opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20.json. Performance might be sub-optimal!
benchmark sglang_fused_moe_triton_v340 with batch_size=513
benchmark sglang_fused_moe_triton with batch_size=513
benchmark sglang_fused_moe_triton_v340 with batch_size=514
benchmark sglang_fused_moe_triton with batch_size=514
benchmark sglang_fused_moe_triton_v340 with batch_size=515
benchmark sglang_fused_moe_triton with batch_size=515
benchmark sglang_fused_moe_triton_v340 with batch_size=516
benchmark sglang_fused_moe_triton with batch_size=516
benchmark sglang_fused_moe_triton_v340 with batch_size=517
benchmark sglang_fused_moe_triton with batch_size=517
benchmark sglang_fused_moe_triton_v340 with batch_size=518
benchmark sglang_fused_moe_triton with batch_size=518
benchmark sglang_fused_moe_triton_v340 with batch_size=519
benchmark sglang_fused_moe_triton with batch_size=519
fused-moe-performance:
   batch_size  sglang_fused_moe_triton_v340  sglang_fused_moe_triton
0       512.0                      2.833792                 0.404320
1       513.0                      2.831680                 0.401216
2       514.0                      2.829920                 0.395728
3       515.0                      2.830240                 0.395168
4       516.0                      2.831904                 0.385728
5       517.0                      2.842752                 0.388368
6       518.0                      2.832480                 0.391424
7       519.0                      2.833728                 0.393760

@yuan-luo
Copy link

yuan-luo commented Jul 1, 2025

After fixing the logic issue, the performance improved.

b75ba454e0f7:342801:345121 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512
b75ba454e0f7:342801:345121 [0] NCCL INFO 16 coll channels, 16 collnet channels, 0 nvls channels, 16 p2p channels, 16 p2p channels per peer
b75ba454e0f7:342801:345121 [0] NCCL INFO CC Off, workFifoBytes 1048576
b75ba454e0f7:342802:345122 [1] NCCL INFO ncclCommInitRankConfig comm 0x7fdaf466e6d0 rank 1 nranks 2 cudaDev 1 nvmlDev 1 busId 7e000 commId 0x3fa59530b29bedc0 - Init COMPLETE
b75ba454e0f7:342801:345121 [0] NCCL INFO ncclCommInitRankConfig comm 0x7f4188b94110 rank 0 nranks 2 cudaDev 0 nvmlDev 0 busId 8000 commId 0x3fa59530b29bedc0 - Init COMPLETE
b75ba454e0f7:342801:345121 [0] NCCL INFO Init timings - ncclCommInitRankConfig: rank 0 nranks 2 total 0.17 (kernels 0.00, alloc 0.00, bootstrap 0.01, allgathers 0.00, topo 0.06, graphs 0.00, connections 0.10, rest 0.00)
b75ba454e0f7:342802:345122 [1] NCCL INFO Init timings - ncclCommInitRankConfig: rank 1 nranks 2 total 0.16 (kernels 0.00, alloc 0.00, bootstrap 0.00, allgathers 0.00, topo 0.06, graphs 0.00, connections 0.10, rest 0.00)
b75ba454e0f7:342801:344793 [0] NCCL INFO AllGather: 151936 Bytes -> Algo RING proto LL channel{Lo..Hi}={0..14}
[2025-07-01 06:06:05] INFO:     127.0.0.1:60674 - "POST /generate HTTP/1.1" 200 OK
[2025-07-01 06:06:05] The server is fired up and ready to roll!
[2025-07-01 06:06:29 TP0] Prefill batch. #new-seq: 1, #new-token: 33, #cached-token: 0, #token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
b75ba454e0f7:342801:344793 [0] NCCL INFO AllGather: 151936 Bytes -> Algo RING proto LL channel{Lo..Hi}={0..14}
[2025-07-01 06:06:30 TP0] Decode batch. #running-req: 1, #token: 66, token usage: 0.00, cuda graph: True, gen throughput (token/s): 1.41, #queue-req: 0
[2025-07-01 06:06:30 TP0] Decode batch. #running-req: 1, #token: 106, token usage: 0.00, cuda graph: True, gen throughput (token/s): 182.12, #queue-req: 0
[2025-07-01 06:06:30 TP0] Decode batch. #running-req: 1, #token: 146, token usage: 0.00, cuda graph: True, gen throughput (token/s): 181.22, #queue-req: 0
[2025-07-01 06:06:30 TP0] Decode batch. #running-req: 1, #token: 186, token usage: 0.00, cuda graph: True, gen throughput (token/s): 180.17, #queue-req: 0
[2025-07-01 06:06:31 TP0] Decode batch. #running-req: 1, #token: 226, token usage: 0.00, cuda graph: True, gen throughput (token/s): 179.99, #queue-req: 0
[2025-07-01 06:06:31] INFO:     127.0.0.1:32264 - "POST /v1/chat/completions HTTP/1.1" 200 OK
#python ./benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py
INFO 07-01 06:08:26 [__init__.py:244] Automatically detected platform cuda.
shape_configs={'num_experts': 128, 'topk': 8, 'hidden_size': 2048, 'shard_intermediate_size': 768, 'dtype': torch.bfloat16, 'block_shape': None}
benchmark sglang_fused_moe_triton_v340 with batch_size=64
benchmark sglang_fused_moe_triton with batch_size=64
Using default MoE kernel config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.10/dist-packages/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_L20Y.json, you can create them with https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton
benchmark sglang_fused_moe_triton_v340 with batch_size=128
benchmark sglang_fused_moe_triton with batch_size=128
benchmark sglang_fused_moe_triton_v340 with batch_size=256
benchmark sglang_fused_moe_triton with batch_size=256
benchmark sglang_fused_moe_triton_v340 with batch_size=512
benchmark sglang_fused_moe_triton with batch_size=512
benchmark sglang_fused_moe_triton_v340 with batch_size=1024
benchmark sglang_fused_moe_triton with batch_size=1024
benchmark sglang_fused_moe_triton_v340 with batch_size=2048
benchmark sglang_fused_moe_triton with batch_size=2048
benchmark sglang_fused_moe_triton_v340 with batch_size=4096
benchmark sglang_fused_moe_triton with batch_size=4096
benchmark sglang_fused_moe_triton_v340 with batch_size=8192
benchmark sglang_fused_moe_triton with batch_size=8192
benchmark sglang_fused_moe_triton_v340 with batch_size=16384
benchmark sglang_fused_moe_triton with batch_size=16384
fused-moe-performance:
   batch_size  sglang_fused_moe_triton_v340  sglang_fused_moe_triton
0        64.0                      1.486816                 0.435488
1       128.0                      1.493936                 0.432544
2       256.0                      1.378224                 0.432896
3       512.0                      1.387248                 0.416352
4      1024.0                      1.392960                 0.428576
5      2048.0                      1.394656                 0.524800
6      4096.0                      1.395104                 0.893536
7      8192.0                      1.400384                 1.618304
8     16384.0                      1.687072                 3.071136

Finally got the result. But batch_size smaller than 256 performance is still lower than legacy fused_moe.

   batch_size  sglang_fused_moe_triton_v340  sglang_fused_moe_triton
0       128.0                      0.259136                 0.241552
1       256.0                      0.255424                 0.303200
2       512.0                      0.269712                 0.342176
3      1024.0                      0.303376                 0.395168
4      2048.0                      0.366592                 0.516064
5      4096.0                      0.530272                 0.881728
6      8192.0                      0.891840                 1.615424

@zhuohan123 zhuohan123 closed this Sep 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-rebase performance Performance-related issues qwen Related to Qwen models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: Integrate Triton MoE Kernel

4 participants