@@ -69,10 +69,32 @@ class NormalizationConfig:
69
69
enable_hlfb : bool = True
70
70
epsilon : float = 1e-5
71
71
zero_centered : bool = False
72
+ # Whether to use a scale parameter in the normalization.
73
+ with_scale : bool = False
74
+ # The shift to apply to the scale parameter.
75
+ scale_shift : float = 0.0
72
76
# Number of groups used in group normalization.
73
77
group_num : Optional [float ] = None
74
78
75
79
80
+ # Exprimental feature and may subject to change.
81
+ class KVCacheUpdateStrategy (enum .Enum ):
82
+ """Different alignment strategies of the KV cache.
83
+
84
+ Due to restrictions from different devices, we may need to apply different
85
+ alignment strategies to the KV cache during Attention layer's cache update.
86
+
87
+ Available options:
88
+ INPLACE: Update the existing cache in place using indexes.
89
+ PREPEND_LEFT: Append the new kv to the left of the existing cache. When this
90
+ cache update is applied, the newer kvs will always be prepended at the
91
+ beginning of the cache.
92
+ """
93
+
94
+ INPLACE = enum .auto ()
95
+ PREPEND_LEFT = enum .auto ()
96
+
97
+
76
98
@dataclasses .dataclass
77
99
class AttentionConfig :
78
100
"""Attention model's parameters."""
@@ -108,6 +130,12 @@ class AttentionConfig:
108
130
key_norm_config : NormalizationConfig = dataclasses .field (
109
131
default_factory = NormalizationConfig
110
132
)
133
+ # The normalization applied to value projection's output.
134
+ value_norm_config : NormalizationConfig = dataclasses .field (
135
+ default_factory = NormalizationConfig
136
+ )
137
+ # Whether the KV cache is shared with the previous attention block.
138
+ kv_shared : bool = False
111
139
relative_attention_num_buckets : int = 0
112
140
relative_attention_max_distance : int = 0
113
141
# Softcap on the output logits.
@@ -118,6 +146,8 @@ class AttentionConfig:
118
146
sliding_window_size : Optional [int ] = None
119
147
# The default causal mask value used by attention layer.
120
148
causal_mask_value : float = float ("-inf" )
149
+ # The update strategy of the KV cache. Default to INPLACE.
150
+ kvcache_update_strategy : KVCacheUpdateStrategy = KVCacheUpdateStrategy .INPLACE
121
151
122
152
123
153
@dataclasses .dataclass
@@ -135,6 +165,9 @@ class FeedForwardConfig:
135
165
type : FeedForwardType
136
166
activation : ActivationConfig
137
167
intermediate_size : int
168
+ # Whether to use two separate gating parameters or a single one in
169
+ # GatedFeedForward.
170
+ use_separate_gating : bool = True
138
171
use_bias : bool = False
139
172
# The normalization applied to feed forward's input.
140
173
pre_ff_norm_config : NormalizationConfig = dataclasses .field (
0 commit comments