Skip to content

Commit 9ae6590

Browse files
haozha111copybara-github
authored andcommitted
Add BUILD rules and update copy.bara.sky for toy models.
PiperOrigin-RevId: 678471108
1 parent c0f0b63 commit 9ae6590

File tree

3 files changed

+109
-87
lines changed

3 files changed

+109
-87
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# A toy example which has a single-layer transformer block.
16+
from absl import app
17+
import ai_edge_torch
18+
from ai_edge_torch import lowertools
19+
from ai_edge_torch.generative.examples.test_models import toy_model
20+
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
21+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
22+
import torch
23+
24+
KV_CACHE_MAX_LEN = 100
25+
26+
27+
def convert_toy_model(_) -> None:
28+
"""Converts a toy model to tflite."""
29+
model = toy_model.ToySingleLayerModel(toy_model.get_model_config())
30+
idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
31+
input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
32+
print('running an inference')
33+
print(
34+
model.forward(
35+
idx,
36+
input_pos,
37+
)
38+
)
39+
40+
# Convert model to tflite.
41+
print('converting model to tflite')
42+
edge_model = ai_edge_torch.convert(
43+
model,
44+
(
45+
idx,
46+
input_pos,
47+
),
48+
)
49+
edge_model.export('/tmp/toy_model.tflite')
50+
51+
52+
def _export_stablehlo_mlir(model, args):
53+
ep = torch.export.export(model, args)
54+
return lowertools.exported_program_to_mlir_text(ep)
55+
56+
57+
def convert_toy_model_with_kv_cache(_) -> None:
58+
"""Converts a toy model with kv cache to tflite."""
59+
dump_mlir = False
60+
61+
config = toy_model_with_kv_cache.get_model_config()
62+
model = toy_model_with_kv_cache.ToyModelWithKVCache(config)
63+
model.eval()
64+
print('running an inference')
65+
kv = kv_utils.KVCache.from_model_config(config)
66+
67+
tokens, input_pos = toy_model_with_kv_cache.get_sample_prefill_inputs()
68+
decode_token, decode_input_pos = (
69+
toy_model_with_kv_cache.get_sample_decode_inputs()
70+
)
71+
print(model.forward(tokens, input_pos, kv))
72+
73+
if dump_mlir:
74+
mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
75+
with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
76+
f.write(mlir_text)
77+
78+
# Convert model to tflite with 2 signatures (prefill + decode).
79+
print('converting toy model to tflite with 2 signatures (prefill + decode)')
80+
edge_model = (
81+
ai_edge_torch.signature(
82+
'prefill',
83+
model,
84+
sample_kwargs={
85+
'tokens': tokens,
86+
'input_pos': input_pos,
87+
'kv_cache': kv,
88+
},
89+
)
90+
.signature(
91+
'decode',
92+
model,
93+
sample_kwargs={
94+
'tokens': decode_token,
95+
'input_pos': decode_input_pos,
96+
'kv_cache': kv,
97+
},
98+
)
99+
.convert()
100+
)
101+
edge_model.export('/tmp/toy_external_kv_cache.tflite')
102+
103+
104+
if __name__ == '__main__':
105+
app.run(convert_toy_model)

ai_edge_torch/generative/examples/test_models/toy_model.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
# A toy example which has a single-layer transformer block.
1616
from typing import Tuple
1717

18-
import ai_edge_torch
18+
from ai_edge_torch.generative.layers import builder
1919
from ai_edge_torch.generative.layers.attention import TransformerBlock
2020
import ai_edge_torch.generative.layers.attention_utils as attn_utils
21-
import ai_edge_torch.generative.layers.builder as builder
2221
import ai_edge_torch.generative.layers.model_config as cfg
2322
import torch
24-
import torch.nn as nn
23+
from torch import nn
2524

2625
RoPECache = Tuple[torch.Tensor, torch.Tensor]
2726
KV_CACHE_MAX_LEN = 100
@@ -149,31 +148,3 @@ def get_model_config() -> cfg.ModelConfig:
149148
final_norm_config=norm_config,
150149
)
151150
return config
152-
153-
154-
def define_and_run() -> None:
155-
model = ToySingleLayerModel(get_model_config())
156-
idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
157-
input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
158-
print('running an inference')
159-
print(
160-
model.forward(
161-
idx,
162-
input_pos,
163-
)
164-
)
165-
166-
# Convert model to tflite.
167-
print('converting model to tflite')
168-
edge_model = ai_edge_torch.convert(
169-
model,
170-
(
171-
idx,
172-
input_pos,
173-
),
174-
)
175-
edge_model.export('/tmp/toy_model.tflite')
176-
177-
178-
if __name__ == '__main__':
179-
define_and_run()

ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py

Lines changed: 2 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,14 @@
1717

1818
from typing import Tuple
1919

20-
import ai_edge_torch
21-
from ai_edge_torch import lowertools
20+
from absl import app
2221
from ai_edge_torch.generative.layers import attention
2322
from ai_edge_torch.generative.layers import builder
2423
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2524
import ai_edge_torch.generative.layers.attention_utils as attn_utils
2625
import ai_edge_torch.generative.layers.model_config as cfg
2726
import torch
28-
import torch.nn as nn
27+
from torch import nn
2928

3029
RoPECache = Tuple[torch.Tensor, torch.Tensor]
3130

@@ -87,11 +86,6 @@ def forward(
8786
return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
8887

8988

90-
def _export_stablehlo_mlir(model, args):
91-
ep = torch.export.export(model, args)
92-
return lowertools.exported_program_to_mlir_text(ep)
93-
94-
9589
def get_model_config() -> cfg.ModelConfig:
9690
attn_config = cfg.AttentionConfig(
9791
num_heads=32,
@@ -133,51 +127,3 @@ def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
133127
tokens = torch.tensor([[1]], dtype=torch.int)
134128
input_pos = torch.tensor([10])
135129
return tokens, input_pos
136-
137-
138-
def define_and_run() -> None:
139-
dump_mlir = False
140-
141-
config = get_model_config()
142-
model = ToyModelWithExternalKV(config)
143-
model.eval()
144-
print('running an inference')
145-
kv = kv_utils.KVCache.from_model_config(config)
146-
147-
tokens, input_pos = get_sample_prefill_inputs()
148-
decode_token, decode_input_pos = get_sample_decode_inputs()
149-
print(model.forward(tokens, input_pos, kv))
150-
151-
if dump_mlir:
152-
mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
153-
with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
154-
f.write(mlir_text)
155-
156-
# Convert model to tflite with 2 signatures (prefill + decode).
157-
print('converting toy model to tflite with 2 signatures (prefill + decode)')
158-
edge_model = (
159-
ai_edge_torch.signature(
160-
'prefill',
161-
model,
162-
sample_kwargs={
163-
'tokens': tokens,
164-
'input_pos': input_pos,
165-
'kv_cache': kv,
166-
},
167-
)
168-
.signature(
169-
'decode',
170-
model,
171-
sample_kwargs={
172-
'tokens': decode_token,
173-
'input_pos': decode_input_pos,
174-
'kv_cache': kv,
175-
},
176-
)
177-
.convert()
178-
)
179-
edge_model.export('/tmp/toy_external_kv_cache.tflite')
180-
181-
182-
if __name__ == '__main__':
183-
define_and_run()

0 commit comments

Comments
 (0)