Skip to content

Commit b5a63e4

Browse files
committed
Support for advanced KV caching and batch generation
1 parent e3088e6 commit b5a63e4

34 files changed

+2663
-654
lines changed

litgpt/adapter.py

Lines changed: 64 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
import torch.nn as nn
1616
from typing_extensions import Self
1717

18+
from litgpt.attention import DefaultKeysAndValues, MultiHeadSelfAttention
1819
from litgpt.config import Config as BaseConfig
20+
from litgpt.kvcache.base import KVCache
1921
from litgpt.model import GPT as BaseModel
2022
from litgpt.model import Block as BaseBlock
2123
from litgpt.model import CausalSelfAttention as BaseCausalSelfAttention
@@ -42,8 +44,9 @@ def __init__(self, config: Config) -> None:
4244
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
4345
)
4446
)
45-
self.mask_cache: Optional[torch.Tensor] = None
47+
self.mha = MultiHeadSelfAttention(config)
4648
self.max_seq_length = self.config.block_size
49+
self._default_kv_cache = False
4750

4851
@classmethod
4952
def from_name(cls, name: str, **kwargs: Any) -> Self:
@@ -57,56 +60,80 @@ def _init_weights(self, module: nn.Module) -> None:
5760

5861

5962
class Block(BaseBlock):
60-
def __init__(self, config: Config, block_idx: int) -> None:
61-
super().__init__(config, block_idx)
62-
self.attn = CausalSelfAttention(config, block_idx)
63+
def __init__(
64+
self,
65+
config: Config,
66+
block_idx: int,
67+
kv_cache: Optional[KVCache] = None,
68+
) -> None:
69+
super().__init__(config, block_idx, kv_cache)
70+
self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache)
6371

6472

6573
class CausalSelfAttention(BaseCausalSelfAttention):
6674
"""A modification of `litgpt.model.CausalSelfAttention` that adds the attention
6775
over the adaption prompt."""
6876

69-
def __init__(self, config: Config, block_idx: int) -> None:
70-
super().__init__(config, block_idx)
71-
if block_idx >= config.adapter_start_layer:
77+
def __init__(
78+
self,
79+
config: Config,
80+
block_idx: int,
81+
kv_cache: Optional[KVCache] = None,
82+
) -> None:
83+
super().__init__(
84+
config=config,
85+
block_idx=block_idx,
86+
kv_cache=kv_cache,
87+
)
88+
self._extend_forward = block_idx >= config.adapter_start_layer
89+
if self._extend_forward:
7290
# adapter embedding layer
7391
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
7492
# gate for adaption
7593
self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
7694
# kv cache for inference
7795
self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
7896

79-
def scaled_dot_product_attention(
80-
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
97+
def _transform_output(
98+
self,
99+
y: torch.Tensor,
100+
query: torch.Tensor,
101+
mha: MultiHeadSelfAttention,
81102
) -> torch.Tensor:
82-
y = super().scaled_dot_product_attention(q, k, v, mask)
83-
if self.block_idx < self.config.adapter_start_layer:
84-
return y
85-
86-
aT = self.config.adapter_prompt_length
87-
if self.adapter_kv_cache is not None:
88-
# since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av
89-
# are the same every call
90-
ak, av = self.adapter_kv_cache
91-
else:
92-
prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd)
93-
aqkv = self.qkv(prefix)
94-
q_per_kv = self.config.n_head // self.config.n_query_groups
95-
aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
96-
aqkv = aqkv.permute(0, 2, 3, 1, 4)
97-
_, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)
98-
if self.config.n_query_groups != 1:
99-
# for MHA this is a no-op
100-
ak = ak.repeat_interleave(q_per_kv, dim=2)
101-
av = av.repeat_interleave(q_per_kv, dim=2)
102-
ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs)
103-
av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs)
104-
self.adapter_kv_cache = (ak, av)
105-
106-
T = q.size(2)
107-
amask = torch.ones(T, aT, dtype=torch.bool, device=q.device)
108-
ay = super().scaled_dot_product_attention(q, ak, av, amask)
109-
return y + self.gating_factor * ay
103+
if self._extend_forward:
104+
B, T, _ = y.shape
105+
y = y.view(B, T, self.config.n_head, self.config.head_size)
106+
aT = self.config.adapter_prompt_length
107+
if self.adapter_kv_cache is not None:
108+
# since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av
109+
# are the same every call
110+
ak, av = self.adapter_kv_cache
111+
else:
112+
prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd)
113+
aqkv = self.qkv(prefix)
114+
q_per_kv = self.config.n_head // self.config.n_query_groups
115+
aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
116+
aqkv = aqkv.permute(0, 2, 3, 1, 4)
117+
_, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)
118+
if self.config.n_query_groups != 1:
119+
# for MHA this is a no-op
120+
ak = ak.repeat_interleave(q_per_kv, dim=2)
121+
av = av.repeat_interleave(q_per_kv, dim=2)
122+
ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs)
123+
av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs)
124+
self.adapter_kv_cache = (ak, av)
125+
126+
amask = torch.ones(T, aT, dtype=torch.bool, device=query.device)
127+
a_k_and_v = DefaultKeysAndValues(keys=ak, values=av)
128+
ay, _ = mha.scaled_dot_product_attention(
129+
query=query,
130+
k_and_v=a_k_and_v,
131+
mask=amask,
132+
is_causal=False,
133+
)
134+
y = (y + self.gating_factor * ay).view(B, T, -1)
135+
136+
return y
110137

