Skip to content

Commit 6d76d82

Browse files
author
Haiyang Shi
committed
[Feature] KVCache layout: compact laytout
Signed-off-by: Haiyang Shi <[email protected]>
1 parent e0c4511 commit 6d76d82

34 files changed

+1364
-1113
lines changed

python/aibrix_kvcache/aibrix_kvcache/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
from .cache_handle import KVCacheHandle, MemoryRegionKVCacheHandle
16-
from .cache_hashable import TokenCacheKey
1716
from .cache_manager import (
1817
BaseKVCacheManager,
1918
GroupAwareKVCacheManager,
@@ -26,7 +25,6 @@
2625

2726
__all__ = [
2827
"KVCacheHandle",
29-
"TokenCacheKey",
3028
"MemoryRegionKVCacheHandle",
3129
"BaseKVCacheManager",
3230
"GroupAwareKVCacheManager",

python/aibrix_kvcache/aibrix_kvcache/cache_hashable.py

Lines changed: 12 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,9 @@
1313
# limitations under the License.
1414

1515
from abc import ABC, abstractmethod
16-
from typing import Iterable
16+
from typing import Tuple
1717

18-
import numpy as np
19-
from farmhash import FarmHash64
20-
21-
from .memory import MemoryRegion
22-
from .utils import np_array_concat
18+
from .common import CachedPyObjectBase
2319

2420

2521
class KVCacheHashable(ABC):
@@ -40,119 +36,34 @@ def __len__(self) -> int:
4036
raise NotImplementedError
4137

4238

43-
class BaseKVCacheHashable(KVCacheHashable):
39+
class BaseKVCacheHashable(KVCacheHashable, CachedPyObjectBase):
4440
"""
4541
Base class for a hashable object that uses all tokens to compute the hash.
4642
"""
4743

48-
@abstractmethod
49-
def all_tokens_memoryview(self) -> memoryview:
50-
"""Memoryview of the PACKED bytes representation of all tokens."""
51-
raise NotImplementedError
44+
def __init__(self, prefix: Tuple[int, ...] | None, tokens: Tuple[int, ...]):
45+
self.prefix = prefix or tuple()
46+
self.tokens = tokens
5247

5348
def __hash__(self) -> int:
54-
return FarmHash64(self.all_tokens_memoryview())
49+
return hash((self.prefix, self.tokens))
5550

5651
def __eq__(self, other) -> bool:
5752
if not isinstance(other, BaseKVCacheHashable):
5853
return False
59-
return self.all_tokens_memoryview() == other.all_tokens_memoryview()
54+
return (self.prefix, self.tokens) == (other.prefix, other.tokens)
6055

6156

6257
class TokenCacheKey(BaseKVCacheHashable):
6358
"""
64-
A cache key that use numpy ndarray's bytes representation to compute
65-
the hash.
59+
A cache key that compounds prefix and tokens.
6660
Args:
6761
prefix (np.ndarray | None): The prefix tokens of the kv tensors.
6862
tokens (np.ndarray): The tokens of the kv tensors.
6963
"""
7064

71-
def __init__(self, *, tokens: np.ndarray, prefix: np.ndarray | None = None):
72-
self._storage = np_array_concat(prefix, tokens)
73-
74-
self.prefix = prefix
75-
self.tokens = tokens
76-
77-
def all_tokens(self) -> np.ndarray:
78-
return self._storage
79-
80-
def all_tokens_memoryview(self) -> memoryview:
81-
return memoryview(self._storage) # type: ignore
82-
83-
def __len__(self) -> int:
84-
return len(self._storage)
85-
86-
def shift(self, shift: int) -> "TokenCacheKey":
87-
"""
88-
Shifts the split of prefix and tokens.
89-
Args:
90-
shift (int): The number of tokens to shift.
91-
Returns:
92-
TokenCacheKey: The cache key with shifted tokens.
93-
"""
94-
orig_prefix_len = len(self.prefix) if self.prefix is not None else 0
95-
new_prefix_len = min(max(0, orig_prefix_len + shift), len(self))
96-
new_prefix = self._storage[:new_prefix_len]
97-
new_tokens = self._storage[new_prefix_len:]
98-
return TokenCacheKey(prefix=new_prefix, tokens=new_tokens)
99-
100-
def shrink(self, n: int) -> "TokenCacheKey":
101-
"""
102-
Shrinks n from tokens.
103-
Args:
104-
n (int): The number of tokens to shrink.
105-
Returns:
106-
TokenCacheKey: The new cache key.
107-
"""
108-
new_tokens = self.tokens[:-n]
109-
return TokenCacheKey(prefix=self.prefix, tokens=new_tokens)
110-
111-
def batched(self, batch_size: int) -> Iterable["TokenCacheKey"]:
112-
"""
113-
Batches the tokens into a list of TokenCacheKey.
114-
Args:
115-
batch_size (int): The batch size.
116-
Returns:
117-
Iterable[TokenCacheKey]: The batched tokens.
118-
"""
119-
prefix_len = len(self.prefix) if self.prefix is not None else 0
120-
num_batches = len(self.tokens) // batch_size
121-
for _ in range(num_batches):
122-
yield (
123-
TokenCacheKey(
124-
prefix=self._storage[:prefix_len],
125-
tokens=self._storage[prefix_len : prefix_len + batch_size],
126-
)
127-
)
128-
prefix_len += batch_size
129-
130-
131-
class MemoryRegionCacheEntry(BaseKVCacheHashable):
132-
"""
133-
A cache entry that stores a memory region and uses MR's all tokens
134-
to compute the hash.
135-
"""
136-
137-
def __init__(self, mr: MemoryRegion):
138-
# Take the ownership of the memory region
139-
self._mr: MemoryRegion | None = mr
140-
assert mr.is_sealed, "Memory region must be sealed"
141-
142-
self.ref_up = self._mr.ref_up
143-
self.ref_down = self._mr.ref_down
144-
145-
def all_tokens_memoryview(self) -> memoryview:
146-
assert self._mr is not None
147-
return memoryview(np_array_concat(*self._mr.unpack_tokens())) # type: ignore
65+
def __init__(self, prefix: Tuple[int, ...] | None, tokens: Tuple[int, ...]):
66+
super().__init__(prefix, tokens)
14867

14968
def __len__(self) -> int:
150-
assert self._mr is not None
151-
return self._mr.length
152-
153-
def cache_key(self) -> TokenCacheKey:
154-
assert self._mr is not None
155-
prefix, tokens = self._mr.unpack_tokens()
156-
prefix = prefix.copy() if prefix is not None else None
157-
tokens = tokens.copy()
158-
return TokenCacheKey(prefix=prefix, tokens=tokens)
69+
return len(self.prefix) + len(self.tokens)

0 commit comments

Comments
 (0)