Skip to content

Commit e41e8d4

Browse files
committed
Support for advanced KV caching and batch generation
1 parent 4ea2542 commit e41e8d4

30 files changed

+2501
-772
lines changed

litgpt/adapter.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from litgpt.model import GPT as BaseModel
2020
from litgpt.model import Block as BaseBlock
2121
from litgpt.model import CausalSelfAttention as BaseCausalSelfAttention
22+
from litgpt.kvcache.base import KVCache, KeysAndValues, DefaultKeysAndValues
2223

2324

2425
@dataclass
@@ -49,6 +50,7 @@ def __init__(self, config: Config) -> None:
4950
)
5051
self.mask_cache: Optional[torch.Tensor] = None
5152
self.max_seq_length = self.config.block_size
53+
self._default_kv_cache = False
5254

5355
@classmethod
5456
def from_name(cls, name: str, **kwargs: Any) -> Self:
@@ -62,17 +64,27 @@ def _init_weights(self, module: nn.Module) -> None:
6264

6365

6466
class Block(BaseBlock):
65-
def __init__(self, config: Config, block_idx: int) -> None:
66-
super().__init__(config, block_idx)
67-
self.attn = CausalSelfAttention(config, block_idx)
67+
def __init__(
68+
self,
69+
config: Config,
70+
block_idx: int,
71+
kv_cache: Optional[KVCache] = None,
72+
) -> None:
73+
super().__init__(config, block_idx, kv_cache)
74+
self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache)
6875

6976

7077
class CausalSelfAttention(BaseCausalSelfAttention):
7178
"""A modification of `litgpt.model.CausalSelfAttention` that adds the attention
7279
over the adaption prompt."""
7380

74-
def __init__(self, config: Config, block_idx: int) -> None:
75-
super().__init__(config, block_idx)
81+
def __init__(
82+
self,
83+
config: Config,
84+
block_idx: int,
85+
kv_cache: Optional[KVCache] = None,
86+
) -> None:
87+
super().__init__(config, block_idx, kv_cache)
7688
if block_idx >= config.adapter_start_layer:
7789
# adapter embedding layer
7890
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
@@ -82,11 +94,16 @@ def __init__(self, config: Config, block_idx: int) -> None:
8294
self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
8395

8496
def scaled_dot_product_attention(
85-
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
86-
) -> torch.Tensor:
87-
y = super().scaled_dot_product_attention(q, k, v, mask)
97+
self,
98+
q: torch.Tensor,
99+
k_and_v: KeysAndValues,
100+
mask: Optional[torch.Tensor] = None,
101+
is_causal: bool = True,
102+
return_scores: bool = False,
103+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
104+
y, scores = super().scaled_dot_product_attention(q, k_and_v, mask, is_causal, return_scores)
88105
if self.block_idx < self.config.adapter_start_layer:
89-
return y
106+
return y, scores
90107

91108
aT = self.config.adapter_prompt_length
92109
if self.adapter_kv_cache is not None:
@@ -110,8 +127,14 @@ def scaled_dot_product_attention(
110127

111128
T = q.size(2)
112129
amask = torch.ones(T, aT, dtype=torch.bool, device=q.device)
113-
ay = super().scaled_dot_product_attention(q, ak, av, amask)
114-
return y + self.gating_factor * ay
130+
a_k_and_v = DefaultKeysAndValues(keys=ak, values=av)
131+
ay, _ = super().scaled_dot_product_attention(
132+
q=q,
133+
k_and_v=a_k_and_v,
134+
mask=amask,
135+
is_causal=False,
136+
)
137+
return y + self.gating_factor * ay, scores
115138

116139
def reset_parameters(self) -> None:
117140
if hasattr(self, "gating_factor"):

litgpt/adapter_v2.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010

1111
from dataclasses import dataclass
12-
from typing import Any, Dict, Type, Optional
12+
from typing import Any, Dict, Type, Optional, List
1313

1414
import torch
1515
import torch.nn as nn
@@ -22,6 +22,7 @@
2222
from litgpt.adapter import Config as BaseConfig
2323
from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble
2424
from litgpt.utils import map_old_state_dict_weights
25+
from litgpt.kvcache.base import KVCache
2526

2627

2728
@dataclass
@@ -84,6 +85,7 @@ def __init__(self, config: Config) -> None:
8485
)
8586
self.mask_cache: Optional[torch.Tensor] = None
8687
self.max_seq_length = self.config.block_size
88+
self._default_kv_cache = False
8789

8890
@classmethod
8991
def from_name(cls, name: str, **kwargs: Any) -> Self:
@@ -103,18 +105,28 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa
103105

104106

