Skip to content

Commit 93edc84

Browse files
protobird-gitcopybara-github
authored andcommitted
Remove kv_cache_max_len from ModelConfig.
- This is the first step to make kv_cache_max_len configurable when model is loaded for inference - Infer kv_cache_max_len from kv_cache or mask. Either of them must be not null - Pass kv_cache_max_len as parameter during export - Build mask_cache only when mask_as_input is false - Confirmed that conversion generates the same tflite files before and after for gemma3, llama, and deepseek PiperOrigin-RevId: 766326675
1 parent 5492f26 commit 93edc84

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+470
-545
lines changed

ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,8 @@ class AmdLlama(model_builder.DecoderOnlyModel):
2929
pass
3030

3131

32-
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
33-
"""Returns the model config for an AMD-Llama-135m model.
34-
35-
Args:
36-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
37-
is 1024.
38-
39-
Returns:
40-
The model config for an AMD-Llama-135m model.
41-
"""
32+
def get_model_config() -> cfg.ModelConfig:
33+
"""Returns the model config for an AMD-Llama-135m model."""
4234
attn_config = cfg.AttentionConfig(
4335
num_heads=12,
4436
head_dim=64,
@@ -63,16 +55,15 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
6355
num_layers=12,
6456
max_seq_len=2048,
6557
embedding_dim=768,
66-
kv_cache_max_len=kv_cache_max_len,
6758
block_configs=block_config,
6859
final_norm_config=norm_config,
6960
lm_head_share_weight_with_embedding=False,
7061
)
7162
return config
7263

7364

