Skip to content

Commit ec7dfd2

Browse files
haozha111copybara-github
authored andcommitted
Introduce a few configuration changes for Gemma3N.
- NormalizationConfig : support scale - Add KVcache update strategy enum - Add value norm config in AttentionConfig. - Introduce separate gating flag. This is to accomodate the case where Jax's GatedFeedforward may only use a single gating einsum parameter rather than two. PiperOrigin-RevId: 762074125
1 parent 479696b commit ec7dfd2

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

ai_edge_torch/generative/layers/model_config.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,32 @@ class NormalizationConfig:
6969
enable_hlfb: bool = True
7070
epsilon: float = 1e-5
7171
zero_centered: bool = False
72+
# Whether to use a scale parameter in the normalization.
73+
with_scale: bool = False
74+
# The shift to apply to the scale parameter.
75+
scale_shift: float = 0.0
7276
# Number of groups used in group normalization.
7377
group_num: Optional[float] = None
7478

7579

80+
# Exprimental feature and may subject to change.
81+
class KVCacheUpdateStrategy(enum.Enum):
82+
"""Different alignment strategies of the KV cache.
83+
84+
Due to restrictions from different devices, we may need to apply different
85+
alignment strategies to the KV cache during Attention layer's cache update.
86+
87+
Available options:
88+
INPLACE: Update the existing cache in place using indexes.
89+
PREPEND_LEFT: Append the new kv to the left of the existing cache. When this
90+
cache update is applied, the newer kvs will always be prepended at the
91+
beginning of the cache.
92+
"""
93+
94+
INPLACE = enum.auto()
95+
PREPEND_LEFT = enum.auto()
96+
97+
7698
@dataclasses.dataclass
7799
class AttentionConfig:
78100
"""Attention model's parameters."""
@@ -108,6 +130,12 @@ class AttentionConfig:
108130
key_norm_config: NormalizationConfig = dataclasses.field(
109131
default_factory=NormalizationConfig
110132
)
133+
# The normalization applied to value projection's output.
134+
value_norm_config: NormalizationConfig = dataclasses.field(
135+
default_factory=NormalizationConfig
136+
)
137+
# Whether the KV cache is shared with the previous attention block.
138+
kv_shared: bool = False
111139
relative_attention_num_buckets: int = 0
112140
relative_attention_max_distance: int = 0
113141
# Softcap on the output logits.
@@ -118,6 +146,8 @@ class AttentionConfig:
118146
sliding_window_size: Optional[int] = None
119147
# The default causal mask value used by attention layer.
120148
causal_mask_value: float = float("-inf")
149+
# The update strategy of the KV cache. Default to INPLACE.
150+
kvcache_update_strategy: KVCacheUpdateStrategy = KVCacheUpdateStrategy.INPLACE
121151

122152

123153
@dataclasses.dataclass
@@ -135,6 +165,9 @@ class FeedForwardConfig:
135165
type: FeedForwardType
136166
activation: ActivationConfig
137167
intermediate_size: int
168+
# Whether to use two separate gating parameters or a single one in
169+
# GatedFeedForward.
170+
use_separate_gating: bool = True
138171
use_bias: bool = False
139172
# The normalization applied to feed forward's input.
140173
pre_ff_norm_config: NormalizationConfig = dataclasses.field(

0 commit comments

Comments
 (0)