Skip to content

Commit c56ed8b

Browse files
[Bugfix][Nixl] Fix full prefix cache hit bug (#18632)
Signed-off-by: [email protected] <[email protected]> Signed-off-by: Nick Hill <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent 78dcf56 commit c56ed8b

File tree

4 files changed

+97
-81
lines changed

4 files changed

+97
-81
lines changed

tests/v1/kv_connector/unit/test_multi_connector.py

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
KVConnectorFactory)
1313
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
1414
SharedStorageConnector)
15+
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
1516

1617
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
1718

@@ -32,7 +33,7 @@ def __init__(self, config: VllmConfig, role):
3233
self.call_record: dict[str, int] = defaultdict(int)
3334
# Use a unique temp file per connector
3435
self._event_file = tempfile.gettempdir(
35-
) + f"/connector_{self.name}_events.log"
36+
) + f"/connector_{self.name}-{self.role.name}_events.log"
3637
# Start with an empty file
3738
with open(self._event_file, "w") as _:
3839
pass
@@ -52,10 +53,19 @@ def __getattribute__(self, name):
5253

5354
def wrapper(*args, **kwargs):
5455
self.call_record[name] += 1
56+
57+
# Include args that we're interested in
58+
to_log = [name]
59+
for arg in args:
60+
if isinstance(arg, int):
61+
to_log.append(str(arg))
62+
elif isinstance(arg, KVCacheBlocks):
63+
to_log.append(f"num_blocks={len(arg.blocks)}")
64+
5565
# Log the event as a line to the file
5666
try:
5767
with open(self._event_file, "a") as f:
58-
f.write(name + "\n")
68+
f.write(' '.join(to_log) + "\n")
5969
except Exception as e:
6070
print(f"[ERROR] Could not log event {name} "
6171
f"for {self.name}: {e}")
@@ -162,15 +172,23 @@ def test_multi_shared_storage_connector_consistency():
162172
f"{storage_1_path} and {storage_2_path}")
163173