74-
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
75-
config = get_model_config(**kwargs)
65+
def get_fake_model_config() -> cfg.ModelConfig:
66+
config = get_model_config()
7667
config.vocab_size = 128
7768
config.num_layers = 2
7869
config.block_config(0).ff_config.intermediate_size = 64
@@ -82,12 +73,13 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
8273
def build_model(
8374
checkpoint_path: str,
8475
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
85-
**kwargs
76+
mask_cache_size: int = 0,
8677
) -> nn.Module:
8778
return model_builder.build_decoder_only_model(
8879
checkpoint_path=checkpoint_path,
89-
config=get_model_config(**kwargs),
80+
config=get_model_config(),
9081
tensor_names=TENSOR_NAMES,
9182
model_class=AmdLlama,
9283
custom_loader=custom_loader,
84+
mask_cache_size=mask_cache_size,
9385
)

ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,14 @@ def main(_):
3131
custom_loader=loader.maybe_get_custom_loader(
3232
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
3333
),
34-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
34+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
3535
)
3636
converter.convert_to_tflite(
3737
pytorch_model,
3838
output_path=flags.FLAGS.output_path,
3939
output_name_prefix=flags.FLAGS.output_name_prefix,
4040
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
4142
quantize=flags.FLAGS.quantize,
4243
lora_ranks=flags.FLAGS.lora_ranks,
4344
export_config=export_config.get_from_flags(),

ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,22 @@
2323

2424
flags = converter.define_conversion_flags('deepseek')
2525

26+
2627
def main(_):
2728
checkpoint_path = flags.FLAGS.checkpoint_path
2829
pytorch_model = deepseek.build_model(
2930
checkpoint_path,
3031
custom_loader=loader.maybe_get_custom_loader(
3132
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
3233
),
33-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
34+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
3435
)
3536
converter.convert_to_tflite(
3637
pytorch_model,
3738
output_path=flags.FLAGS.output_path,
3839
output_name_prefix=flags.FLAGS.output_name_prefix,
3940
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
4042
quantize=flags.FLAGS.quantize,
4143
lora_ranks=flags.FLAGS.lora_ranks,
4244
export_config=export_config.get_from_flags(),

ai_edge_torch/generative/examples/deepseek/deepseek.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,8 @@ class DeepSeekDistillQwen(model_builder.DecoderOnlyModel):
2929
pass
3030

3131

32-
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
33-
"""Returns the model config for a Qwen 2.5 3B model.
34-
35-
Args:
36-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
37-
is 1024.
38-
39-
Returns:
40-
The model config for a SmolLM model.
41-
"""
32+
def get_model_config() -> cfg.ModelConfig:
33+
"""Returns the model config for a Qwen 2.5 3B model."""
4234
attn_config = cfg.AttentionConfig(
4335
num_heads=12,
4436
head_dim=128,
@@ -66,16 +58,15 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
6658
num_layers=28,
6759
max_seq_len=4096,
6860
embedding_dim=1536,
69-
kv_cache_max_len=kv_cache_max_len,
7061
block_configs=block_config,
7162
final_norm_config=norm_config,
7263
lm_head_share_weight_with_embedding=False,
7364
)
7465
return config
7566

7667

77-
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
78-
config = get_model_config(**kwargs)
68+
def get_fake_model_config() -> cfg.ModelConfig:
69+
config = get_model_config()
7970
config.vocab_size = 128
8071
config.num_layers = 2
8172
# DeepSeek-R1-Distill-Qwen has only one block config.
@@ -86,12 +77,13 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
8677
def build_model(
8778
checkpoint_path: str,
8879
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
89-
**kwargs
80+
mask_cache_size: int = 0,
9081
) -> nn.Module:
9182
return model_builder.build_decoder_only_model(
9283
checkpoint_path=checkpoint_path,
93-
config=get_model_config(**kwargs),
84+
config=get_model_config(),
9485
tensor_names=TENSOR_NAMES,
9586
model_class=DeepSeekDistillQwen,
9687
custom_loader=custom_loader,
88+
mask_cache_size=mask_cache_size,
9789
)

ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,14 @@ def main(_):
3131
custom_loader=loader.maybe_get_custom_loader(
3232
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
3333
),
34-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
34+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
3535
)
3636
converter.convert_to_tflite(
3737
pytorch_model,
3838
output_path=flags.FLAGS.output_path,
3939
output_name_prefix=flags.FLAGS.output_name_prefix,
4040
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
4142
quantize=flags.FLAGS.quantize,
4243
lora_ranks=flags.FLAGS.lora_ranks,
4344
export_config=export_config.get_from_flags(),

ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,14 @@ def main(_):
3333
custom_loader=loader.maybe_get_custom_loader(
3434
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
3535
),
36-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
36+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
3737
)
3838
converter.convert_to_tflite(
3939
pytorch_model,
4040
output_path=flags.FLAGS.output_path,
4141
output_name_prefix=flags.FLAGS.output_name_prefix,
4242
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
43+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
4344
quantize=flags.FLAGS.quantize,
4445
lora_ranks=flags.FLAGS.lora_ranks,
4546
export_config=export_config.get_from_flags(),

ai_edge_torch/generative/examples/gemma/gemma1.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,8 @@ class Gemma1(model_builder.DecoderOnlyModel):
4242
pass
4343

4444

45-
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
46-
"""Returns the model config for a Gemma 2B model.
47-
48-
Args:
49-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
50-
is 1024.
51-
52-
Returns:
53-
The model config for a Gemma 2B model.
54-
"""
45+
def get_model_config_2b() -> cfg.ModelConfig:
46+
"""Returns the model config for a Gemma 2B model."""
5547
attn_config = cfg.AttentionConfig(
5648
num_heads=8,
5749
head_dim=256,
@@ -80,33 +72,33 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
8072
max_seq_len=8192,
8173
embedding_dim=embedding_dim,
8274
embedding_scale=embedding_dim**0.5,
83-
kv_cache_max_len=kv_cache_max_len,
8475
block_configs=block_config,
8576
final_norm_config=norm_config,
8677
lm_head_use_bias=False,
8778
)
8879
return config
8980

9081

91-
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
92-
config = get_model_config_2b(kv_cache_max_len)
82+
def get_fake_model_config() -> cfg.ModelConfig:
83+
config = get_model_config_2b()
9384
# Gemma has only one block config.
9485
config.block_config(0).ff_config.intermediate_size = 128
9586
config.vocab_size = 128
9687
config.num_layers = 2
97-
config.max_seq_len = 2 * kv_cache_max_len
88+
config.max_seq_len = 256
9889
return config
9990

10091

10192
def build_2b_model(
10293
checkpoint_path: str,
10394
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
104-
**kwargs
95+
mask_cache_size: int = 0,
10596
) -> nn.Module:
10697
return model_builder.build_decoder_only_model(
10798
checkpoint_path=checkpoint_path,
108-
config=get_model_config_2b(**kwargs),
99+
config=get_model_config_2b(),
109100
tensor_names=TENSOR_NAMES,
110101
model_class=Gemma1,
111102
custom_loader=custom_loader,
103+
mask_cache_size=mask_cache_size,
112104
)

ai_edge_torch/generative/examples/gemma/gemma2.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def forward(
104104
class Gemma2(nn.Module):
105105
"""A Gemma2 model built from the Edge Generative API layers."""
106106

107-
def __init__(self, config: cfg.ModelConfig):
107+
def __init__(self, config: cfg.ModelConfig, mask_cache_size: int = 0):
108108
super().__init__()
109109

110110
# Construct model layers.
@@ -126,17 +126,24 @@ def __init__(self, config: cfg.ModelConfig):
126126
config.embedding_dim,
127127
config.final_norm_config,
128128
)
129-
self.mask_cache = attn_utils.build_causal_mask_cache(
130-
size=config.kv_cache_max,
131-
)
129+
self.config = config
130+
self.build_mask_cache(mask_cache_size)
131+
132+
def build_mask_cache(self, mask_cache_size: int):
133+
assert (
134+
mask_cache_size <= self.config.max_seq_len
135+
), "Mask cache size must be less than or equal to the max seq length."
136+
if mask_cache_size <= 0:
137+
self.mask_cache = None
138+
self.sliding_window_mask_cache = None
139+
return
140+
self.mask_cache = attn_utils.build_causal_mask_cache(mask_cache_size)
132141
# Gemma2 has same hyper parameters for each layer except for attention
133142
# types. Use the first layer.
134-
attn_config = config.block_config(0).attn_config
135143
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
136-
size=config.kv_cache_max,
137-
window_size=attn_config.sliding_window_size,
144+
size=mask_cache_size,
145+
window_size=self.config.block_config(0).attn_config.sliding_window_size,
138146
)
139-
self.config = config
140147

141148
def get_attention_mask(
142149
self, attn_type: cfg.AttentionType, input_pos: torch.Tensor
@@ -167,6 +174,7 @@ def forward(
167174
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
168175
rope = rotary_pos_emb.build_rope(input_pos, n_elem, attn_config.rotary_base)
169176
if mask is None:
177+
assert self.mask_cache is not None, "Mask cache must be built."
170178
mask = [
171179
self.get_attention_mask(
172180
self.config.block_config(i).attn_config.attn_type, input_pos
@@ -222,16 +230,8 @@ def _forward_with_embeds(
222230
return {"logits": res, "kv_cache": updated_kv_cache}
223231

224232

225-
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
226-
"""Returns the model config for a Gemma2 2B model.
227-
228-
Args:
229-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
230-
is 1024.
231-
232-
Returns:
233-
The model config for a Gemma 2B model.
234-
"""
233+
def get_model_config_2b() -> cfg.ModelConfig:
234+
"""Returns the model config for a Gemma2 2B model."""
235235
norm_config = cfg.NormalizationConfig(
236236
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
237237
)
@@ -277,7 +277,6 @@ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
277277
max_seq_len=8192,
278278
embedding_dim=embedding_dim,
279279
embedding_scale=embedding_dim**0.5,
280-
kv_cache_max_len=kv_cache_max_len,
281280
block_configs=[get_block_config(i) for i in range(num_layers)],
282281
final_norm_config=norm_config,
283282
lm_head_use_bias=False,
@@ -286,11 +285,11 @@ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
286285
return config
287286

288287

289-
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
290-
config = get_model_config_2b(kv_cache_max_len)
288+
def get_fake_model_config() -> cfg.ModelConfig:
289+
config = get_model_config_2b()
291290
config.vocab_size = 128
292291
config.num_layers = 2
293-
config.max_seq_len = 2 * kv_cache_max_len
292+
config.max_seq_len = 256
294293
config.embedding_dim = 128
295294
config.embedding_scale = config.embedding_dim**0.5
296295
config.block_configs = config.block_configs[: config.num_layers]
@@ -305,16 +304,17 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
305304
def build_2b_model(
306305
checkpoint_path: str,
307306
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
308-
**kwargs,
307+
mask_cache_size: int = 0,
309308
) -> nn.Module:
310309
for tensor_names in TENSOR_NAMES_DICT.values():
311310
try:
312311
return model_builder.build_decoder_only_model(
313312
checkpoint_path=checkpoint_path,
314-
config=get_model_config_2b(**kwargs),
313+
config=get_model_config_2b(),
315314
tensor_names=tensor_names,
316315
model_class=Gemma2,
317316
custom_loader=custom_loader,
317+
mask_cache_size=mask_cache_size,
318318
)
319319
except KeyError as _:
320320
continue

ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def main(_):
4040
custom_loader=loader.maybe_get_custom_loader(
4141
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
4242
),
43-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
43+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
4444
)
4545
else:
4646
raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
@@ -50,6 +50,7 @@ def main(_):
5050
output_path=flags.FLAGS.output_path,
5151
output_name_prefix=flags.FLAGS.output_name_prefix,
5252
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
53+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
5354
quantize=flags.FLAGS.quantize,
5455
lora_ranks=flags.FLAGS.lora_ranks,
5556
export_config=export_config.get_from_flags(),

0 commit comments

Comments
 (0)