Skip to content

Commit 304a7b9

Browse files
haozha111copybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 760021381
1 parent 00202bf commit 304a7b9

File tree

17 files changed

+662
-244
lines changed

17 files changed

+662
-244
lines changed

ai_edge_torch/generative/examples/amd_llama_135m/verify.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,9 @@
1515

1616
"""Verifies the reauthored AMD-Llama-135M model."""
1717

18-
import logging
19-
import pathlib
20-
2118
from absl import app
2219
from absl import flags
23-
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
24-
from ai_edge_torch.generative.utilities import transformers_verifier
25-
from ai_edge_torch.generative.utilities import verifier
26-
import transformers
20+
from ai_edge_torch.generative.examples.amd_llama_135m import verify_util
2721

2822

2923
_PROMPTS = flags.DEFINE_multi_string(
@@ -39,32 +33,10 @@
3933

4034

4135
def main(_):
42-
checkpoint = "amd/AMD-Llama-135m"
43-
logging.info("Loading the original model from: %s", checkpoint)
44-
original_model = transformers.AutoModelForCausalLM.from_pretrained(
45-
checkpoint, trust_remote_code=True
46-
)
47-
48-
# Locate the cached dir.
49-
cached_config_file = transformers.utils.cached_file(
50-
checkpoint, transformers.utils.CONFIG_NAME
51-
)
52-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
53-
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
54-
reauthored_model = amd_llama_135m.build_model(str(reauthored_checkpoint))
55-
56-
logging.info("Loading the tokenizer from: %s", checkpoint)
57-
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
58-
59-
verifier.verify_reauthored_model(
60-
original_model=transformers_verifier.TransformersModelWrapper(
61-
original_model
62-
),
63-
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64-
tokenizer=verifier.TokenizerWrapper(tokenizer),
65-
generate_prompts=_PROMPTS.value,
36+
verify_util.verify_amd_llama_135m(
37+
"amd/AMD-Llama-135m",
6638
max_new_tokens=_MAX_NEW_TOKENS.value,
67-
atol=1e-04,
39+
prompts=_PROMPTS.value,
6840
)
6941

7042

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Utils for verifying the AMD-Llama-135M model."""
16+
import logging
17+
import os
18+
import pathlib
19+
20+
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
21+
from ai_edge_torch.generative.utilities import loader
22+
from ai_edge_torch.generative.utilities import transformers_verifier
23+
from ai_edge_torch.generative.utilities import verifier
24+
import transformers
25+
26+
27+
DEFAULT_PROMPTS = ["Tell me a story?\nOnce upon a time"]
28+
29+
30+
def verify_amd_llama_135m(
31+
checkpoint_dir: str,
32+
weight_filename: str = "model.safetensors",
33+
max_new_tokens: int = 30,
34+
initialize_from_local: bool = True,
35+
prompts: list[str] | None = None,
36+
) -> bool:
37+
"""Verifies the reauthored AMD-Llama-135M model with a custom loader."""
38+
logging.info("Loading the original model from: %s", checkpoint_dir)
39+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
40+
checkpoint_dir
41+
)
42+
43+
logging.info("Building the reauthored model from: %s", checkpoint_dir)
44+
custom_loader = (
45+
None
46+
if initialize_from_local
47+
else loader.get_custom_loader("", "safetensors")
48+
)
49+
50+
if initialize_from_local:
51+
# Locate the cached dir.
52+
cached_config_file = transformers.utils.cached_file(
53+
checkpoint_dir, transformers.utils.CONFIG_NAME
54+
)
55+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
56+
else:
57+
reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
58+
59+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
60+
reauthored_model = amd_llama_135m.build_model(
61+
checkpoint_path=reauthored_checkpoint,
62+
custom_loader=custom_loader,
63+
)
64+
65+
logging.info("Loading the tokenizer from: %s", checkpoint_dir)
66+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
67+
return verifier.verify_reauthored_model(
68+
original_model=transformers_verifier.TransformersModelWrapper(
69+
original_model
70+
),
71+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
72+
tokenizer=verifier.TokenizerWrapper(tokenizer),
73+
generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
74+
max_new_tokens=max_new_tokens,
75+
atol=1e-04,
76+
)

ai_edge_torch/generative/examples/deepseek/verify.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,9 @@
1515

1616
"""Verifies the reauthored DeepSeek R1 distilled 1.5B model."""
1717

18-
import logging
19-
import pathlib
20-
2118
from absl import app
2219
from absl import flags
23-
from ai_edge_torch.generative.examples.deepseek import deepseek
24-
from ai_edge_torch.generative.utilities import transformers_verifier
25-
from ai_edge_torch.generative.utilities import verifier
26-
import transformers
20+
from ai_edge_torch.generative.examples.deepseek import verify_util
2721

2822

2923
_PROMPTS = flags.DEFINE_multi_string(
@@ -39,30 +33,10 @@
3933

4034

4135
def main(_):
42-
checkpoint = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
43-
logging.info("Loading the original model from: %s", checkpoint)
44-
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45-
46-
# Locate the cached dir.
47-
cached_config_file = transformers.utils.cached_file(
48-
checkpoint, transformers.utils.CONFIG_NAME
49-
)
50-
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51-
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52-
reauthored_model = deepseek.build_model(str(reauthored_checkpoint))
53-
54-
logging.info("Loading the tokenizer from: %s", checkpoint)
55-
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
56-
57-
verifier.verify_reauthored_model(
58-
original_model=transformers_verifier.TransformersModelWrapper(
59-
original_model
60-
),
61-
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62-
tokenizer=verifier.TokenizerWrapper(tokenizer),
63-
generate_prompts=_PROMPTS.value,
36+
verify_util.verify_deepseek_r1_distill_1_5b(
37+
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
6438
max_new_tokens=_MAX_NEW_TOKENS.value,
65-
atol=1e-04,
39+
prompts=_PROMPTS.value,
6640
)
6741

6842

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Utils for verifying the DeepSeek R1 distilled 1.5B model."""
16+
import logging
17+
import os
18+
import pathlib
19+
20+
from ai_edge_torch.generative.examples.deepseek import deepseek
21+
from ai_edge_torch.generative.utilities import loader
22+
from ai_edge_torch.generative.utilities import transformers_verifier
23+
from ai_edge_torch.generative.utilities import verifier
24+
import transformers
25+
26+
27+
DEFAULT_PROMPTS = ["What is the meaning of life?"]
28+
29+
30+
def verify_deepseek_r1_distill_1_5b(
31+
checkpoint_dir: str,
32+
weight_filename: str = "model.safetensors",
33+
max_new_tokens: int = 30,
34+
initialize_from_local: bool = True,
35+
prompts: list[str] | None = None,
36+
) -> bool:
37+
"""Verifies the reauthored DeepSeek R1 distilled 1.5B model with a custom loader."""
38+
logging.info("Loading the original model from: %s", checkpoint_dir)
39+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
40+
checkpoint_dir
41+
)
42+
43+
logging.info("Building the reauthored model from: %s", checkpoint_dir)
44+
custom_loader = (
45+
None
46+
if initialize_from_local
47+
else loader.get_custom_loader("", "safetensors")
48+
)
49+
50+
if initialize_from_local:
51+
# Locate the cached dir.
52+
cached_config_file = transformers.utils.cached_file(
53+
checkpoint_dir, transformers.utils.CONFIG_NAME
54+
)
55+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
56+
else:
57+
reauthored_checkpoint = os.path.join(checkpoint_dir, weight_filename)
58+
59+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
60+
reauthored_model = deepseek.build_model(
61+
checkpoint_path=reauthored_checkpoint,
62+
custom_loader=custom_loader,
63+
)
64+
65+
logging.info("Loading the tokenizer from: %s", checkpoint_dir)
66+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_dir)
67+
return verifier.verify_reauthored_model(
68+
original_model=transformers_verifier.TransformersModelWrapper(
69+
original_model
70+
),
71+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
72+
tokenizer=verifier.TokenizerWrapper(tokenizer),
73+
generate_prompts=DEFAULT_PROMPTS if prompts is None else prompts,
74+
max_new_tokens=max_new_tokens,
75+
atol=1e-04,
76+
)

