13
13
# limitations under the License.
14
14
15
15
from abc import ABC , abstractmethod
16
- from typing import Iterable
16
+ from typing import Tuple
17
17
18
- import numpy as np
19
- from farmhash import FarmHash64
20
-
21
- from .memory import MemoryRegion
22
- from .utils import np_array_concat
18
+ from .common import CachedPyObjectBase
23
19
24
20
25
21
class KVCacheHashable (ABC ):
@@ -40,119 +36,34 @@ def __len__(self) -> int:
40
36
raise NotImplementedError
41
37
42
38
43
- class BaseKVCacheHashable (KVCacheHashable ):
39
+ class BaseKVCacheHashable (KVCacheHashable , CachedPyObjectBase ):
44
40
"""
45
41
Base class for a hashable object that uses all tokens to compute the hash.
46
42
"""
47
43
48
- @abstractmethod
49
- def all_tokens_memoryview (self ) -> memoryview :
50
- """Memoryview of the PACKED bytes representation of all tokens."""
51
- raise NotImplementedError
44
+ def __init__ (self , prefix : Tuple [int , ...] | None , tokens : Tuple [int , ...]):
45
+ self .prefix = prefix or tuple ()
46
+ self .tokens = tokens
52
47
53
48
def __hash__ (self ) -> int :
54
- return FarmHash64 ( self .all_tokens_memoryview ( ))
49
+ return hash (( self .prefix , self . tokens ))
55
50
56
51
def __eq__ (self , other ) -> bool :
57
52
if not isinstance (other , BaseKVCacheHashable ):
58
53
return False
59
- return self .all_tokens_memoryview ( ) == other .all_tokens_memoryview ( )
54
+ return ( self .prefix , self . tokens ) == ( other .prefix , other . tokens )
60
55
61
56
62
57
class TokenCacheKey (BaseKVCacheHashable ):
63
58
"""
64
- A cache key that use numpy ndarray's bytes representation to compute
65
- the hash.
59
+ A cache key that compounds prefix and tokens.
66
60
Args:
67
61
prefix (np.ndarray | None): The prefix tokens of the kv tensors.
68
62
tokens (np.ndarray): The tokens of the kv tensors.
69
63
"""
70
64
71
- def __init__ (self , * , tokens : np .ndarray , prefix : np .ndarray | None = None ):
72
- self ._storage = np_array_concat (prefix , tokens )
73
-
74
- self .prefix = prefix
75
- self .tokens = tokens
76
-
77
- def all_tokens (self ) -> np .ndarray :
78
- return self ._storage
79
-
80
- def all_tokens_memoryview (self ) -> memoryview :
81
- return memoryview (self ._storage ) # type: ignore
82
-
83
- def __len__ (self ) -> int :
84
- return len (self ._storage )
85
-
86
- def shift (self , shift : int ) -> "TokenCacheKey" :
87
- """
88
- Shifts the split of prefix and tokens.
89
- Args:
90
- shift (int): The number of tokens to shift.
91
- Returns:
92
- TokenCacheKey: The cache key with shifted tokens.
93
- """
94
- orig_prefix_len = len (self .prefix ) if self .prefix is not None else 0
95
- new_prefix_len = min (max (0 , orig_prefix_len + shift ), len (self ))
96
- new_prefix = self ._storage [:new_prefix_len ]
97
- new_tokens = self ._storage [new_prefix_len :]
98
- return TokenCacheKey (prefix = new_prefix , tokens = new_tokens )
99
-
100
- def shrink (self , n : int ) -> "TokenCacheKey" :
101
- """
102
- Shrinks n from tokens.
103
- Args:
104
- n (int): The number of tokens to shrink.
105
- Returns:
106
- TokenCacheKey: The new cache key.
107
- """
108
- new_tokens = self .tokens [:- n ]
109
- return TokenCacheKey (prefix = self .prefix , tokens = new_tokens )
110
-
111
- def batched (self , batch_size : int ) -> Iterable ["TokenCacheKey" ]:
112
- """
113
- Batches the tokens into a list of TokenCacheKey.
114
- Args:
115
- batch_size (int): The batch size.
116
- Returns:
117
- Iterable[TokenCacheKey]: The batched tokens.
118
- """
119
- prefix_len = len (self .prefix ) if self .prefix is not None else 0
120
- num_batches = len (self .tokens ) // batch_size
121
- for _ in range (num_batches ):
122
- yield (
123
- TokenCacheKey (
124
- prefix = self ._storage [:prefix_len ],
125
- tokens = self ._storage [prefix_len : prefix_len + batch_size ],
126
- )
127
- )
128
- prefix_len += batch_size
129
-
130
-
131
- class MemoryRegionCacheEntry (BaseKVCacheHashable ):
132
- """
133
- A cache entry that stores a memory region and uses MR's all tokens
134
- to compute the hash.
135
- """
136
-
137
- def __init__ (self , mr : MemoryRegion ):
138
- # Take the ownership of the memory region
139
- self ._mr : MemoryRegion | None = mr
140
- assert mr .is_sealed , "Memory region must be sealed"
141
-
142
- self .ref_up = self ._mr .ref_up
143
- self .ref_down = self ._mr .ref_down
144
-
145
- def all_tokens_memoryview (self ) -> memoryview :
146
- assert self ._mr is not None
147
- return memoryview (np_array_concat (* self ._mr .unpack_tokens ())) # type: ignore
65
+ def __init__ (self , prefix : Tuple [int , ...] | None , tokens : Tuple [int , ...]):
66
+ super ().__init__ (prefix , tokens )
148
67
149
68
def __len__ (self ) -> int :
150
- assert self ._mr is not None
151
- return self ._mr .length
152
-
153
- def cache_key (self ) -> TokenCacheKey :
154
- assert self ._mr is not None
155
- prefix , tokens = self ._mr .unpack_tokens ()
156
- prefix = prefix .copy () if prefix is not None else None
157
- tokens = tokens .copy ()
158
- return TokenCacheKey (prefix = prefix , tokens = tokens )
69
+ return len (self .prefix ) + len (self .tokens )
0 commit comments