Skip to content

Commit 30ad826

Browse files
haozha111copybara-github
authored andcommitted
Allow all verify_util helpers to take a custom loader object.
PiperOrigin-RevId: 765248503
1 parent f9c49af commit 30ad826

File tree

9 files changed

+53
-51
lines changed

9 files changed

+53
-51
lines changed

ai_edge_torch/generative/examples/amd_llama_135m/verify_util.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
import logging
1717
import os
1818
import pathlib
19+
from typing import Callable, Dict
1920

2021
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
2122
from ai_edge_torch.generative.utilities import loader
2223
from ai_edge_torch.generative.utilities import transformers_verifier
2324
from ai_edge_torch.generative.utilities import verifier
25+
import torch
2426
import transformers
2527

2628

@@ -31,8 +33,9 @@ def verify_amd_llama_135m(
3133
checkpoint_dir: str,
3234
weight_filename: str = "model.safetensors",
3335
max_new_tokens: int = 30,
34-
initialize_from_local: bool = True,
3536
prompts: list[str] | None = None,
37+
initialize_from_local: bool = True,
38+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
3639
) -> bool:
3740
"""Verifies the reauthored AMD-Llama-135M model with a custom loader."""
3841
logging.info("Loading the original model from: %s", checkpoint_dir)
@@ -41,11 +44,8 @@ def verify_amd_llama_135m(
4144
)
4245

4346
logging.info("Building the reauthored model from: %s", checkpoint_dir)
44-
custom_loader = (
45-
None
46-
if initialize_from_local
47-
else loader.get_custom_loader("", "safetensors")
48-
)
47+
if custom_loader is None and not initialize_from_local:
48+
custom_loader = loader.get_custom_loader("", "safetensors")
4949

5050
if initialize_from_local:
5151
# Locate the cached dir.

ai_edge_torch/generative/examples/deepseek/verify_util.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
import logging
1717
import os
1818
import pathlib
19+
from typing import Callable, Dict
1920

2021
from ai_edge_torch.generative.examples.deepseek import deepseek
2122
from ai_edge_torch.generative.utilities import loader
2223
from ai_edge_torch.generative.utilities import transformers_verifier
2324
from ai_edge_torch.generative.utilities import verifier
25+
import torch
2426
import transformers
2527

2628

@@ -31,8 +33,9 @@ def verify_deepseek_r1_distill_1_5b(
3133
checkpoint_dir: str,
3234
weight_filename: str = "model.safetensors",
3335
max_new_tokens: int = 30,
34-
initialize_from_local: bool = True,
3536
prompts: list[str] | None = None,
37+
initialize_from_local: bool = True,
38+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
3639
) -> bool:
3740
"""Verifies the reauthored DeepSeek R1 distilled 1.5B model with a custom loader."""
3841
logging.info("Loading the original model from: %s", checkpoint_dir)
@@ -41,11 +44,8 @@ def verify_deepseek_r1_distill_1_5b(
4144
)
4245

4346
logging.info("Building the reauthored model from: %s", checkpoint_dir)
44-
custom_loader = (
45-
None
46-
if initialize_from_local
47-
else loader.get_custom_loader("", "safetensors")
48-
)
47+
if custom_loader is None and not initialize_from_local:
48+
custom_loader = loader.get_custom_loader("", "safetensors")
4949

5050
if initialize_from_local:
5151
# Locate the cached dir.

ai_edge_torch/generative/examples/llama/verify_util.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
import logging
1717
import os
1818
import pathlib
19+
from typing import Callable, Dict
1920

2021
from ai_edge_torch.generative.examples.llama import llama
2122
from ai_edge_torch.generative.utilities import loader
2223
from ai_edge_torch.generative.utilities import transformers_verifier
2324
from ai_edge_torch.generative.utilities import verifier
25+
import torch
2426
import transformers
2527