164174
events = get_connector_events()
165-
# get_num_new_matched_tokens will be called on each connector in turn.
166-
# neither of them have hits so update_state_after_alloc won't be called.
167-
assert events["storage1"][:3] == [
168-
'get_num_new_matched_tokens', 'build_connector_meta',
169-
'bind_connector_metadata'
175+
# get_num_new_matched_tokens and update_state_after_alloc will be called
176+
# on each connector in turn.
177+
assert events["storage1-SCHEDULER"][:3] == [
178+
'get_num_new_matched_tokens 0',
179+
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
180+
]
181+
assert events["storage1-WORKER"][:5] == [
182+
'register_kv_caches', 'bind_connector_metadata', 'start_load_kv',
183+
'wait_for_layer_load', 'save_kv_layer'
184+
]
185+
assert events["storage2-SCHEDULER"][:3] == [
186+
'get_num_new_matched_tokens 0',
187+
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
170188
]
171-
assert events["storage2"][:3] == [
172-
'get_num_new_matched_tokens', 'build_connector_meta',
173-
'bind_connector_metadata'
189+
assert events["storage2-WORKER"][:5] == [
190+
'register_kv_caches', 'bind_connector_metadata', 'start_load_kv',
191+
'wait_for_layer_load', 'save_kv_layer'
174192
]
175193

176194
# Reset prefix cache or else we'll just get the tokens back from there.
@@ -182,16 +200,16 @@ def test_multi_shared_storage_connector_consistency():
182200

183201
events = get_connector_events()
184202
# get_num_new_matched_tokens will return new tokens from the first
185-
# connector so update_state_after_alloc will be called once blocks
186-
# are allocated for the first connector.
187-
# get_num_new_matched_tokens *won't* be called on the second connector
188-
# in this case.
189-
assert events["storage1"][:4] == [
190-
'get_num_new_matched_tokens', 'update_state_after_alloc',
191-
'build_connector_meta', 'bind_connector_metadata'
203+
# connector so update_state_after_alloc will be with allocated blocks
204+
# on that one but with zero blocks for others (first nonzero match is
205+
# chosen).
206+
assert events["storage1-SCHEDULER"][:3] == [
207+
'get_num_new_matched_tokens 0',
208+
'update_state_after_alloc num_blocks=7 96', 'build_connector_meta'
192209
]
193-
assert events["storage2"][:2] == [
194-
'build_connector_meta', 'bind_connector_metadata'
210+
assert events["storage2-SCHEDULER"][:3] == [
211+
'get_num_new_matched_tokens 0',
212+
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
195213
]
196214

197215
# Delete storage1 connector state
@@ -205,17 +223,17 @@ def test_multi_shared_storage_connector_consistency():
205223
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
206224

207225
events = get_connector_events()
208-
# get_num_new_matched_tokens will be called for the first connector but it
209-
# won't have a hit so update_state_after_alloc won't be called.
210-
# get_num_new_matched_tokens will also be called on the second connector,
211-
# but it should have a hit so update_state_after_alloc will be called.
212-
assert events["storage1"][:3] == [
213-
'get_num_new_matched_tokens', 'build_connector_meta',
214-
'bind_connector_metadata'
226+
# get_num_new_matched_tokens will be called for both connectors but will
227+
# return 0 from the first connector, but the second connector should have
228+
# a hit, so update_state_after_alloc will only be called with allocated
229+
# blocks for the second connector.
230+
assert events["storage1-SCHEDULER"][:3] == [
231+
'get_num_new_matched_tokens 0',
232+
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
215233
]
216-
assert events["storage2"][:4] == [
217-
'get_num_new_matched_tokens', 'update_state_after_alloc',
218-
'build_connector_meta', 'bind_connector_metadata'
234+
assert events["storage2-SCHEDULER"][:3] == [
235+
'get_num_new_matched_tokens 0',
236+
'update_state_after_alloc num_blocks=7 96', 'build_connector_meta'
219237
]
220238

221239
# Clean up

vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
1313
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
1414
from vllm.logger import init_logger
15+
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
1516
from vllm.v1.core.sched.output import SchedulerOutput
1617

1718
if TYPE_CHECKING:
1819
from vllm.attention.backends.abstract import AttentionMetadata
1920
from vllm.forward_context import ForwardContext
20-
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
2121
from vllm.v1.request import Request
2222

2323
logger = init_logger(__name__)
@@ -51,8 +51,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
5151
self._connectors.append(
5252
KVConnectorFactory.create_connector_v1(temp_config, role))
5353

54-
# A mapping from request id to the connector that is assigned to it.
55-
self._requests_to_connector: dict[str, KVConnectorBase_V1] = {}
54+
# A mapping from request id to the index of the connector chosen to
55+
# load the request from (if any).
56+
self._requests_to_connector: dict[str, int] = {}
5657

5758
# Keeps track of *additional* remaining async saves (beyond 1) to be
5859
# finished per request. Not needed for async loads since we only allow
@@ -136,25 +137,31 @@ def get_num_new_matched_tokens(
136137
request: "Request",
137138
num_computed_tokens: int,
138139
) -> tuple[int, bool]:
139-
for c in self._connectors:
140+
to_return = (0, False)
141+
for i, c in enumerate(self._connectors):
140142
toks, load_async = c.get_num_new_matched_tokens(
141143
request, num_computed_tokens)
142144
# The first connector that has new matched tokens will be assigned
143145
# to this request.
144-
if toks > 0:
145-
self._requests_to_connector[request.request_id] = c
146-
return toks, load_async
147-
return 0, False
146+
if to_return[0] == 0 and toks > 0:
147+
self._requests_to_connector[request.request_id] = i
148+
to_return = (toks, load_async)
149+
return to_return
148150

149151
def update_state_after_alloc(self, request: "Request",
150152
blocks: "KVCacheBlocks",
151153
num_external_tokens: int):
152-
# If the request is not assigned to any connector, we do nothing.
153-
if request.request_id not in self._requests_to_connector:
154-
return
155-
# We assume that the request is assigned to only one connector.
156-
c = self._requests_to_connector.pop(request.request_id)
157-
c.update_state_after_alloc(request, blocks, num_external_tokens)
154+
chosen_connector = self._requests_to_connector.get(
155+
request.request_id, -1)
156+
for i, c in enumerate(self._connectors):
157+
if i == chosen_connector:
158+
# Forward call to the chosen connector (if any).
159+
c.update_state_after_alloc(request, blocks,
160+
num_external_tokens)
161+
else:
162+
# Call with empty blocks for other connectors.
163+
c.update_state_after_alloc(request,
164+
KVCacheBlocks.create_empty(), 0)
158165

159166
def build_connector_meta(
160167
self,
@@ -170,7 +177,7 @@ def build_connector_meta(
170177
def request_finished(
171178
self,
172179
request: "Request",
173-
blocks: "KVCacheBlocks",
180+
blocks: list[int],
174181
) -> tuple[bool, Optional[dict[str, Any]]]:
175182
async_saves = 0
176183
kv_txfer_params = None
@@ -187,4 +194,8 @@ def request_finished(
187194
kv_txfer_params = txfer_params
188195
if async_saves > 1:
189196
self._extra_async_saves[request.request_id] = async_saves - 1
197+
198+
# Clean up other state for this request.
199+
self._requests_to_connector.pop(request.request_id, None)
200+
190201
return async_saves > 0, kv_txfer_params

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -221,15 +221,6 @@ def get_num_new_matched_tokens(
221221
if count > 0:
222222
return count, True
223223

224-
# NOTE: if count is 0 here, we have less than block_size
225-
# tokens to pull after subtracting the local prefix cache hit.
226-
# The remote only sends fully computed blocks, so there is
227-
# nothing to transfer but we still need to notify the
228-
# prefill worker so that the remote blocks are freed.
229-
if all(p in params for p in ("remote_engine_id", "remote_host",
230-
"remote_port")):
231-
self._reqs_need_recv[request.request_id] = (request, [])
232-
233224
# No remote prefill for this request.
234225
return 0, False
235226

@@ -247,9 +238,14 @@ def update_state_after_alloc(self, request: "Request",
247238
if params.get("remote_block_ids"):
248239
if all(p in params for p in ("remote_engine_id", "remote_host",
249240
"remote_port")):
241+
# If remote_blocks and num_external_tokens = 0, we have
242+
# a full prefix cache hit on the D worker. We need to call
243+
# send_notif in _read_blocks to free the memory on the P.
244+
local_block_ids = (blocks.get_unhashed_block_ids()
245+
if num_external_tokens > 0 else [])
250246
# Get unhashed blocks to pull from remote.
251247
self._reqs_need_recv[request.request_id] = (
252-
request, blocks.get_unhashed_block_ids())
248+
request, local_block_ids)
253249
else:
254250
logger.warning(
255251
"Got invalid KVTransferParams: %s. This "
@@ -268,15 +264,6 @@ def build_connector_meta(
268264
# Loop through scheduled reqs and convert to ReqMeta.
269265
for req_id, (req, block_ids) in self._reqs_need_recv.items():
270266
assert req.kv_transfer_params is not None
271-
# For the case where there are no remote blocks to pull
272-
# (block_ids is empty), we don't need to schedule
273-
# an async read on the worker side.
274-
if not block_ids:
275-
logger.debug(
276-
"Skipping adding request %s to NixlConnectorMetadata, "
277-
"as there are no remote blocks to pull", req_id)
278-
continue
279-
280267
meta.add_new_req(
281268
request_id=req_id,
282269
local_block_ids=block_ids,
@@ -660,26 +647,26 @@ def add_remote_agent(self,
660647

661648
# Number of D TP workers reading from a single P TP worker. This is
662649
# 1 when P and D `--tensor-parallel-size` match.
663-
assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, \
664-
"Local TP size must be divisible by remote TP size."
650+
assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, (
651+
"Local TP size must be divisible by remote TP size.")
665652
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id]
666-
assert tp_ratio > 0, "Decode TP cannot be smaller than"
667-
" prefill TP"
653+
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
668654
if self.use_mla:
669655
# With MLA the only difference is in the number of blocks.
670-
remote_block_size = nixl_agent_meta.block_len / (
656+
remote_block_size = nixl_agent_meta.block_len // (
671657
self.slot_size_bytes)
672658
assert self.block_len == nixl_agent_meta.block_len
673659
else:
674-
remote_block_size = nixl_agent_meta.block_len / (
660+
remote_block_size = nixl_agent_meta.block_len // (
675661
self.slot_size_bytes * tp_ratio)
676662

677-
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, \
678-
"Remote P worker KV layer cache must be of shape [2, N, \
679-
local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
663+
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
664+
"Remote P worker KV layer cache must be of shape [2, N, "
665+
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
666+
)
680667

681-
assert self.block_size == remote_block_size, "Remote P worker with \
682-
different block size is not supported"
668+
assert self.block_size == remote_block_size, "Remote P worker with "
669+
"different block size is not supported"
683670

684671
assert self.num_blocks >= nixl_agent_meta.num_blocks
685672

@@ -712,9 +699,9 @@ def add_remote_agent(self,
712699
# (addr, len, device id)
713700
blocks_data.append((addr, self.block_len, remote_tp_rank))
714701
logger.debug(
715-
"Created %s blocks for dst engine %s with remote rank %s and " \
716-
"local rank %s",
717-
len(blocks_data), engine_id, remote_tp_rank, self.tp_rank)
702+
"Created %s blocks for dst engine %s with remote rank %s and "
703+
"local rank %s", len(blocks_data), engine_id, remote_tp_rank,
704+
self.tp_rank)
718705

719706
# Register with NIXL.
720707
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")

vllm/v1/core/sched/scheduler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -424,11 +424,11 @@ def schedule(self) -> SchedulerOutput:
424424
# The request cannot be scheduled.
425425
break
426426

427-
# KVConnector: update internal state after allocation.
427+
# KVTransfer: the connector uses this info to determine
428+
# if a load is needed. Note that
428429
# This information is used to determine if a load is
429430
# needed for this request.
430-
if num_external_computed_tokens:
431-
assert self.connector is not None
431+
if self.connector is not None:
432432
self.connector.update_state_after_alloc(
433433
request,
434434
new_computed_blocks + new_blocks,
@@ -841,7 +841,7 @@ def update_from_output(
841841
}
842842

843843
finished_req_ids = self.finished_req_ids_dict
844-
if finished_req_ids is not None:
844+
if finished_req_ids:
845845
# Include ids of requests that finished since last outputs
846846
# were sent.
847847
for client_index, finished_set in finished_req_ids.items():

0 commit comments

Comments
 (0)