111138
def reset_parameters(self) -> None:
112139
if hasattr(self, "gating_factor"):

litgpt/adapter_v2.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from litgpt.adapter import GPT as BaseModel
2020
from litgpt.adapter import CausalSelfAttention as BaseCausalSelfAttention
2121
from litgpt.adapter import Config as BaseConfig
22+
from litgpt.attention import MultiHeadSelfAttention
23+
from litgpt.kvcache.base import KVCache
2224
from litgpt.model import Block as BaseBlock
2325
from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble
2426
from litgpt.utils import map_old_state_dict_weights
@@ -77,8 +79,9 @@ def __init__(self, config: Config) -> None:
7779
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
7880
)
7981
)
80-
self.mask_cache: Optional[torch.Tensor] = None
82+
self.mha = MultiHeadSelfAttention(config)
8183
self.max_seq_length = self.config.block_size
84+
self._default_kv_cache = False
8285

8386
@classmethod
8487
def from_name(cls, name: str, **kwargs: Any) -> Self:
@@ -98,18 +101,28 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa
98101

99102

100103
class Block(BaseBlock):
101-
def __init__(self, config: Config, block_idx: int) -> None:
102-
super().__init__(config, block_idx)
103-
self.attn = CausalSelfAttention(config, block_idx)
104+
def __init__(
105+
self,
106+
config: Config,
107+
block_idx: int,
108+
kv_cache: Optional[KVCache] = None,
109+
) -> None:
110+
super().__init__(config, block_idx, kv_cache)
111+
self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache)
104112
self.mlp = config.mlp_class(config)
105113

106114

107115
class CausalSelfAttention(BaseCausalSelfAttention):
108116
"""A modification of `litgpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class"""
109117

110118
# Copy&paste from :class:`model.CausalSelfAttention`
111-
def __init__(self, config: Config, block_idx: int) -> None:
112-
super().__init__(config, block_idx)
119+
def __init__(
120+
self,
121+
config: Config,
122+
block_idx: int,
123+
kv_cache: Optional[KVCache] = None,
124+
) -> None:
125+
super().__init__(config, block_idx, kv_cache)
113126
# key, query, value projections for all heads, but in a batch
114127
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
115128
self.qkv = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias)

litgpt/api.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,9 @@ def distribute(
377377
kv_cache_size = model.max_seq_length
378378
else:
379379
kv_cache_size = fixed_kv_cache_size
380-
model.set_kv_cache(batch_size=1, max_seq_length=kv_cache_size, device=fabric.device)
380+
model.set_kv_cache(
381+
batch_size=1, max_seq_length=kv_cache_size, device=fabric.device,
382+
)
381383
self.kv_cache_initialized = True
382384
self.fixed_kv_cache_size = fixed_kv_cache_size
383385

@@ -504,15 +506,18 @@ def generate(
504506
device = self.fabric.device
505507
else:
506508
device = self.preprocessor.device
507-
self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=device)
509+
self.model.set_kv_cache(
510+
batch_size=1, max_seq_length=max_returned_tokens, device=device,
511+
)
508512
self.kv_cache_initialized = True
509513

510514
# Dynamically grow the kv cache size if necessary
511515
if not self.fixed_kv_cache_size and self.prev_generated_seq_length < max_returned_tokens:
512-
tmp_device = self.model.mask_cache.device
516+
tmp_device = self.model.mha.mask_cache.device
513517
self.model.clear_kv_cache()
514-
self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=tmp_device)
515-
518+
self.model.set_kv_cache(
519+
batch_size=1, max_seq_length=max_returned_tokens, device=tmp_device,
520+
)
516521
else:
517522
for block in self.model.transformer.h:
518523
block.attn.kv_cache.reset_parameters()

0 commit comments

Comments
 (0)