Skip to content

Conversation

@wzhao18
Copy link

@wzhao18 wzhao18 commented Jan 24, 2026

Purpose

This PR fixes a bug in the CPU offloading logic where weights re-initialized during quantization processing are not being correctly re-offloaded to the CPU when using Unified Virtual Addressing (UVA).

Bug:

During process_weights_after_loading, certain quantization methods replace original model weights with newly initialized tensors (e.g., here). These new tensors are initialized as device tensors. The current logic fails to re-offload these "replaced" tensors to CPU memory. As a result, weights that were supposed to be offloaded end up on the GPU, undoing the effects of offloading.

In contrast, the non-UVA code path (currently disabled) correctly handles this by re-offloading weights by name. This PR brings that consistency to the UVA path.

Test Plan

As an example of the bug, running deepseek r1 nvfp4 (~350 GB) on GB300 (288 GB) with 200 GB CPU offloading results in OOM:

Current main:

python3 -m vllm.entrypoints.openai.api_server  \
    --model nvidia/DeepSeek-R1-NVFP4     \
    --trust-remote-code    \
    --cpu-offload-gb 200 \
    --load-format dummy
File "vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py", line 539, in prepare_nvfp4_moe_layer_for_fi_or_cutlass
     w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   File "vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py", line 278, in prepare_static_weights_for_trtllm_fp4_moe
     gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.50 GiB. GPU 0 has a total capacity of 276.62 GiB of which 1.18 GiB is free.

Test Result

Using this PR allows the above configuration to run.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
@wzhao18 wzhao18 requested a review from 22quinn as a code owner January 24, 2026 03:13
@mergify mergify bot added nvidia bug Something isn't working labels Jan 24, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug with CPU offloading using Unified Virtual Addressing (UVA), where weights re-initialized during quantization were not being correctly re-offloaded. The changes introduce a mechanism to track and re-offload these replaced tensors, ensuring the offloading works as expected. The modification in csrc/cuda_view.cu to use a custom deleter for managing tensor lifetimes is a robust improvement. My review includes one suggestion to optimize the re-offloading logic to avoid unnecessary data copies for non-replaced parameters, which will improve efficiency.

Comment on lines +171 to +175
if name in uva_offloaded_parameters:
assert pin_memory, "UVA offloaded parameters must be pinned"
cpu_data = p.data.to("cpu").pin_memory()
p.data = get_cuda_view_from_cpu_tensor(cpu_data)
p._vllm_is_uva_offloaded = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic for re-offloading UVA parameters is correct but could be more efficient. It re-offloads any parameter that was originally UVA-offloaded, even if it was not replaced within the device_loading_context. This results in unnecessary data copies for parameters that are already correctly offloaded.

To optimize this, we should only re-offload parameters that were originally UVA-offloaded but have since been replaced (i.e., they no longer have the _vllm_is_uva_offloaded attribute). This makes the check more specific and avoids unnecessary work.

Suggested change
if name in uva_offloaded_parameters:
assert pin_memory, "UVA offloaded parameters must be pinned"
cpu_data = p.data.to("cpu").pin_memory()
p.data = get_cuda_view_from_cpu_tensor(cpu_data)
p._vllm_is_uva_offloaded = True
if name in uva_offloaded_parameters and not getattr(
p, "_vllm_is_uva_offloaded", False):
assert pin_memory, "UVA offloaded parameters must be pinned"
cpu_data = p.data.to("cpu").pin_memory()
p.data = get_cuda_view_from_cpu_tensor(cpu_data)
p._vllm_is_uva_offloaded = True

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working nvidia

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant