Skip to content

Commit b2fac67

Browse files
authored
[P/D] Heterogeneous TP (#18833)
Signed-off-by: nicklucche <[email protected]>
1 parent 23027e2 commit b2fac67

File tree

6 files changed

+287
-100
lines changed

6 files changed

+287
-100
lines changed

tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ MODELS=(
88

99
# Number of prefill and decode instances to create
1010
NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1
11-
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-2} # Default to 2
11+
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1
12+
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
13+
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
1214

1315
# Find the git repository root directory
1416
GIT_ROOT=$(git rev-parse --show-toplevel)
@@ -74,9 +76,10 @@ run_tests_for_model() {
7476
for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do
7577
# Calculate GPU ID - we'll distribute across available GPUs
7678
GPU_ID=$((i % $(get_num_gpus)))
79+
7780
# Calculate port number (base port + instance number)
7881
PORT=$((8100 + i))
79-
# Calculate side channel port
82+
# Calculate side channel port. Avoid clash with with TP workers.
8083
SIDE_CHANNEL_PORT=$((5559 + i))
8184

8285
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
@@ -87,6 +90,7 @@ run_tests_for_model() {
8790
--enforce-eager \
8891
--disable-log-requests \
8992
--gpu-memory-utilization 0.2 \
93+
--tensor-parallel-size $PREFILLER_TP_SIZE \
9094
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
9195

9296
if [ -n "$model_args" ]; then
@@ -109,7 +113,7 @@ run_tests_for_model() {
109113
# Calculate port number (base port + instance number)
110114
PORT=$((8200 + i))
111115
# Calculate side channel port
112-
SIDE_CHANNEL_PORT=$((5659 + i))
116+
SIDE_CHANNEL_PORT=$((5659 + i * $DECODER_TP_SIZE))
113117

114118
echo "Starting decode instance $i on GPU $GPU_ID, port $PORT"
115119

@@ -119,6 +123,7 @@ run_tests_for_model() {
119123
--enforce-eager \
120124
--disable-log-requests \
121125
--gpu-memory-utilization 0.2 \
126+
--tensor-parallel-size $DECODER_TP_SIZE \
122127
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
123128

124129
if [ -n "$model_args" ]; then

tests/v1/kv_connector/nixl_integration/test_accuracy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# Model-specific expected values
1515
EXPECTED_VALUES = {
1616
"Qwen/Qwen3-0.6B": 0.41,
17+
"deepseek-ai/deepseek-vl2-small": 0.59
1718
}
1819

1920
SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means", # noqa: E501

vllm/distributed/kv_transfer/kv_connector/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
"""
44
KV cache helper for store.
55
"""
6+
67
import torch
78

89
import vllm.envs as envs
910
from vllm import _custom_ops as ops
10-
from vllm.config import VllmConfig
11+
from vllm.config import VllmConfig, get_current_vllm_config
1112
from vllm.logger import init_logger
1213

1314
logger = init_logger(__name__)
@@ -90,3 +91,18 @@ def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values,
9091
layer.self_attn.attn._k_scale,
9192
layer.self_attn.attn._v_scale,
9293
)
94+
95+
96+
def get_kv_connector_cache_layout():
97+
vllm_config = get_current_vllm_config()
98+
kv_config = vllm_config.kv_transfer_config
99+
if vllm_config.model_config is None:
100+
logger.warning("Unable to detect current VLLM config. " \
101+
"Defaulting to NHD kv cache layout.")
102+
else:
103+
use_mla = vllm_config.model_config.use_mla
104+
if not use_mla and kv_config.kv_connector == "NixlConnector":
105+
logger.info("NixlConnector detected. Setting KV cache " \
106+
"layout to HND for better xfer performance.")
107+
return "HND"
108+
return "NHD"

0 commit comments

Comments
 (0)