Open
Description
Your current environment
The output of python collect_env.py
# python3 collect_env.py
INFO 06-20 21:32:16 [__init__.py:239] Automatically detected platform cuda.
Collecting environment information...
/usr/local/lib/python3.10/dist-packages/_distutils_hack/__init__.py:30: UserWarning: Setuptools is replacing distutils. Support for replacing an already imported distutils is deprecated. In the future, this condition will fail. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml
warnings.warn(
==============================
System Info
==============================
OS : Ubuntu 22.04.4 LTS (x86_64)
GCC version : (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version : Could not collect
CMake version : version 3.27.5
Libc version : glibc-2.35
==============================
PyTorch Info
==============================
PyTorch version : 2.6.0+cu124
Is debug build : False
CUDA used to build PyTorch : 12.4
ROCM used to build PyTorch : N/A
==============================
Python Environment
==============================
Python version : 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] (64-bit runtime)
Python platform : Linux-5.4.0-162-generic-x86_64-with-glibc2.35
==============================
CUDA / GPU Info
==============================
Is CUDA available : True
CUDA runtime version : 12.4.131
CUDA_MODULE_LOADING set to : LAZY
GPU models and configuration :
GPU 0: NVIDIA H20
GPU 1: NVIDIA H20
GPU 2: NVIDIA H20
GPU 3: NVIDIA H20
GPU 4: NVIDIA H20
GPU 5: NVIDIA H20
GPU 6: NVIDIA H20
GPU 7: NVIDIA H20
Nvidia driver version : 535.161.08
cuDNN version : Could not collect
HIP runtime version : N/A
MIOpen runtime version : N/A
Is XNNPACK available : True
==============================
CPU Info
==============================
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 192
On-line CPU(s) list: 0-191
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8457C
CPU family: 6
Model: 143
Thread(s) per core: 2
Core(s) per socket: 48
Socket(s): 2
Stepping: 8
CPU max MHz: 3800.0000
CPU min MHz: 800.0000
BogoMIPS: 5200.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid cldemote movdiri movdir64b md_clear pconfig flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 4.5 MiB (96 instances)
L1i cache: 3 MiB (96 instances)
L2 cache: 192 MiB (96 instances)
L3 cache: 195 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-47,96-143
NUMA node1 CPU(s): 48-95,144-191
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
==============================
Versions of relevant libraries
==============================
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-cufile-cu12==1.11.1.6
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-ml-py==12.560.30
[pip3] nvidia-modelopt==0.25.0
[pip3] nvidia-modelopt-core==0.25.0
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] onnx==1.17.0
[pip3] onnx_graphsurgeon==0.5.8
[pip3] pynvml==12.0.0
[pip3] pyzmq==26.2.0
[pip3] torch==2.6.0
[pip3] torchaudio==2.6.0
[pip3] torchprofile==0.0.4
[pip3] torchvision==0.21.0
[pip3] transformers==4.51.2
[pip3] triton==3.2.0
[conda] Could not collect
==============================
vLLM Info
==============================
ROCM Version : Could not collect
Neuron SDK Version : N/A
vLLM Version : 0.8.5.post2.dev1+g35a4376c9.d20250618 (git sha: 35a4376c9, date: 20250618)
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 NIC0 NIC1 NIC2 NIC3 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV18 NV18 NV18 NV18 NV18 NV18 NV18 PIX NODE SYS SYS 0-47,96-143 0 N/A
GPU1 NV18 X NV18 NV18 NV18 NV18 NV18 NV18 NODE NODE SYS SYS 0-47,96-143 0 N/A
GPU2 NV18 NV18 X NV18 NV18 NV18 NV18 NV18 NODE PIX SYS SYS 0-47,96-143 0 N/A
GPU3 NV18 NV18 NV18 X NV18 NV18 NV18 NV18 NODE NODE SYS SYS 0-47,96-143 0 N/A
GPU4 NV18 NV18 NV18 NV18 X NV18 NV18 NV18 SYS SYS PIX NODE 48-95,144-191 1 N/A
GPU5 NV18 NV18 NV18 NV18 NV18 X NV18 NV18 SYS SYS NODE NODE 48-95,144-191 1 N/A
GPU6 NV18 NV18 NV18 NV18 NV18 NV18 X NV18 SYS SYS NODE PIX 48-95,144-191 1 N/A
GPU7 NV18 NV18 NV18 NV18 NV18 NV18 NV18 X SYS SYS NODE NODE 48-95,144-191 1 N/A
NIC0 PIX NODE NODE NODE SYS SYS SYS SYS X NODE SYS SYS
NIC1 NODE NODE PIX NODE SYS SYS SYS SYS NODE X SYS SYS
NIC2 SYS SYS SYS SYS PIX NODE NODE NODE SYS SYS X NODE
NIC3 SYS SYS SYS SYS NODE NODE PIX NODE SYS SYS NODE X
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
NIC Legend:
NIC0: mlx5_1
NIC1: mlx5_2
NIC2: mlx5_3
NIC3: mlx5_4
==============================
Environment Variables
==============================
NVIDIA_VISIBLE_DEVICES=all
NVIDIA_REQUIRE_CUDA=cuda>=12.4 brand=tesla,driver>=470,driver<471 brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 brand=quadro,driver>=470,driver<471 brand=quadrortx,driver>=470,driver<471 brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471 brand=tesla,driver>=525,driver<526 brand=unknown,driver>=525,driver<526 brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526 brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 brand=quadro,driver>=525,driver<526 brand=quadrortx,driver>=525,driver<526 brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526 brand=tesla,driver>=535,driver<536 brand=unknown,driver>=535,driver<536 brand=nvidia,driver>=535,driver<536 brand=nvidiartx,driver>=535,driver<536 brand=geforce,driver>=535,driver<536 brand=geforcertx,driver>=535,driver<536 brand=quadro,driver>=535,driver<536 brand=quadrortx,driver>=535,driver<536 brand=titan,driver>=535,driver<536 brand=titanrtx,driver>=535,driver<536
NCCL_MIN_NCHANNELS=24
NCCL_VERSION=2.20.5-1
NVIDIA_DRIVER_CAPABILITIES=compute,utility
NCCL_DEBUG=INFO
VLLM_WORKER_MULTIPROC_METHOD=spawn
NVIDIA_PRODUCT_NAME=CUDA
CUDA_VERSION=12.4.0
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
NCCL_IB_QPS_PER_CONNECTION=8
LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
VLLM_HOST_IP=192.168.12.2
NCCL_CUMEM_ENABLE=0
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY```
</details>
### 🐛 Describe the bug
### Summary
We discovered a significant memory issue with the latest `cutlass_moe_fp8()` implementation in the `main` branch of vLLM. The function now allocates **at least 24.5 GB** of GPU memory per rank, which appears **abnormally high** and prevents us from rebasing our [PR #19843](https://github.com/vllm-project/vllm/pull/19843) onto `main`.
This issue did **not** exist in `v0.8.5.post1`, where the memory footprint of the same function was much smaller under otherwise identical conditions.
### Error Log
<details>
<summary>Click to expand</summary>
```python
...
File "/nvme0n1/jack/vllm-w8a8-cutlass/vllm/worker/model_runner.py", line 1426, in _dummy_run
self.execute_model(model_input, kv_caches, intermediate_tensors)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/nvme0n1/jack/vllm-w8a8-cutlass/vllm/worker/model_runner.py", line 1844, in execute_model
hidden_or_intermediate_states = model_executable(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/nvme0n1/jack/vllm-w8a8-cutlass/vllm/model_executor/models/deepseek_v2.py", line 714, in forward
hidden_states = self.model(input_ids, positions, intermediate_tensors,
File "/nvme0n1/jack/vllm-w8a8-cutlass/vllm/compilation/decorators.py", line 173, in __call__
return self.forward(*args, **kwargs)
File "/nvme0n1/jack/vllm-w8a8-cutlass/vllm/model_executor/models/deepseek_v2.py", line 672, in forward
hidden_states, residual = layer(positions, hidden_states, residual)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/nvme0n1/jack/vllm-w8a8-cutlass/vllm/model_executor/models/deepseek_v2.py", line 592, in forward
hidden_states = self.mlp(hidden_states)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/nvme0n1/jack/vllm-w8a8-cutlass/vllm/model_executor/models/deepseek_v2.py", line 160, in forward
final_hidden_states = self.experts(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/nvme0n1/jack/vllm-w8a8-cutlass/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 426, in forward
fused_out = self._do_fused_experts(
File "/nvme0n1/jack/vllm-w8a8-cutlass/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 318, in _do_fused_experts
workspace13 = torch.zeros(workspace13_shape,
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 24.50 GiB. GPU 0 has a total capacity of 95.22 GiB of which 13.66 GiB is free. Process 4078568 has 81.55 GiB memory in use. Of the allocated memory 79.79 GiB is allocated by PyTorch, and 108.60 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[rank0]:[W618 00:18:45.548079750 ProcessGroupNCCL.cpp:1476] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
💡 Context
- Model: DeepSeek-R1
- GPUs: 8x NVIDIA H20
- Tensor Parallelism (TP): 8
- Quantization:
--quantization fp8
- vLLM Version (working):
v0.8.5.post1
- vLLM Version (issue):
main
branch (latest: 6bc7b57) - Backend: Cutlass MoE FP8 (
cutlass_moe_fp8()
)
🔁 Reproduction Steps
It's a bit hard but
- Check out the latest
main
branch of vLLM (6bc7b57). - Apply changes PR #19843
- Follow the setup procedure described in PR #19843 to run the test. (VLLM_USE_CUTLASS_MOE_FP8=1)
- Hit the error
🧨 Observed Behavior
- GPU memory usage increases by at least 24.5 GB per rank due to a tensor allocation inside
cutlass_moe_fp8()
. (See Error Log) - The same workload on
v0.8.5.post1
uses much less memory, with no functional degradation.
✅ Expected Behavior
- Memory usage should remain similar to previous versions.
🔗 Related PR
This issue blocks the validation and merging of our contribution:
- PR: #19843 — Add Cutlass backend for MoE FP8 workloads
- Summary: We introduced a Cutlass kernel for MoE FP8 workloads with strong throughput gains. It is currently tested and validated under
v0.8.5.post1
, but we are unable to rebase it ontomain
due to the unexpected memory increase.
- Summary: We introduced a Cutlass kernel for MoE FP8 workloads with strong throughput gains. It is currently tested and validated under
🙏 Request
We would really appreciate:
- A review or fix of the memory allocation behavior in the latest
cutlass_moe_fp8()
implementation. - Any guidance that would help us rebase and verify our PR (Add Cutlass integration for MoE FP8 #19843) against
main
.
Thanks!
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.