Skip to content

Commit 6a0361e

Browse files
protobird-gitcopybara-github
authored andcommitted
Make enable_hlfb true by default
- It's been true for most cases except stable_diffusion's decoder and diffuser - which use unet config instead of model config and set enable_hlfb explicitly based on cpu/gpu - Set enable_hlfb false explicitly for cpu_only/gemma3, T5, toy_models, and stable_diffusion clip PiperOrigin-RevId: 759665348
1 parent c5ab892 commit 6a0361e

File tree

26 files changed

+37
-95
lines changed

26 files changed

+37
-95
lines changed

ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
5151
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
5252
intermediate_size=2048,
5353
)
54-
norm_config = cfg.NormalizationConfig(
55-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
56-
)
54+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
5755
block_config = cfg.TransformerBlockConfig(
5856
attn_config=attn_config,
5957
ff_config=ff_config,
@@ -69,7 +67,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
6967
block_configs=block_config,
7068
final_norm_config=norm_config,
7169
lm_head_share_weight_with_embedding=False,
72-
enable_hlfb=True,
7370
)
7471
return config
7572

ai_edge_torch/generative/examples/deepseek/deepseek.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
5353
intermediate_size=8960,
5454
)
5555
norm_config = cfg.NormalizationConfig(
56-
type=cfg.NormalizationType.RMS_NORM,
57-
epsilon=1e-06,
58-
enable_hlfb=True,
56+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
5957
)
6058
block_config = cfg.TransformerBlockConfig(
6159
attn_config=attn_config,
@@ -72,7 +70,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
7270
block_configs=block_config,
7371
final_norm_config=norm_config,
7472
lm_head_share_weight_with_embedding=False,
75-
enable_hlfb=True,
7673
)
7774
return config
7875

ai_edge_torch/generative/examples/gemma/gemma1.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
6565
intermediate_size=16384,
6666
)
6767
norm_config = cfg.NormalizationConfig(
68-
type=cfg.NormalizationType.RMS_NORM,
69-
epsilon=1e-6,
70-
zero_centered=True,
71-
enable_hlfb=True,
68+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
7269
)
7370
block_config = cfg.TransformerBlockConfig(
7471
attn_config=attn_config,
@@ -87,7 +84,6 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
8784
block_configs=block_config,
8885
final_norm_config=norm_config,
8986
lm_head_use_bias=False,
90-
enable_hlfb=True,
9187
)
9288
return config
9389

ai_edge_torch/generative/examples/gemma/gemma2.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
233233
The model config for a Gemma 2B model.
234234
"""
235235
norm_config = cfg.NormalizationConfig(
236-
type=cfg.NormalizationType.RMS_NORM,
237-
epsilon=1e-6,
238-
zero_centered=True,
239-
enable_hlfb=True,
236+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
240237
)
241238
ff_config = cfg.FeedForwardConfig(
242239
type=cfg.FeedForwardType.GATED,
@@ -284,7 +281,6 @@ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
284281
block_configs=[get_block_config(i) for i in range(num_layers)],
285282
final_norm_config=norm_config,
286283
lm_head_use_bias=False,
287-
enable_hlfb=True,
288284
final_logit_softcap=30.0,
289285
)
290286
return config

ai_edge_torch/generative/examples/gemma3/decoder.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -329,10 +329,7 @@ def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
329329
The model config for a Gemma 1B model.
330330
"""
331331
norm_config = cfg.NormalizationConfig(
332-
type=cfg.NormalizationType.RMS_NORM,
333-
epsilon=1e-6,
334-
zero_centered=True,
335-
enable_hlfb=True,
332+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True,
336333
)
337334
ff_config = cfg.FeedForwardConfig(
338335
type=cfg.FeedForwardType.GATED,
@@ -379,7 +376,6 @@ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
379376
block_configs=[get_block_config(i) for i in range(num_layers)],
380377
final_norm_config=norm_config,
381378
lm_head_use_bias=False,
382-
enable_hlfb=True,
383379
final_logit_softcap=None,
384380
)
385381
return config

ai_edge_torch/generative/examples/gemma3/gemma3.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,7 @@ def get_fake_model_config(**kwargs) -> Gemma3MMConfig:
158158
image_projection_scale=128**0.5,
159159
image_projection_use_bias=False,
160160
mm_norm_config=cfg.NormalizationConfig(
161-
type=cfg.NormalizationType.LAYER_NORM,
162-
epsilon=1e-6,
163-
enable_hlfb=True,
161+
type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
164162
),
165163
mm_extra_tokens=32,
166164
)

ai_edge_torch/generative/examples/gemma3/image_encoder.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,7 @@ def get_image_encoder_config() -> cfg.ModelConfig:
9898
output_proj_use_bias=True,
9999
)
100100
norm_config = cfg.NormalizationConfig(
101-
type=cfg.NormalizationType.LAYER_NORM,
102-
epsilon=1e-6,
103-
enable_hlfb=True,
101+
type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
104102
)
105103
ff_config = cfg.FeedForwardConfig(
106104
type=cfg.FeedForwardType.SEQUENTIAL,
@@ -123,7 +121,6 @@ def get_image_encoder_config() -> cfg.ModelConfig:
123121
image_embedding=image_embedding_config,
124122
block_configs=block_config,
125123
final_norm_config=norm_config,
126-
enable_hlfb=True,
127124
num_mm_tokens_per_image=256,
128125
)
129126
return config

ai_edge_torch/generative/examples/hammer/hammer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
4545
intermediate_size=8960,
4646
)
4747
norm_config = cfg.NormalizationConfig(
48-
type=cfg.NormalizationType.RMS_NORM,
49-
epsilon=1e-06,
50-
enable_hlfb=True,
48+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-06
5149
)
5250
block_config = cfg.TransformerBlockConfig(
5351
attn_config=attn_config,
@@ -63,7 +61,6 @@ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
6361
kv_cache_max_len=kv_cache_max_len,
6462
block_configs=block_config,
6563
final_norm_config=norm_config,
66-
enable_hlfb=True,
6764
)
6865
return config
6966

ai_edge_torch/generative/examples/llama/llama.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,7 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
121121
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
122122
intermediate_size=8192,
123123
)
124-
norm_config = cfg.NormalizationConfig(
125-
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
126-
)
124+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
127125
block_config = cfg.TransformerBlockConfig(
128126
attn_config=attn_config,
129127
ff_config=ff_config,
@@ -152,7 +150,6 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
152150
kv_cache_max_len=kv_cache_max_len,
153151
block_configs=block_config,
154152
final_norm_config=norm_config,
155-
enable_hlfb=True,
156153
build_rope=build_rope,
157154
)
158155
return config

ai_edge_torch/generative/examples/openelm/openelm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
5353
The model config for an OpenELM model.
5454
"""
5555
norm_config = cfg.NormalizationConfig(
56-
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, enable_hlfb=True
56+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
5757
)
5858
num_heads = [12] * 4 + [16] * 14 + [20] * 12 + [24] * 6
5959
num_query_groups = [3] * 4 + [4] * 14 + [5] * 12 + [6] * 6
@@ -101,7 +101,6 @@ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
101101
kv_cache_max_len=kv_cache_max_len,
102102
block_configs=[get_block_config(i) for i in range(num_layers)],
103103
final_norm_config=norm_config,
104-
enable_hlfb=True,
105104
)
106105
return config
107106

0 commit comments

Comments
 (0)