105107
class Block(BaseBlock):
106-
def __init__(self, config: Config, block_idx: int) -> None:
107-
super().__init__(config, block_idx)
108-
self.attn = CausalSelfAttention(config, block_idx)
108+
def __init__(
109+
self,
110+
config: Config,
111+
block_idx: int,
112+
kv_cache: Optional[KVCache] = None,
113+
) -> None:
114+
super().__init__(config, block_idx, kv_cache)
115+
self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache)
109116
self.mlp = config.mlp_class(config)
110117

111118

112119
class CausalSelfAttention(BaseCausalSelfAttention):
113120
"""A modification of `litgpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class"""
114121

115122
# Copy&paste from :class:`model.CausalSelfAttention`
116-
def __init__(self, config: Config, block_idx: int) -> None:
117-
super().__init__(config, block_idx)
123+
def __init__(
124+
self,
125+
config: Config,
126+
block_idx: int,
127+
kv_cache: Optional[KVCache] = None,
128+
) -> None:
129+
super().__init__(config, block_idx, kv_cache)
118130
# key, query, value projections for all heads, but in a batch
119131
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
120132
self.qkv = AdapterV2Linear(

litgpt/api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,6 @@ def generate(
504504
tmp_device = self.model.mask_cache.device
505505
self.model.clear_kv_cache()
506506
self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=tmp_device)
507-
508507
else:
509508
for block in self.model.transformer.h:
510509
block.attn.kv_cache.reset_parameters()

litgpt/chat/base.py

Lines changed: 77 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def generate(
3131
prompt: torch.Tensor,
3232
max_returned_tokens: int,
3333
*,
34+
prompt_chunksize: int = 1,
3435
temperature: float = 1.0,
3536
top_k: Optional[int] = None,
3637
top_p: float = 1.0,
@@ -60,35 +61,60 @@ def generate(
6061
or https://huyenchip.com/2024/01/16/sampling.html#top_p
6162
stop_tokens: If specified, stop generating any more token once one of this list is generated.
6263
"""
63-
from litgpt.generate.base import generate_fn
64-
return generate_fn(
65-
include_prompt=False,
66-
include_eos=False,
67-
model=model,
68-
prompt=prompt,
69-
max_returned_tokens=max_returned_tokens,
70-
temperature=temperature,
71-
top_k=top_k,
72-
top_p=top_p,
73-
stop_tokens=stop_tokens
64+
from litgpt.generate.base import batched_generate_fn
65+
66+
return map(
67+
lambda lst: lst[0],
68+
batched_generate_fn(
69+
model=model,
70+
prompts=[prompt],
71+
max_returned_tokens=max_returned_tokens,
72+
prompt_chunksize=prompt_chunksize,
73+
sample_args = dict(
74+
temperature=temperature,
75+
top_k=top_k,
76+
top_p=top_p,
77+
),
78+
stop_tokens=stop_tokens,
79+
include_prompt=False,
80+
include_eos=False,
81+
)
7482
)
7583

7684

77-
def process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens):
85+
def process_prompt(
86+
prompt: str,
87+
model: GPT,
88+
tokenizer,
89+
prompt_style,
90+
fabric,
91+
max_new_tokens: int,
92+
prompt_chunksize: int,
93+
temperature: float,
94+
top_k: Optional[int],
95+
top_p: float,
96+
stop_tokens: Tuple[List[int], ...],
97+
):
7898
prompt = prompt_style.apply(prompt=prompt)
7999
encoded_prompt = tokenizer.encode(prompt, device=fabric.device)
80100

81101
if max_new_tokens is None:
82102
max_returned_tokens = model.max_seq_length
83103
else:
84-
first_turn = model.mask_cache is None
85104
max_returned_tokens = encoded_prompt.size(0) + max_new_tokens
86-
if first_turn or max_returned_tokens > model.max_seq_length:
105+
msl = model.max_seq_length
106+
if max_returned_tokens > msl or model.config.block_size == msl:
87107
model.max_seq_length = max_returned_tokens
88-
model.set_kv_cache(batch_size=1, device=fabric.device)
89108

90109
y: Iterator[torch.Tensor] = generate(
91-
model, encoded_prompt, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=stop_tokens
110+
model=model,
111+
prompt=encoded_prompt,
112+
max_returned_tokens=max_returned_tokens,
113+
prompt_chunksize=prompt_chunksize,
114+
temperature=temperature,
115+
top_k=top_k,
116+
top_p=top_p,
117+
stop_tokens=stop_tokens,
92118
)
93119
token_generator: Iterator[str] = tokenizer.decode_stream(y, device=fabric.device)
94120

@@ -103,8 +129,7 @@ def process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature,
103129

104130
t = time.perf_counter() - t0
105131

106-
for block in model.transformer.h:
107-
block.attn.kv_cache.reset_parameters()
132+
model.clear_kv_cache()
108133
fabric.print(
109134
f"\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec,"
110135
f" {tokens_generated} tokens",
@@ -113,7 +138,19 @@ def process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature,
113138
fabric.print()
114139

115140

116-
def interact(multiline, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens):
141+
def interact(
142+
multiline: bool,
143+
model: GPT,
144+
tokenizer,
145+
prompt_style,
146+
fabric,
147+
max_new_tokens: int,
148+
prompt_chunksize: int,
149+
temperature: float,
150+
top_k: Optional[int],
151+
top_p: float,
152+
stop_tokens: Tuple[List[int], ...],
153+
):
117154
while True:
118155
try:
119156
if not multiline:
@@ -135,14 +172,27 @@ def interact(multiline, model, tokenizer, prompt_style, fabric, temperature, max
135172
if not prompt or prompt in ("!quit", "!exit"):
136173
break
137174

138-
process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens)
175+
process_prompt(
176+
prompt=prompt,
177+
model=model,
178+
tokenizer=tokenizer,
179+
prompt_style=prompt_style,
180+
fabric=fabric,
181+
temperature=temperature,
182+
max_new_tokens=max_new_tokens,
183+
prompt_chunksize=prompt_chunksize,
184+
top_k=top_k,
185+
top_p=top_p,
186+
stop_tokens=stop_tokens,
187+
)
139188

140189

141190
@torch.inference_mode()
142191
def main(
143192
checkpoint_dir: Path,
144193
*,
145194
max_new_tokens: int = 50,
195+
prompt_chunksize: int = 1,
146196
top_k: Optional[int] = 50,
147197
top_p: float = 1.0,
148198
temperature: float = 0.8,
@@ -158,6 +208,11 @@ def main(
158208
checkpoint_dir: A local path to a directory containing the model weights or a valid model name.
159209
You can get a list of valid model names via the `litgpt download list` command line argument.
160210
max_new_tokens: The number of generation steps to take.
211+
prompt_chunksize: If even the shortest prompt is longer than the KV
212+
cache, prompts are processed in chunks of this size in the
213+
prefill phase. Once the shortest has been processed to the
214+
end, we proceed with chunk size 1.
215+
Defaults to 1, but larger values are recommended for long prompts.
161216
top_k: The number of top most probable tokens to consider in the sampling process.
162217
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
163218
In top-p sampling, the next token is sampled from the highest probability tokens
@@ -252,8 +307,9 @@ def main(
252307
tokenizer=tokenizer,
253308
prompt_style=prompt_style,
254309
fabric=fabric,
255-
temperature=temperature,
256310
max_new_tokens=(None if compile else max_new_tokens),
311+
prompt_chunksize=prompt_chunksize,
312+
temperature=temperature,
257313
top_k=top_k,
258314
top_p=top_p,
259315
stop_tokens=stop_tokens

litgpt/finetune/adapter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,8 +399,12 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E
399399
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
400400
model.set_kv_cache(batch_size=1)
401401
output = generate(
402-
model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id
403-
)
402+
model=model,
403+
prompts=[encoded],
404+
max_returned_tokens=max_returned_tokens,
405+
temperature=0.8,
406+
eos_id=tokenizer.eos_id,
407+
)[0]
404408
model.clear_kv_cache()
405409
model.train()
406410
output = tokenizer.decode(output)

litgpt/finetune/adapter_v2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,8 +396,12 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E
396396
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
397397
model.set_kv_cache(batch_size=1)
398398
output = generate(
399-
model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id
400-
)
399+
model=model,
400+
prompts=[encoded],
401+
max_returned_tokens=max_returned_tokens,
402+
temperature=0.8,
403+
eos_id=tokenizer.eos_id,
404+
)[0]
401405
model.clear_kv_cache()
402406
model.train()
403407
output = tokenizer.decode(output)

litgpt/finetune/full.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,12 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E
366366
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
367367
model.set_kv_cache(batch_size=1)
368368
output = generate(
369-
model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id
370-
)
369+
model=model,
370+
prompts=[encoded],
371+
max_returned_tokens=max_returned_tokens,
372+
temperature=0.8,
373+
eos_id=tokenizer.eos_id,
374+
)[0]
371375
model.clear_kv_cache()
372376
model.train()
373377
output = tokenizer.decode(output)

litgpt/finetune/lora.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,8 +428,12 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E
428428
# do not set `max_seq_length=max_returned_token` because memory is not a concern here
429429
model.set_kv_cache(batch_size=1)
430430
output = generate(
431-
model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id
432-
)
431+
model=model,
432+
prompts=[encoded],
433+
max_returned_tokens=max_returned_tokens,
434+
temperature=0.8,
435+
eos_id=tokenizer.eos_id,
436+
)[0]
433437
model.clear_kv_cache()
434438
model.train()
435439
output = tokenizer.decode(output)

0 commit comments

Comments
 (0)