Skip to content

Commit 5af0dab

Browse files
haozha111copybara-github
authored andcommitted
* Move logging funcions to seperate file.
* Add some missing BUILD rules. PiperOrigin-RevId: 678962505
1 parent 9ae6590 commit 5af0dab

File tree

11 files changed

+59
-52
lines changed

11 files changed

+59
-52
lines changed

ai_edge_torch/generative/examples/gemma/gemma2.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
"""Example of building a Gemma2 model."""
1717

1818
import os
19-
import pathlib
2019
from typing import Optional, Tuple
2120

2221
from ai_edge_torch.generative.layers import attention
@@ -25,7 +24,6 @@
2524
import ai_edge_torch.generative.layers.attention_utils as attn_utils
2625
import ai_edge_torch.generative.layers.model_config as cfg
2726
import ai_edge_torch.generative.utilities.loader as loading_utils
28-
import numpy as np
2927
import torch
3028
from torch import nn
3129

ai_edge_torch/generative/examples/gemma/verify_gemma1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515

1616
"""Verifies the reauthored Gemma1 model."""
1717

18+
import logging
1819
from absl import app
1920
from absl import flags
2021
from ai_edge_torch.generative.examples.gemma import gemma1
2122
from ai_edge_torch.generative.examples.gemma import verify_util
22-
from ai_edge_torch.generative.utilities import verifier
2323
import kagglehub
2424

