Skip to content

Commit 00202bf

Browse files
ai-edge-botcopybara-github
authored andcommitted
Add verify_util to run model numerical validation with a custom loader.
PiperOrigin-RevId: 759916326
1 parent b5489ae commit 00202bf

File tree

17 files changed

+244
-662
lines changed

17 files changed

+244
-662
lines changed

ai_edge_torch/generative/examples/amd_llama_135m/verify.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,15 @@
1515

1616
"""Verifies the reauthored AMD-Llama-135M model."""
1717

18+
import logging
19+
import pathlib
20+
1821
from absl import app
1922
from absl import flags
20-
from ai_edge_torch.generative.examples.amd_llama_135m import verify_util
23+
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
24+
from ai_edge_torch.generative.utilities import transformers_verifier
25+
from ai_edge_torch.generative.utilities import verifier
26+
import transformers
2127

2228

2329
_PROMPTS = flags.DEFINE_multi_string(
@@ -33,10 +39,32 @@
3339

3440

3541
def main(_):
36-
verify_util.verify_amd_llama_135m(
37-
"amd/AMD-Llama-135m",
42+
checkpoint = "amd/AMD-Llama-135m"
43+
logging.info("Loading the original model from: %s", checkpoint)
44+
original_model = transformers.AutoModelForCausalLM.from_pretrained(
45+
checkpoint, trust_remote_code=True
46+
)
47+
48+
# Locate the cached dir.
49+
cached_config_file = transformers.utils.cached_file(
50+
checkpoint, transformers.utils.CONFIG_NAME
51+
)
52+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
53+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
54+
reauthored_model = amd_llama_135m.build_model(str(reauthored_checkpoint))
55+
56+
logging.info("Loading the tokenizer from: %s", checkpoint)
57+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
58+
59+
verifier.verify_reauthored_model(
60+
original_model=transformers_verifier.TransformersModelWrapper(
61+
original_model
62+
),
63+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64+
tokenizer=verifier.TokenizerWrapper(tokenizer),
65+
generate_prompts=_PROMPTS.value,
3866
max_new_tokens=_MAX_NEW_TOKENS.value,
39-
prompts=_PROMPTS.value,
67+
atol=1e-04,
4068
)
4169

4270

ai_edge_torch/generative/examples/amd_llama_135m/verify_util.py

Lines changed: 0 additions & 76 deletions
This file was deleted.

ai_edge_torch/generative/examples/deepseek/verify.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,15 @@
1515

1616
"""Verifies the reauthored DeepSeek R1 distilled 1.5B model."""
1717

18+
import logging
19+
import pathlib
20+
1821
from absl import app
1922
from absl import flags
20-
from ai_edge_torch.generative.examples.deepseek import verify_util
23+
from ai_edge_torch.generative.examples.deepseek import deepseek
24+
from ai_edge_torch.generative.utilities import transformers_verifier
25+
from ai_edge_torch.generative.utilities import verifier
26+
import transformers
2127

2228

2329
_PROMPTS = flags.DEFINE_multi_string(
@@ -33,10 +39,30 @@
3339

3440

3541
def main(_):
36-
verify_util.verify_deepseek_r1_distill_1_5b(
37-
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
42+
checkpoint = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
43+
logging.info("Loading the original model from: %s", checkpoint)
44+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45+
46+
# Locate the cached dir.
47+
cached_config_file = transformers.utils.cached_file(
48+
checkpoint, transformers.utils.CONFIG_NAME
49+
)
50+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52+
reauthored_model = deepseek.build_model(str(reauthored_checkpoint))
53+
54+
logging.info("Loading the tokenizer from: %s", checkpoint)
55+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
56+
57+
verifier.verify_reauthored_model(
58+
original_model=transformers_verifier.TransformersModelWrapper(
59+
original_model
60+
),
61+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62+
tokenizer=verifier.TokenizerWrapper(tokenizer),
63+
generate_prompts=_PROMPTS.value,
3864
max_new_tokens=_MAX_NEW_TOKENS.value,
39-
prompts=_PROMPTS.value,
65+
atol=1e-04,
4066
)
4167

4268

ai_edge_torch/generative/examples/deepseek/verify_util.py

Lines changed: 0 additions & 76 deletions
This file was deleted.

ai_edge_torch/generative/examples/gemma/verify_util.py

Lines changed: 6 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,11 @@
1717

1818
import logging
1919
import os
20-
from typing import Callable, Dict, List, Tuple
20+
from typing import List, Tuple
2121

22-
from ai_edge_torch.generative.examples.gemma import gemma1
2322
from ai_edge_torch.generative.examples.gemma import gemma2
2423
import ai_edge_torch.generative.layers.attention_utils as attn_utils
2524
import ai_edge_torch.generative.layers.kv_cache as kv_utils
26-
from ai_edge_torch.generative.utilities import loader
2725
from ai_edge_torch.generative.utilities import verifier
2826
from gemma import config as gemma_config
2927
from gemma import model as gemma_model
@@ -109,7 +107,6 @@ def verify_reauthored_gemma_model(
109107
generate_prompts: List[str],
110108
forward_input_ids: List[List[int]],
111109
weight_filename: str = "model.ckpt",
112-
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
113110
tokenizer_filename: str = "tokenizer.model",
114111
max_new_tokens: int = 20,
115112
mask_as_input: bool = False,
@@ -128,14 +125,7 @@ def verify_reauthored_gemma_model(
128125

129126
logging.info("Loading the original model from: %s", checkpoint)
130127
original_model = gemma_model.GemmaForCausalLM(config).eval()
131-
checkpoint_path = os.path.join(checkpoint, weight_filename)
132-
if custom_loader is None:
133-
original_model.load_weights(checkpoint_path)
134-
else:
135-
original_model.load_state_dict(
136-
custom_loader(checkpoint_path)["model_state_dict"],
137-
strict=False,
138-
)
128+
original_model.load_weights(os.path.join(checkpoint, weight_filename))
139129

140130
return verifier.verify_reauthored_model(
141131
original_model=GemmaWrapper(original_model),
@@ -154,62 +144,27 @@ def verify_reauthored_gemma_model(
154144

155145

156146
def verify_gemma2(
157-
checkpoint_dir: str,
158-
weight_filename: str,
147+
gemma2_model_path: str,
159148
prompts: List[str],
160149
max_new_tokens: int,
161150
mask_as_input: bool = False,
162151
kv_layout: kv_utils.KVLayout = kv_utils.KV_LAYOUT_DEFAULT,
163-
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
164152
) -> bool:
165153
"""Verifies the reauthored Gemma2 model.
166154
167155
Return True if the verification passes, False otherwise.
168156
"""
169-
checkpoint_path = os.path.join(checkpoint_dir, weight_filename)
170-
logging.info("Building the reauthored model from: %s", checkpoint_path)
171-
reauthored_model = gemma2.build_2b_model(checkpoint_path, custom_loader)
157+
logging.info("Building the reauthored model from: %s", gemma2_model_path)
158+
reauthored_model = gemma2.build_2b_model(gemma2_model_path)
172159

173160
return verify_reauthored_gemma_model(
174-
checkpoint=checkpoint_dir,
161+
checkpoint=gemma2_model_path,
175162
variant="2b-v2",
176163
reauthored_model=reauthored_model,
177164
generate_prompts=prompts,
178165
forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
179-
weight_filename=weight_filename,
180-
custom_loader=custom_loader,
181166
max_new_tokens=max_new_tokens,
182167
mask_as_input=mask_as_input,
183168
kv_layout=kv_layout,
184169
atol=1e-04,
185170
)
186-
187-
188-
def verify_gemma1_with_custom_loader(checkpoint_dir: str) -> bool:
189-
"""Verifies the reauthored Gemma1 model with a custom loader."""
190-
weight_filename = "gemma-2b-it.ckpt"
191-
checkpoint_path = os.path.join(checkpoint_dir, weight_filename)
192-
custom_loader = loader.get_custom_loader(checkpoint_path)
193-
reauthored_model = gemma1.build_2b_model(checkpoint_path, custom_loader)
194-
return verify_reauthored_gemma_model(
195-
checkpoint=checkpoint_dir,
196-
variant="2b",
197-
reauthored_model=reauthored_model,
198-
weight_filename=weight_filename,
199-
custom_loader=custom_loader,
200-
generate_prompts=["What is the meaning of life?"],
201-
forward_input_ids=[[1, 2, 3, 4]],
202-
max_new_tokens=30,
203-
)
204-
205-
206-
def verify_gemma2_with_custom_loader(checkpoint_dir: str) -> bool:
207-
"""Verifies the reauthored Gemma2 model with a custom loader."""
208-
return verify_gemma2(
209-
checkpoint_dir=checkpoint_dir,
210-
weight_filename="model.ckpt",
211-
prompts=["What is the meaning of life?"],
212-
max_new_tokens=30,
213-
mask_as_input=True,
214-
custom_loader=loader.get_custom_loader("", checkpoint_format="pt"),
215-
)

ai_edge_torch/generative/examples/gemma3/verify_util.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from ai_edge_torch.generative.examples.gemma3 import gemma3
2323
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2424
import ai_edge_torch.generative.layers.attention_utils as attn_utils
25-
from ai_edge_torch.generative.utilities import loader
2625
from ai_edge_torch.generative.utilities import verifier
2726
from gemma import config as gemma_config
2827
from gemma import model as gemma_model
@@ -261,15 +260,3 @@ def verify_gemma3(
261260
custom_loader=custom_loader,
262261
atol=1e-04,
263262
)
264-
265-
266-
def verify_gemma3_with_custom_loader(checkpoint: str) -> bool:
267-
"""Verifies the reauthored Gemma3 model with a custom loader."""
268-
return verify_gemma3(
269-
checkpoint=checkpoint,
270-
prompts=["What is the meaning of life?"],
271-
max_new_tokens=30,
272-
variant="1b",
273-
weight_filename="model.ckpt",
274-
custom_loader=loader.get_custom_loader("", checkpoint_format="pt"),
275-
)

0 commit comments

Comments
 (0)