Skip to content

Commit d554ca1

Browse files
haozha111copybara-github
authored andcommitted
*For all model conversion script, add the support for custom loader, so that the script could use a special loader that loads checkpoint from a remote source.
*Move the `custom_checkpoint_loader` flag to utilities/converter.py so that all conversion script can utilize this feature. PiperOrigin-RevId: 759278654
1 parent b14e2a7 commit d554ca1

36 files changed

+320
-72
lines changed

ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515

1616
"""Example of building AMD-Llama-135m."""
1717

18+
from typing import Callable, Dict
1819
import ai_edge_torch.generative.layers.model_config as cfg
1920
from ai_edge_torch.generative.utilities import model_builder
21+
import torch
2022
from torch import nn
2123

2224
TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
@@ -80,10 +82,15 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
8082
return config
8183

8284

83-
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
85+
def build_model(
86+
checkpoint_path: str,
87+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
88+
**kwargs
89+
) -> nn.Module:
8490
return model_builder.build_decoder_only_model(
8591
checkpoint_path=checkpoint_path,
8692
config=get_model_config(**kwargs),
8793
tensor_names=TENSOR_NAMES,
88-
model_class=AmdLlama
94+
model_class=AmdLlama,
95+
custom_loader=custom_loader,
8996
)

ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,19 @@
1919
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
2020
from ai_edge_torch.generative.utilities import converter
2121
from ai_edge_torch.generative.utilities import export_config
22+
from ai_edge_torch.generative.utilities import loader
2223

2324
flags = converter.define_conversion_flags("amd-llama-135m")
2425

2526