2628
_BUILDER = {
@@ -36,8 +38,9 @@ def verify_llama_3_2(
3638
checkpoint_dir: str,
3739
weight_filename: str = "model.safetensors",
3840
max_new_tokens: int = 30,
39-
initialize_from_local: bool = True,
4041
prompts: list[str] | None = None,
42+
initialize_from_local: bool = True,
43+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
4144
) -> bool:
4245
"""Verifies the reauthored Llama 3.2 model with a custom loader."""
4346
logging.info("Loading the original model from: %s", checkpoint_dir)
@@ -46,11 +49,8 @@ def verify_llama_3_2(
4649
)
4750

4851
logging.info("Building the reauthored model from: %s", checkpoint_dir)
49-
custom_loader = (
50-
None
51-
if initialize_from_local
52-
else loader.get_custom_loader("", "safetensors")
53-
)
52+
if custom_loader is None and not initialize_from_local:
53+
custom_loader = loader.get_custom_loader("", "safetensors")
5454

5555
if initialize_from_local:
5656
# Locate the cached dir.

ai_edge_torch/generative/examples/openelm/verify_util.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
import logging
1717
import os
1818
import pathlib
19+
from typing import Callable, Dict
1920

2021
from ai_edge_torch.generative.examples.openelm import openelm
2122
from ai_edge_torch.generative.utilities import loader
2223
from ai_edge_torch.generative.utilities import transformers_verifier
2324
from ai_edge_torch.generative.utilities import verifier
25+
import torch
2426
import transformers
2527

2628

@@ -31,8 +33,9 @@ def verify_openelm(
3133
checkpoint_dir: str,
3234
weight_filename: str = "model.safetensors",
3335
max_new_tokens: int = 30,
34-
initialize_from_local: bool = True,
3536
prompts: list[str] | None = None,
37+
initialize_from_local: bool = True,
38+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
3639
) -> bool:
3740
"""Verifies the reauthored OpenELM model with a custom loader."""
3841
logging.info("Loading the original model from: %s", checkpoint_dir)
@@ -41,11 +44,8 @@ def verify_openelm(
4144
)
4245

4346
logging.info("Building the reauthored model from: %s", checkpoint_dir)
44-
custom_loader = (
45-
None
46-
if initialize_from_local
47-
else loader.get_custom_loader("", "safetensors")
48-
)
47+
if custom_loader is None and not initialize_from_local:
48+
custom_loader = loader.get_custom_loader("", "safetensors")
4949

5050
if initialize_from_local:
5151
# Locate the cached dir.

ai_edge_torch/generative/examples/phi/verify_util.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""Utils for verifying the Phi model."""
16+
1617
import logging
1718
import os
1819
import pathlib
20+
from typing import Callable, Dict
1921

2022
from ai_edge_torch.generative.examples.phi import phi2, phi3, phi4
2123
from ai_edge_torch.generative.utilities import loader
2224
from ai_edge_torch.generative.utilities import transformers_verifier
2325
from ai_edge_torch.generative.utilities import verifier
26+
import torch
2427
import transformers
2528

2629

@@ -38,9 +41,10 @@ def verify_phi(
3841
checkpoint_dir: str,
3942
weight_filename: str = "model.safetensors",
4043
max_new_tokens: int = 30,
41-
initialize_from_local: bool = True,
4244
prompts: list[str] | None = None,
4345
atol: float = 1e-04,
46+
initialize_from_local: bool = True,
47+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
4448
) -> bool:
4549
"""Verifies the reauthored Phi model with a custom loader."""
4650
logging.info("Loading the original model from: %s", checkpoint_dir)
@@ -49,11 +53,8 @@ def verify_phi(
4953
)
5054

5155
logging.info("Building the reauthored model from: %s", checkpoint_dir)
52-
custom_loader = (
53-
None
54-
if initialize_from_local
55-
else loader.get_custom_loader("", "safetensors")
56-
)
56+
if custom_loader is None and not initialize_from_local:
57+
custom_loader = loader.get_custom_loader("", "safetensors")
5758

5859
if initialize_from_local:
5960
# Locate the cached dir.

ai_edge_torch/generative/examples/qwen/verify_util.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""Utils for verifying the Qwen model."""
16+
1617
import logging
1718
import os
1819
import pathlib
20+
from typing import Callable, Dict
1921

2022
from ai_edge_torch.generative.examples.qwen import qwen, qwen3
2123
from ai_edge_torch.generative.utilities import loader
2224
from ai_edge_torch.generative.utilities import transformers_verifier
2325
from ai_edge_torch.generative.utilities import verifier
26+
import torch
2427
import transformers
2528

2629

@@ -50,8 +53,9 @@ def verify_qwen(
5053
checkpoint_dir: str,
5154
weight_filename: str = "model.safetensors",
5255
max_new_tokens: int = 30,
53-
initialize_from_local: bool = True,
5456
prompts: list[str] | None = None,
57+
initialize_from_local: bool = True,
58+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
5559
) -> bool:
5660
"""Verifies the reauthored Llama 3.2 model with a custom loader."""
5761
logging.info("Loading the original model from: %s", checkpoint_dir)
@@ -60,11 +64,8 @@ def verify_qwen(
6064
)
6165

6266
logging.info("Building the reauthored model from: %s", checkpoint_dir)
63-
custom_loader = (
64-
None
65-
if initialize_from_local
66-
else loader.get_custom_loader("", "safetensors")
67-
)
67+
if custom_loader is None and not initialize_from_local:
68+
custom_loader = loader.get_custom_loader("", "safetensors")
6869

6970
if initialize_from_local:
7071
# Locate the cached dir.

ai_edge_torch/generative/examples/smollm/verify_util.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
import logging
1717
import os
1818
import pathlib
19-
19+
from typing import Callable, Dict
2020
from ai_edge_torch.generative.examples.smollm import smollm
2121
from ai_edge_torch.generative.utilities import loader
2222
from ai_edge_torch.generative.utilities import transformers_verifier
2323
from ai_edge_torch.generative.utilities import verifier
24+
import torch
2425
import transformers
2526

2627
_BUILDER = {
@@ -36,8 +37,9 @@ def verify_smollm_135m(
3637
checkpoint_dir: str,
3738
weight_filename: str = "model.safetensors",
3839
max_new_tokens: int = 30,
39-
initialize_from_local: bool = True,
4040
prompts: list[str] | None = None,
41+
initialize_from_local: bool = True,
42+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
4143
) -> bool:
4244
"""Verifies the reauthored SmoLLM model with a custom loader."""
4345
logging.info("Loading the original model from: %s", checkpoint_dir)
@@ -46,11 +48,9 @@ def verify_smollm_135m(
4648
)
4749

4850
logging.info("Building the reauthored model from: %s", checkpoint_dir)
49-
custom_loader = (
50-
None
51-
if initialize_from_local
52-
else loader.get_custom_loader("", "safetensors")
53-
)
51+
52+
if custom_loader is None and not initialize_from_local:
53+
custom_loader = loader.get_custom_loader("", "safetensors")
5454

5555
if initialize_from_local:
5656
# Locate the cached dir.

ai_edge_torch/generative/examples/tiny_llama/verify_util.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""Utils for verifying the TinyLlama model."""
16+
1617
import logging
1718
import os
1819
import pathlib
20+
from typing import Callable, Dict
1921

2022
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
2123
from ai_edge_torch.generative.utilities import loader
2224
from ai_edge_torch.generative.utilities import transformers_verifier
2325
from ai_edge_torch.generative.utilities import verifier
26+
import torch
2427
import transformers
2528

2629

@@ -31,8 +34,9 @@ def verify_tiny_llama(
3134
checkpoint_dir: str,
3235
weight_filename: str = "model.safetensors",
3336
max_new_tokens: int = 30,
34-
initialize_from_local: bool = True,
3537
prompts: list[str] | None = None,
38+
initialize_from_local: bool = True,
39+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
3640
) -> bool:
3741
"""Verifies the reauthored TinyLlama model with a custom loader."""
3842
logging.info("Loading the original model from: %s", checkpoint_dir)
@@ -41,11 +45,8 @@ def verify_tiny_llama(
4145
)
4246

4347
logging.info("Building the reauthored model from: %s", checkpoint_dir)
44-
custom_loader = (
45-
None
46-
if initialize_from_local
47-
else loader.get_custom_loader("", "safetensors")
48-
)
48+
if custom_loader is None and not initialize_from_local:
49+
custom_loader = loader.get_custom_loader("", "safetensors")
4950

5051
if initialize_from_local:
5152
# Locate the cached dir.

ai_edge_torch/generative/utilities/loader.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,13 @@
2323
from safetensors.torch import load_file
2424
import torch
2525

26-
2726
def get_custom_loader(
2827
checkpoint_path: str,
2928
checkpoint_format: Optional[str] = None,
3029
) -> Callable[[str], Dict[str, torch.Tensor]]:
3130
"""Returns a custom loader for the given checkpoint path.
3231
33-
Those customer loaders can either support state dictionary or safetensors, and
32+
Those custome loaders can either support state dictionary or safetensors, and
3433
the actual data might be fetched from a remote source.
3534
3635
Args:

0 commit comments

Comments
 (0)