ai_edge_torch/generative/examples/gemma/verify_util.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717

1818
import logging
1919
import os
20-
from typing import List, Tuple
20+
from typing import Callable, Dict, List, Tuple
2121

22+
from ai_edge_torch.generative.examples.gemma import gemma1
2223
from ai_edge_torch.generative.examples.gemma import gemma2
2324
import ai_edge_torch.generative.layers.attention_utils as attn_utils
2425
import ai_edge_torch.generative.layers.kv_cache as kv_utils
26+
from ai_edge_torch.generative.utilities import loader
2527
from ai_edge_torch.generative.utilities import verifier
2628
from gemma import config as gemma_config
2729
from gemma import model as gemma_model
@@ -107,6 +109,7 @@ def verify_reauthored_gemma_model(
107109
generate_prompts: List[str],
108110
forward_input_ids: List[List[int]],
109111
weight_filename: str = "model.ckpt",
112+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
110113
tokenizer_filename: str = "tokenizer.model",
111114
max_new_tokens: int = 20,
112115
mask_as_input: bool = False,
@@ -125,7 +128,14 @@ def verify_reauthored_gemma_model(
125128

126129
logging.info("Loading the original model from: %s", checkpoint)
127130
original_model = gemma_model.GemmaForCausalLM(config).eval()
128-
original_model.load_weights(os.path.join(checkpoint, weight_filename))
131+
checkpoint_path = os.path.join(checkpoint, weight_filename)
132+
if custom_loader is None:
133+
original_model.load_weights(checkpoint_path)
134+
else:
135+
original_model.load_state_dict(
136+
custom_loader(checkpoint_path)["model_state_dict"],
137+
strict=False,
138+
)
129139

130140
return verifier.verify_reauthored_model(
131141
original_model=GemmaWrapper(original_model),
@@ -144,27 +154,62 @@ def verify_reauthored_gemma_model(
144154

145155

146156
def verify_gemma2(
147-
gemma2_model_path: str,
157+
checkpoint_dir: str,
158+
weight_filename: str,
148159
prompts: List[str],
149160
max_new_tokens: int,
150161
mask_as_input: bool = False,
151162
kv_layout: kv_utils.KVLayout = kv_utils.KV_LAYOUT_DEFAULT,
163+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
152164
) -> bool:
153165
"""Verifies the reauthored Gemma2 model.
154166
155167
Return True if the verification passes, False otherwise.
156168
"""
157-
logging.info("Building the reauthored model from: %s", gemma2_model_path)
158-
reauthored_model = gemma2.build_2b_model(gemma2_model_path)
169+
checkpoint_path = os.path.join(checkpoint_dir, weight_filename)
170+
logging.info("Building the reauthored model from: %s", checkpoint_path)
171+
reauthored_model = gemma2.build_2b_model(checkpoint_path, custom_loader)
159172

160173
return verify_reauthored_gemma_model(
161-
checkpoint=gemma2_model_path,
174+
checkpoint=checkpoint_dir,
162175
variant="2b-v2",
163176
reauthored_model=reauthored_model,
164177
generate_prompts=prompts,
165178
forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
179+
weight_filename=weight_filename,
180+
custom_loader=custom_loader,
166181
max_new_tokens=max_new_tokens,
167182
mask_as_input=mask_as_input,
168183
kv_layout=kv_layout,
169184
atol=1e-04,
170185
)
186+
187+
188+
def verify_gemma1_with_custom_loader(checkpoint_dir: str) -> bool:
189+
"""Verifies the reauthored Gemma1 model with a custom loader."""
190+
weight_filename = "gemma-2b-it.ckpt"
191+
checkpoint_path = os.path.join(checkpoint_dir, weight_filename)
192+
custom_loader = loader.get_custom_loader(checkpoint_path)
193+
reauthored_model = gemma1.build_2b_model(checkpoint_path, custom_loader)
194+
return verify_reauthored_gemma_model(
195+
checkpoint=checkpoint_dir,
196+
variant="2b",
197+
reauthored_model=reauthored_model,
198+
weight_filename=weight_filename,
199+
custom_loader=custom_loader,
200+
generate_prompts=["What is the meaning of life?"],
201+
forward_input_ids=[[1, 2, 3, 4]],
202+
max_new_tokens=30,
203+
)
204+
205+
206+
def verify_gemma2_with_custom_loader(checkpoint_dir: str) -> bool:
207+
"""Verifies the reauthored Gemma2 model with a custom loader."""
208+
return verify_gemma2(
209+
checkpoint_dir=checkpoint_dir,
210+
weight_filename="model.ckpt",
211+
prompts=["What is the meaning of life?"],
212+
max_new_tokens=30,
213+
mask_as_input=True,
214+
custom_loader=loader.get_custom_loader("", checkpoint_format="pt"),
215+
)

ai_edge_torch/generative/examples/gemma3/verify_util.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ai_edge_torch.generative.examples.gemma3 import gemma3
2323
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2424
import ai_edge_torch.generative.layers.attention_utils as attn_utils
25+
from ai_edge_torch.generative.utilities import loader
2526
from ai_edge_torch.generative.utilities import verifier
2627
from gemma import config as gemma_config
2728
from gemma import model as gemma_model
@@ -260,3 +261,15 @@ def verify_gemma3(
260261
custom_loader=custom_loader,
261262
atol=1e-04,
262263
)
264+
265+
266+
def verify_gemma3_with_custom_loader(checkpoint: str) -> bool:
267+
"""Verifies the reauthored Gemma3 model with a custom loader."""
268+
return verify_gemma3(
269+
checkpoint=checkpoint,
270+
prompts=["What is the meaning of life?"],
271+
max_new_tokens=30,
272+
variant="1b",
273+
weight_filename="model.ckpt",
274+
custom_loader=loader.get_custom_loader("", checkpoint_format="pt"),
275+
)

0 commit comments

Comments
 (0)