2627
def main(_):
28+
checkpoint_path = flags.FLAGS.checkpoint_path
2729
pytorch_model = amd_llama_135m.build_model(
28-
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
30+
checkpoint_path,
31+
custom_loader=loader.maybe_get_custom_loader(
32+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33+
),
34+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
2935
)
3036
converter.convert_to_tflite(
3137
pytorch_model,

ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,20 @@
1717

1818
from absl import app
1919
from ai_edge_torch.generative.examples.deepseek import deepseek
20-
from ai_edge_torch.generative.layers import kv_cache
2120
from ai_edge_torch.generative.utilities import converter
2221
from ai_edge_torch.generative.utilities import export_config
22+
from ai_edge_torch.generative.utilities import loader
2323

2424
flags = converter.define_conversion_flags('deepseek')
2525

2626
def main(_):
27+
checkpoint_path = flags.FLAGS.checkpoint_path
2728
pytorch_model = deepseek.build_model(
28-
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
29+
checkpoint_path,
30+
custom_loader=loader.maybe_get_custom_loader(
31+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
32+
),
33+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
2934
)
3035
converter.convert_to_tflite(
3136
pytorch_model,

ai_edge_torch/generative/examples/deepseek/deepseek.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515

1616
"""Example of building DeepSeek R1 distilled models."""
1717

18+
from typing import Callable, Dict
1819
import ai_edge_torch.generative.layers.model_config as cfg
1920
from ai_edge_torch.generative.utilities import model_builder
21+
import torch
2022
from torch import nn
2123

2224
TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
@@ -84,10 +86,15 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
8486
return config
8587

8688

87-
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
89+
def build_model(
90+
checkpoint_path: str,
91+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
92+
**kwargs
93+
) -> nn.Module:
8894
return model_builder.build_decoder_only_model(
8995
checkpoint_path=checkpoint_path,
9096
config=get_model_config(**kwargs),
9197
tensor_names=TENSOR_NAMES,
9298
model_class=DeepSeekDistillQwen,
99+
custom_loader=custom_loader,
93100
)

ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,19 @@
1919
from ai_edge_torch.generative.examples.gemma import gemma1
2020
from ai_edge_torch.generative.utilities import converter
2121
from ai_edge_torch.generative.utilities import export_config
22+
from ai_edge_torch.generative.utilities import loader
2223

2324
flags = converter.define_conversion_flags("gemma-2b")
2425

2526

2627
def main(_):
28+
checkpoint_path = flags.FLAGS.checkpoint_path
2729
pytorch_model = gemma1.build_2b_model(
28-
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
30+
checkpoint_path,
31+
custom_loader=loader.maybe_get_custom_loader(
32+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33+
),
34+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
2935
)
3036
converter.convert_to_tflite(
3137
pytorch_model,

ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,21 @@
1919
from ai_edge_torch.generative.examples.gemma import gemma2
2020
from ai_edge_torch.generative.utilities import converter
2121
from ai_edge_torch.generative.utilities import export_config
22+
from ai_edge_torch.generative.utilities import loader
2223

2324
flags = converter.define_conversion_flags(
2425
"gemma2-2b", default_mask_as_input=True, default_transpose_kv_cache=True
2526
)
2627

2728

2829
def main(_):
30+
checkpoint_path = flags.FLAGS.checkpoint_path
2931
pytorch_model = gemma2.build_2b_model(
30-
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
32+
checkpoint_path,
33+
custom_loader=loader.maybe_get_custom_loader(
34+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
35+
),
36+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
3137
)
3238
converter.convert_to_tflite(
3339
pytorch_model,

ai_edge_torch/generative/examples/gemma/gemma1.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515

1616
"""Example of building a Gemma1 model."""
1717

18+
from typing import Callable, Dict
19+
1820
import ai_edge_torch.generative.layers.model_config as cfg
1921
from ai_edge_torch.generative.utilities import model_builder
2022
import ai_edge_torch.generative.utilities.loader as loading_utils
23+
import torch
2124
from torch import nn
2225

2326
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
@@ -99,10 +102,15 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
99102
return config
100103

101104

102-
def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
105+
def build_2b_model(
106+
checkpoint_path: str,
107+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
108+
**kwargs
109+
) -> nn.Module:
103110
return model_builder.build_decoder_only_model(
104111
checkpoint_path=checkpoint_path,
105112
config=get_model_config_2b(**kwargs),
106113
tensor_names=TENSOR_NAMES,
107114
model_class=Gemma1,
115+
custom_loader=custom_loader,
108116
)

ai_edge_torch/generative/examples/gemma/gemma2.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
"""Example of building a Gemma2 model."""
1717

18-
from typing import List, Optional, Tuple
18+
from typing import Callable, Dict, List, Optional, Tuple
1919

2020
from ai_edge_torch.generative.layers import attention
2121
from ai_edge_torch.generative.layers import builder
@@ -306,14 +306,19 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
306306
return config
307307

308308

309-
def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
309+
def build_2b_model(
310+
checkpoint_path: str,
311+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
312+
**kwargs,
313+
) -> nn.Module:
310314
for tensor_names in TENSOR_NAMES_DICT.values():
311315
try:
312316
return model_builder.build_decoder_only_model(
313317
checkpoint_path=checkpoint_path,
314318
config=get_model_config_2b(**kwargs),
315319
tensor_names=tensor_names,
316320
model_class=Gemma2,
321+
custom_loader=custom_loader,
317322
)
318323
except KeyError as _:
319324
continue

ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,6 @@
2525
'gemma3-1b', default_mask_as_input=True, default_transpose_kv_cache=True
2626
)
2727

28-
_CUSTOM_CHECKPOINT_LOADER = flags.DEFINE_bool(
29-
'custom_checkpoint_loader',
30-
False,
31-
'If true, the conversion script will use a custom checkpoint loader which'
32-
' will read a checkpoint from a remote source.',
33-
)
34-
3528
_MODEL_SIZE = flags.DEFINE_string(
3629
'model_size',
3730
'1b',
@@ -40,16 +33,14 @@
4033

4134

4235
def main(_):
43-
custom_loader = None
44-
if flags.FLAGS.custom_checkpoint_loader:
45-
# If loading from a remote source, try to get a custom loader first.
46-
custom_loader = loader.get_custom_loader(flags.FLAGS.checkpoint_path)
47-
36+
checkpoint_path = flags.FLAGS.checkpoint_path
4837
if _MODEL_SIZE.value == '1b':
4938
pytorch_model = gemma3.build_model_1b(
50-
flags.FLAGS.checkpoint_path,
39+
checkpoint_path,
40+
custom_loader=loader.maybe_get_custom_loader(
41+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
42+
),
5143
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
52-
custom_loader=custom_loader,
5344
)
5445
else:
5546
raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')

ai_edge_torch/generative/examples/hammer/convert_to_tflite.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ai_edge_torch.generative.examples.hammer import hammer
2020
from ai_edge_torch.generative.utilities import converter
2121
from ai_edge_torch.generative.utilities import export_config
22+
from ai_edge_torch.generative.utilities import loader
2223

2324
flags = converter.define_conversion_flags('hammer')
2425

@@ -36,8 +37,13 @@
3637

3738

3839
def main(_):
40+
checkpoint_path = flags.FLAGS.checkpoint_path
3941
pytorch_model = _BUILDER[_MODEL_SIZE.value](
40-
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
42+
checkpoint_path,
43+
custom_loader=loader.maybe_get_custom_loader(
44+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
45+
),
46+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
4147
)
4248
converter.convert_to_tflite(
4349
pytorch_model,

0 commit comments

Comments
 (0)