25+
2526
_PROMPTS = flags.DEFINE_multi_string(
2627
"prompts",
2728
"What is the meaning of life?",
@@ -37,7 +38,7 @@
3738
def main(_):
3839
checkpoint = kagglehub.model_download("google/gemma/pyTorch/2b-it")
3940

40-
verifier.log_msg("Building the reauthored model from", checkpoint)
41+
logging.info("Building the reauthored model from: %s", checkpoint)
4142
reauthored_model = gemma1.build_2b_model(checkpoint)
4243

4344
verify_util.verify_reauthored_gemma_model(

ai_edge_torch/generative/examples/gemma/verify_gemma2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515

1616
"""Verifies the reauthored Gemma2 model."""
1717

18+
import logging
1819
from absl import app
1920
from absl import flags
2021
from ai_edge_torch.generative.examples.gemma import gemma2
2122
from ai_edge_torch.generative.examples.gemma import verify_util
2223
from ai_edge_torch.generative.utilities import verifier
2324
import kagglehub
2425

26+
2527
_PROMPTS = flags.DEFINE_multi_string(
2628
"prompts",
2729
"What is the meaning of life?",
@@ -37,7 +39,7 @@
3739
def main(_):
3840
checkpoint = kagglehub.model_download("google/gemma-2/pyTorch/gemma-2-2b-it")
3941

40-
verifier.log_msg("Building the reauthored model from", checkpoint)
42+
logging.info("Building the reauthored model from: %s", checkpoint)
4143
reauthored_model = gemma2.build_2b_model(checkpoint)
4244

4345
verify_util.verify_reauthored_gemma_model(

ai_edge_torch/generative/examples/gemma/verify_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""Utility functions to verify the reauthored Gemma model."""
1717

1818
import dataclasses
19+
import logging
1920
import os
2021
from typing import List, Tuple
2122

@@ -127,7 +128,7 @@ def verify_reauthored_gemma_model(
127128
# Use float32 to be compatible with the reauthored model.
128129
config.dtype = torch.float32
129130

130-
verifier.log_msg("Loading the original model from", checkpoint)
131+
logging.info("Loading the original model from: %s", checkpoint)
131132
original_model = gemma_model.GemmaForCausalLM(config).eval()
132133
original_model.load_weights(os.path.join(checkpoint, weight_filename))
133134

ai_edge_torch/generative/examples/openelm/verify.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515

1616
"""Verifies the reauthored OpenELM-3B model."""
1717

18+
import logging
1819
import pathlib
19-
2020
from absl import app
2121
from absl import flags
2222
from ai_edge_torch.generative.examples.openelm import openelm
2323
from ai_edge_torch.generative.utilities import verifier
2424
import transformers
2525

26+
2627
_PROMPTS = flags.DEFINE_multi_string(
2728
"prompts",
2829
"What is the meaning of life?",
@@ -32,7 +33,7 @@
3233

3334
def main(_):
3435
checkpoint = "apple/OpenELM-3B"
35-
verifier.log_msg("Loading the original model from", checkpoint)
36+
logging.info("Loading the original model from: %s", checkpoint)
3637
wrapper_model = verifier.ModelWrapper(
3738
model=transformers.AutoModelForCausalLM.from_pretrained(
3839
checkpoint, trust_remote_code=True
@@ -44,11 +45,11 @@ def main(_):
4445
checkpoint, transformers.utils.CONFIG_NAME
4546
)
4647
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
47-
verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
48+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
4849
reauthored_model = openelm.build_model(reauthored_checkpoint)
4950

5051
tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
51-
verifier.log_msg("Loading the tokenizer from", tokenizer_checkpoint)
52+
logging.info("Loading the tokenizer from: %s", tokenizer_checkpoint)
5253
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
5354

5455
verifier.verify_reauthored_model(

ai_edge_torch/generative/examples/phi/verify.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# ==============================================================================
1515

1616
"""Verifies the reauthored Phi-2 model."""
17+
import logging
1718

1819
from absl import app
1920
from absl import flags
@@ -22,6 +23,7 @@
2223
import kagglehub
2324
import transformers
2425

26+
2527
_PROMPTS = flags.DEFINE_multi_string(
2628
"prompts",
2729
"Instruct: Write an email about the weather Output:",
@@ -36,18 +38,18 @@
3638

3739
def main(_):
3840
checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
39-
verifier.log_msg("Loading the original model from", checkpoint)
41+
logging.info("Loading the original model from: %s", checkpoint)
4042
generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
4143
generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
4244
wrapper_model = verifier.ModelWrapper(
4345
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
4446
hf_generation_config=generation_config,
4547
)
4648

47-
verifier.log_msg("Building the reauthored model from", checkpoint)
49+
logging.info("Building the reauthored model from: %s", checkpoint)
4850
reauthored_model = phi2.build_model(checkpoint)
4951

50-
verifier.log_msg("Loading the tokenizer from", checkpoint)
52+
logging.info("Loading the tokenizer from: %s", checkpoint)
5153
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
5254

5355
verifier.verify_reauthored_model(

ai_edge_torch/generative/examples/phi/verify_phi3.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Verifies the reauthored Phi-3.5 model."""
1717

18+
import logging
1819
import pathlib
1920

2021
from absl import app
@@ -23,6 +24,7 @@
2324
from ai_edge_torch.generative.utilities import verifier
2425
import transformers
2526

27+
2628
_PROMPTS = flags.DEFINE_multi_string(
2729
"prompts",
2830
"Instruct: Write an email about the weather Output:",
@@ -37,7 +39,7 @@
3739

3840
def main(_):
3941
checkpoint = "microsoft/Phi-3.5-mini-instruct"
40-
verifier.log_msg("Loading the original model from", checkpoint)
42+
logging.info("Loading the original model from: %s", checkpoint)
4143
generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
4244
generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
4345
wrapper_model = verifier.ModelWrapper(
@@ -50,10 +52,10 @@ def main(_):
5052
checkpoint, transformers.utils.CONFIG_NAME
5153
)
5254
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
53-
verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
55+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
5456
reauthored_model = phi3.build_model(reauthored_checkpoint)
5557

56-
verifier.log_msg("Loading the tokenizer from", checkpoint)
58+
logging.info("Loading the tokenizer from: %s", checkpoint)
5759
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
5860

5961
verifier.verify_reauthored_model(

ai_edge_torch/generative/examples/smollm/verify.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Verifies the reauthored SmolLM-135M model."""
1717

18+
import logging
1819
import pathlib
1920

2021
from absl import app
@@ -23,6 +24,7 @@
2324
from ai_edge_torch.generative.utilities import verifier
2425
import transformers
2526

27+
2628
_PROMPTS = flags.DEFINE_multi_string(
2729
"prompts",
2830
"What is the meaning of life?",
@@ -32,7 +34,7 @@
3234

3335
def main(_):
3436
checkpoint = "HuggingFaceTB/SmolLM-135M"
35-
verifier.log_msg("Loading the original model from", checkpoint)
37+
logging.info("Loading the original model from: %s", checkpoint)
3638
wrapper_model = verifier.ModelWrapper(
3739
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
3840
)
@@ -41,10 +43,10 @@ def main(_):
4143
checkpoint, transformers.utils.CONFIG_NAME
4244
)
4345
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
44-
verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
46+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
4547
reauthored_model = smollm.build_model(reauthored_checkpoint)
4648

47-
verifier.log_msg("Loading the tokenizer from", checkpoint)
49+
logging.info("Loading the tokenizer from: %s", checkpoint)
4850
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
4951

5052
verifier.verify_reauthored_model(

ai_edge_torch/generative/examples/tiny_llama/verify.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Verifies the reauthored TinyLlama-1.1B model."""
1717

18+
import logging
1819
import pathlib
1920

2021
from absl import app
@@ -23,6 +24,7 @@
2324
from ai_edge_torch.generative.utilities import verifier
2425
import transformers
2526

27+
2628
_PROMPTS = flags.DEFINE_multi_string(
2729
"prompts",
2830
"Show me the program to add 2 and 3.",
@@ -32,7 +34,7 @@
3234

3335
def main(_):
3436
checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
35-
verifier.log_msg("Loading the original model from", checkpoint)
37+
logging.info("Loading the original model from: %s", checkpoint)
3638
wrapper_model = verifier.ModelWrapper(
3739
model=transformers.AutoModelForCausalLM.from_pretrained(
3840
checkpoint, trust_remote_code=True
@@ -43,10 +45,10 @@ def main(_):
4345
checkpoint, transformers.utils.CONFIG_NAME
4446
)
4547
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
46-
verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
48+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
4749
reauthored_model = tiny_llama.build_model(reauthored_checkpoint)
4850

49-
verifier.log_msg("Loading the tokenizer from", checkpoint)
51+
logging.info("Loading the tokenizer from: %s", checkpoint)
5052
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
5153

5254
verifier.verify_reauthored_model(

ai_edge_torch/generative/tools/batch_convert.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import dataclasses
1919
import enum
20+
import logging
2021
import os
2122
import pathlib
2223
from typing import Callable, Sequence
@@ -29,7 +30,6 @@
2930
from ai_edge_torch.generative.examples.phi import phi3
3031
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
3132
from ai_edge_torch.generative.utilities import converter
32-
from ai_edge_torch.generative.utilities import verifier
3333
import torch
3434

3535
_CHECKPOINT_ROOT_PATH = flags.DEFINE_string(
@@ -67,12 +67,12 @@ class ConversionConfig:
6767

6868
def print_config(self) -> None:
6969
"""Prints the conversion config."""
70-
verifier.log_msg("Model name:", self.model_name)
71-
verifier.log_msg("Input checkpoint:", self.input_checkpoint)
72-
verifier.log_msg("TF Lite output path:", self.tflite_output_path)
73-
verifier.log_msg("Prefill seq len:", self.prefill_seq_len)
74-
verifier.log_msg("KV cache max len:", self.kv_cache_max_len)
75-
verifier.log_msg("Export precision:", self.export_precision)
70+
logging.info("Model name: %s", self.model_name)
71+
logging.info("Input checkpoint: %s", self.input_checkpoint)
72+
logging.info("TF Lite output path: %s", self.tflite_output_path)
73+
logging.info("Prefill seq len: %s", self.prefill_seq_len)
74+
logging.info("KV cache max len: %s", self.kv_cache_max_len)
75+
logging.info("Export precision: %s", self.export_precision)
7676

7777

7878
def prepare_conversion_configs() -> Sequence[ConversionConfig]:
@@ -143,7 +143,7 @@ def get_output_filename(
143143
if precision == ExportPrecision.INT8:
144144
precision_str = "q8"
145145
elif precision == ExportPrecision.FP32:
146-
precision_str = "fp32"
146+
precision_str = "f32"
147147
else:
148148
raise ValueError(f"Unsupported precision: {precision}")
149149
return f"{model_name}_{precision_str}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite"
@@ -152,8 +152,8 @@ def get_output_filename(
152152
def convert_models(conversion_configs: Sequence[ConversionConfig]) -> None:
153153
"""Executes the conversion for a batch of models specified by the `conversion_configs`."""
154154
for config in conversion_configs:
155-
verifier.log_msg(
156-
"Converting model:", config.model_name, " with the following config:"
155+
logging.info(
156+
"Converting model: %s with the following config:", config.model_name
157157
)
158158
config.print_config()
159159
pytorch_model = config.model_builder(
@@ -172,7 +172,7 @@ def convert_models(conversion_configs: Sequence[ConversionConfig]) -> None:
172172
prefill_seq_len=config.prefill_seq_len,
173173
quantize=True if precision == ExportPrecision.INT8 else False,
174174
)
175-
verifier.log_msg("Successfully converted model:", output_filename)
175+
logging.info("Successfully converted model: %s", output_filename)
176176

177177

178178
def main(_):

0 commit comments

Comments
 (0)