Skip to content

Commit 5b994c3

Browse files
committed
Support for advanced KV caching and batch generation
1 parent f99ca4e commit 5b994c3

34 files changed

+2812
-721
lines changed

litgpt/adapter.py

Lines changed: 71 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
import torch.nn as nn
1616
from typing_extensions import Self
1717

18+
from litgpt.attention import DefaultKeysAndValues, MultiHeadSelfAttention
1819
from litgpt.config import Config as BaseConfig
20+
from litgpt.kvcache.base import KVCache
1921
from litgpt.model import GPT as BaseModel
2022
from litgpt.model import Block as BaseBlock
2123
from litgpt.model import CausalSelfAttention as BaseCausalSelfAttention
@@ -34,16 +36,23 @@ def __init__(self, config: Config) -> None:
3436
assert config.padded_vocab_size is not None
3537
self.config = config
3638

37-
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
39+
self.lm_head = nn.Linear(
40+
config.n_embd,
41+
config.padded_vocab_size,
42+
bias=config.lm_head_bias,
43+
)
3844
self.transformer = nn.ModuleDict(
3945
dict(
4046
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
4147
h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)),
4248
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
4349
)
4450
)
45-
self.mask_cache: Optional[torch.Tensor] = None
51+
self.mha = MultiHeadSelfAttention(config)
4652
self.max_seq_length = self.config.block_size
53+
self._start_of_layer_hook = config.start_of_layer_hook
54+
# Have dense KV caches been created by `set_kv_cache`?
55+
self._default_kv_cache = False
4756

4857
@classmethod
4958
def from_name(cls, name: str, **kwargs: Any) -> Self:
@@ -57,56 +66,80 @@ def _init_weights(self, module: nn.Module) -> None:
5766

5867

5968
class Block(BaseBlock):
60-
def __init__(self, config: Config, block_idx: int) -> None:
61-
super().__init__(config, block_idx)
62-
self.attn = CausalSelfAttention(config, block_idx)
69+
def __init__(
70+
self,
71+
config: Config,
72+
block_idx: int,
73+
kv_cache: Optional[KVCache] = None,
74+
) -> None:
75+
super().__init__(config, block_idx, kv_cache)
76+
self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache)
6377

6478

6579
class CausalSelfAttention(BaseCausalSelfAttention):
6680
"""A modification of `litgpt.model.CausalSelfAttention` that adds the attention
6781
over the adaption prompt."""
6882

69-
def __init__(self, config: Config, block_idx: int) -> None:
70-
super().__init__(config, block_idx)
71-
if block_idx >= config.adapter_start_layer:
83+
def __init__(
84+
self,
85+
config: Config,
86+
block_idx: int,
87+
kv_cache: Optional[KVCache] = None,
88+
) -> None:
89+
super().__init__(
90+
config=config,
91+
block_idx=block_idx,
92+
kv_cache=kv_cache,
93+
)
94+
self._extend_forward = block_idx >= config.adapter_start_layer
95+
if self._extend_forward:
7296
# adapter embedding layer
7397
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
7498
# gate for adaption
7599
self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
76100
# kv cache for inference
77101
self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
78102

79-
def scaled_dot_product_attention(
80-
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
103+
def _transform_output(
104+
self,
105+
y: torch.Tensor,
106+
query: torch.Tensor,
107+
mha: MultiHeadSelfAttention,
81108
) -> torch.Tensor:
82-
y = super().scaled_dot_product_attention(q, k, v, mask)
83-
if self.block_idx < self.config.adapter_start_layer:
84-
return y
85-
86-
aT = self.config.adapter_prompt_length
87-
if self.adapter_kv_cache is not None:
88-
# since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av
89-
# are the same every call
90-
ak, av = self.adapter_kv_cache
91-
else:
92-
prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd)
93-
aqkv = self.qkv(prefix)
94-
q_per_kv = self.config.n_head // self.config.n_query_groups
95-
aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
96-
aqkv = aqkv.permute(0, 2, 3, 1, 4)
97-
_, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)
98-
if self.config.n_query_groups != 1:
99-
# for MHA this is a no-op
100-
ak = ak.repeat_interleave(q_per_kv, dim=2)
101-
av = av.repeat_interleave(q_per_kv, dim=2)
102-
ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs)
103-
av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs)
104-
self.adapter_kv_cache = (ak, av)
105-
106-
T = q.size(2)
107-
amask = torch.ones(T, aT, dtype=torch.bool, device=q.device)
108-
ay = super().scaled_dot_product_attention(q, ak, av, amask)
109-
return y + self.gating_factor * ay
109+
if self._extend_forward:
110+
B, T, _ = y.shape
111+
y = y.view(B, T, self.config.n_head, self.config.head_size)
112+
aT = self.config.adapter_prompt_length
113+
if self.adapter_kv_cache is not None:
114+
# since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av
115+
# are the same every call
116+
ak, av = self.adapter_kv_cache
117+
else:
118+
prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd)
119+
aqkv = self.qkv(prefix)
120+
q_per_kv = self.config.n_head // self.config.n_query_groups
121+
aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
122+
aqkv = aqkv.permute(0, 2, 3, 1, 4)
123+
_, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)
124+
if self.config.n_query_groups != 1:
125+
# for MHA this is a no-op
126+
ak = ak.repeat_interleave(q_per_kv, dim=2)
127+
av = av.repeat_interleave(q_per_kv, dim=2)
128+
ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs)
129+
av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs)
130+
self.adapter_kv_cache = (ak, av)
131+
132+
amask = torch.ones(T, aT, dtype=torch.bool, device=query.device)
133+
a_k_and_v = DefaultKeysAndValues(keys=ak, values=av)
134+
ay, _ = mha.scaled_dot_product_attention(
135+
query=query,
136+
k_and_v=a_k_and_v,
137+
mask=amask,
138+
is_causal=False,
139+
)
140+
y = (y + self.gating_factor * ay).view(B, T, -1)
141+
142+
return y
110143

