Skip to content

[v1] Re-init input batch for multiple kv cache groups #18654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 3 additions & 26 deletions tests/v1/worker/test_gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
Expand All @@ -24,27 +22,6 @@
MAX_NUM_PROMPT_TOKENS = 64


def get_kv_cache_config() -> KVCacheConfig:
return KVCacheConfig(
num_blocks=10,
tensors={
"layer.0": KVCacheTensor(size=1024),
},
kv_cache_groups=[
KVCacheGroupSpec(
layer_names=["layer.0"],
kv_cache_spec=FullAttentionSpec(
block_size=1,
num_kv_heads=1,
head_size=16,
dtype=torch.float16,
use_mla=False,
),
),
],
)


def _compare_objs(obj1, obj2):
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
attr_names = set([
Expand Down Expand Up @@ -251,7 +228,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_size=1,
block_sizes=[1],
)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
Expand Down Expand Up @@ -341,7 +318,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_size=1,
block_sizes=[1],
)
ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
Expand All @@ -350,7 +327,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_size=1,
block_sizes=[1],
)

reqs: list[CachedRequestState] = []
Expand Down
4 changes: 3 additions & 1 deletion tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def initialize_kv_cache(runner: GPUModelRunner):
device=runner.device,
pin_memory=runner.pin_memory,
vocab_size=runner.model_config.get_vocab_size(),
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size,
block_sizes=[
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
],
)
runner.initialize_attn_backend(kv_cache_config)

Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/worker/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,11 @@ class MultiGroupBlockTable:

def __init__(self, max_num_reqs: int, max_model_len: int,
max_num_batched_tokens: int, pin_memory: bool,
device: torch.device, block_size: int) -> None:
device: torch.device, block_sizes: list[int]) -> None:
self.block_tables = [
BlockTable(max_num_reqs, cdiv(max_model_len, block_size),
max_num_batched_tokens, pin_memory, device)
for block_size in block_sizes
]

def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
Expand Down
18 changes: 9 additions & 9 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ def get_token_id(self, idx: int) -> int:
class InputBatch:

def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_size: int,
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory,
device=device,
block_size=block_size,
block_sizes=block_sizes,
)

# Sampling-related.
Expand Down
46 changes: 40 additions & 6 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def __init__(
self.attn_metadata_builders: list[AttentionMetadataBuilder] = []
self.attn_backends: list[type[AttentionBackend]] = []
# self.kv_cache_config: KVCacheConfig
# self.input_batch: InputBatch # Persistent batch.

# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
Expand Down Expand Up @@ -172,14 +171,23 @@ def __init__(
# Request states.
self.requests: dict[str, CachedRequestState] = {}

# Input Batch
# NOTE(Chen): Ideally, we should initialize the input batch inside
# `initialize_kv_cache` based on the kv cache config. However, as in
# https://github.com/vllm-project/vllm/pull/18298, due to some unknown
# reasons, we have to initialize the input batch before `load_model`,
# quantization + weight offloading will fail otherwise. As a temporary
# solution, we initialize the input batch here, and re-initialize it
# in `initialize_kv_cache` if the block_sizes here is different from
# the block_sizes in the kv cache config.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_size=self.cache_config.block_size,
block_sizes=[self.cache_config.block_size],
)

self.use_cuda_graph = (self.vllm_config.compilation_config.level
Expand Down Expand Up @@ -1998,18 +2006,44 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
self.attn_backends.append(attn_backend_i)
self.attn_metadata_builders.append(attn_metadata_builder_i)

def may_reinitialize_input_batch(self,
kv_cache_config: KVCacheConfig) -> None:
"""
Re-initialize the input batch if the block sizes are different from
`[self.cache_config.block_size]`. This usually happens when there
are multiple KV cache groups.

Args:
kv_cache_config: The KV cache configuration.
"""
block_sizes = [
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in kv_cache_config.kv_cache_groups
]
if block_sizes != [self.cache_config.block_size]:
assert self.cache_config.cpu_offload_gb == 0, (
"Cannot re-initialize the input batch when CPU weight "
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
"for more details.")
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=block_sizes,
)

def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
Args:
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
if len(kv_cache_config.kv_cache_groups) > 1:
raise NotImplementedError(
"Hybrid models with more than one KV cache type are not "
"supported yet.")
self.kv_cache_config = kv_cache_config
self.may_reinitialize_input_batch(kv_cache_config)
self.initialize_attn_backend(kv_cache_config)

kv_caches: dict[str, torch.Tensor] = {}
Expand Down
5 changes: 3 additions & 2 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,8 +1293,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.
block_size,
block_sizes=[
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
],
)
assert self.block_table_cpu.dtype == self.input_batch.block_table[
0].get_cpu_tensor().dtype
Expand Down