15
15
16
16
"""Common building blocks for Attention layer."""
17
17
18
+ import abc
18
19
from typing import Optional , Tuple , Union
19
20
20
21
from ai_edge_torch .generative .layers import builder
@@ -111,7 +112,42 @@ def forward(
111
112
return output if kv is None else (output , kv )
112
113
113
114
114
- class CausalSelfAttention (nn .Module ):
115
+ class CausalSelfAttentionBase (nn .Module ):
116
+ """Base class for causal self attention layer."""
117
+
118
+ def __init__ (
119
+ self , dim : int , config : cfg .AttentionConfig , enable_hlfb : bool
120
+ ) -> None :
121
+ super ().__init__ ()
122
+ self .dim = dim
123
+ self .config = config
124
+ self .enable_hlfb = enable_hlfb
125
+
126
+ self .query_norm = builder .build_norm (
127
+ self .config .head_dim , self .config .query_norm_config
128
+ )
129
+ self .key_norm = builder .build_norm (
130
+ self .config .head_dim , self .config .key_norm_config
131
+ )
132
+ self .value_norm = builder .build_norm (
133
+ self .config .head_dim , self .config .value_norm_config
134
+ )
135
+
136
+ @abc .abstractmethod
137
+ def forward (
138
+ self ,
139
+ x : torch .Tensor ,
140
+ rope : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
141
+ mask : Optional [torch .Tensor ] = None ,
142
+ input_pos : Optional [torch .Tensor ] = None ,
143
+ kv_cache : Optional [kv_utils .KVCacheEntry ] = None ,
144
+ lora : Optional [lora_utils .LoRAEntry ] = None ,
145
+ ) -> Union [torch .Tensor , Tuple [torch .Tensor , kv_utils .KVCacheEntry ]]:
146
+ raise NotImplementedError ()
147
+
148
+
149
+ class CausalSelfAttention (CausalSelfAttentionBase ):
150
+ """Causal self attention layer implementation."""
115
151
116
152
def __init__ (
117
153
self ,
@@ -126,7 +162,7 @@ def __init__(
126
162
config (cfg.AttentionConfig): attention specific configurations.
127
163
enable_hlfb (bool): whether hlfb is enabled or not.
128
164
"""
129
- super ().__init__ ()
165
+ super ().__init__ (dim , config , enable_hlfb )
130
166
self .kv_cache = None
131
167
qkv_shape = (
132
168
config .num_heads + 2 * config .num_query_groups
@@ -137,12 +173,6 @@ def __init__(
137
173
self .output_projection = nn .Linear (
138
174
output_shape , dim , bias = config .output_proj_use_bias
139
175
)
140
- self .query_norm = builder .build_norm (
141
- config .head_dim , config .query_norm_config
142
- )
143
- self .key_norm = builder .build_norm (config .head_dim , config .key_norm_config )
144
- self .config = config
145
- self .enable_hlfb = enable_hlfb
146
176
147
177
def forward (
148
178
self ,
@@ -204,6 +234,7 @@ def forward(
204
234
205
235
q = self .query_norm (q )
206
236
k = self .key_norm (k )
237
+ v = self .value_norm (v )
207
238
208
239
q = q .reshape (B , T , - 1 , self .config .head_dim )
209
240
k = k .reshape (B , T , - 1 , self .config .head_dim )
0 commit comments