111144
def reset_parameters(self) -> None:
112145
if hasattr(self, "gating_factor"):

litgpt/adapter_v2.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from litgpt.adapter import GPT as BaseModel
2020
from litgpt.adapter import CausalSelfAttention as BaseCausalSelfAttention
2121
from litgpt.adapter import Config as BaseConfig
22+
from litgpt.attention import MultiHeadSelfAttention
23+
from litgpt.kvcache.base import KVCache
2224
from litgpt.model import Block as BaseBlock
2325
from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble
2426
from litgpt.utils import map_old_state_dict_weights
@@ -69,16 +71,23 @@ def __init__(self, config: Config) -> None:
6971
assert config.padded_vocab_size is not None
7072
self.config = config
7173

72-
self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
74+
self.lm_head = AdapterV2Linear(
75+
config.n_embd,
76+
config.padded_vocab_size,
77+
bias=config.lm_head_bias,
78+
)
7379
self.transformer = nn.ModuleDict(
7480
dict(
7581
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
7682
h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)),
7783
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
7884
)
7985
)
80-
self.mask_cache: Optional[torch.Tensor] = None
86+
self.mha = MultiHeadSelfAttention(config)
8187
self.max_seq_length = self.config.block_size
88+
self._start_of_layer_hook = config.start_of_layer_hook
89+
# Have dense KV caches been created by `set_kv_cache`?
90+
self._default_kv_cache = False
8291

8392
@classmethod
8493
def from_name(cls, name: str, **kwargs: Any) -> Self:
@@ -98,18 +107,28 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa
98107

99108

100109
class Block(BaseBlock):
101-
def __init__(self, config: Config, block_idx: int) -> None:
102-
super().__init__(config, block_idx)
103-
self.attn = CausalSelfAttention(config, block_idx)
110+
def __init__(
111+
self,
112+
config: Config,
113+
block_idx: int,
114+
kv_cache: Optional[KVCache] = None,
115+
) -> None:
116+
super().__init__(config, block_idx, kv_cache)
117+
self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache)
104118
self.mlp = config.mlp_class(config)
105119

106120

107121
class CausalSelfAttention(BaseCausalSelfAttention):
108122
"""A modification of `litgpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class"""
109123

110124
# Copy&paste from :class:`model.CausalSelfAttention`
111-
def __init__(self, config: Config, block_idx: int) -> None:
112-
super().__init__(config, block_idx)
125+
def __init__(
126+
self,
127+
config: Config,
128+
block_idx: int,
129+
kv_cache: Optional[KVCache] = None,
130+
) -> None:
131+
super().__init__(config, block_idx, kv_cache)
113132
# key, query, value projections for all heads, but in a batch
114133
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
115134
self.qkv = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias)

litgpt/api.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,11 @@ def distribute(
383383
kv_cache_size = model.max_seq_length
384384
else:
385385
kv_cache_size = fixed_kv_cache_size
386-
model.set_kv_cache(batch_size=1, max_seq_length=kv_cache_size, device=fabric.device)
386+
model.set_kv_cache(
387+
batch_size=1,
388+
max_seq_length=kv_cache_size,
389+
device=fabric.device,
390+
)
387391
self.kv_cache_initialized = True
388392
self.fixed_kv_cache_size = fixed_kv_cache_size
389393

@@ -508,20 +512,26 @@ def generate(
508512
prompt_length = input_ids.size(0)
509513
max_returned_tokens = prompt_length + max_new_tokens
510514

515+
if self.fabric is not None:
516+
device = self.fabric.device
517+
else:
518+
device = self.preprocessor.device
511519
if not self.kv_cache_initialized:
512-
if self.fabric is not None:
513-
device = self.fabric.device
514-
else:
515-
device = self.preprocessor.device
516-
self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=device)
520+
self.model.set_kv_cache(
521+
batch_size=1,
522+
max_seq_length=max_returned_tokens,
523+
device=device,
524+
)
517525
self.kv_cache_initialized = True
518526

519527
# Dynamically grow the kv cache size if necessary
520528
if not self.fixed_kv_cache_size and self.prev_generated_seq_length < max_returned_tokens:
521-
tmp_device = self.model.mask_cache.device
522529
self.model.clear_kv_cache()
523-
self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=tmp_device)
524-
530+
self.model.set_kv_cache(
531+
batch_size=1,
532+
max_seq_length=max_returned_tokens,
533+
device=device,
534+
)
525535
else:
526536
for block in self.model.transformer.h:
527537
block.attn.kv_cache.reset_parameters()

0 commit comments

Comments
 (0)