Skip to content

Commit 0f48501

Browse files
haozha111copybara-github
authored andcommitted
Implement Gemma3N model's attention layer.
PiperOrigin-RevId: 762244211
1 parent 02d534a commit 0f48501

File tree

1 file changed

+39
-8
lines changed

1 file changed

+39
-8
lines changed

ai_edge_torch/generative/layers/attention.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Common building blocks for Attention layer."""
1717

18+
import abc
1819
from typing import Optional, Tuple, Union
1920

2021
from ai_edge_torch.generative.layers import builder
@@ -111,7 +112,42 @@ def forward(
111112
return output if kv is None else (output, kv)
112113

113114

114-
class CausalSelfAttention(nn.Module):
115+
class CausalSelfAttentionBase(nn.Module):
116+
"""Base class for causal self attention layer."""
117+
118+
def __init__(
119+
self, dim: int, config: cfg.AttentionConfig, enable_hlfb: bool
120+
) -> None:
121+
super().__init__()
122+
self.dim = dim
123+
self.config = config
124+
self.enable_hlfb = enable_hlfb
125+
126+
self.query_norm = builder.build_norm(
127+
self.config.head_dim, self.config.query_norm_config
128+
)
129+
self.key_norm = builder.build_norm(
130+
self.config.head_dim, self.config.key_norm_config
131+
)
132+
self.value_norm = builder.build_norm(
133+
self.config.head_dim, self.config.value_norm_config
134+
)
135+
136+
@abc.abstractmethod
137+
def forward(
138+
self,
139+
x: torch.Tensor,
140+
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
141+
mask: Optional[torch.Tensor] = None,
142+
input_pos: Optional[torch.Tensor] = None,
143+
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
144+
lora: Optional[lora_utils.LoRAEntry] = None,
145+
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
146+
raise NotImplementedError()
147+
148+
149+
class CausalSelfAttention(CausalSelfAttentionBase):
150+
"""Causal self attention layer implementation."""
115151

116152
def __init__(
117153
self,
@@ -126,7 +162,7 @@ def __init__(
126162
config (cfg.AttentionConfig): attention specific configurations.
127163
enable_hlfb (bool): whether hlfb is enabled or not.
128164
"""
129-
super().__init__()
165+
super().__init__(dim, config, enable_hlfb)
130166
self.kv_cache = None
131167
qkv_shape = (
132168
config.num_heads + 2 * config.num_query_groups
@@ -137,12 +173,6 @@ def __init__(
137173
self.output_projection = nn.Linear(
138174
output_shape, dim, bias=config.output_proj_use_bias
139175
)
140-
self.query_norm = builder.build_norm(
141-
config.head_dim, config.query_norm_config
142-
)
143-
self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
144-
self.config = config
145-
self.enable_hlfb = enable_hlfb
146176

147177
def forward(
148178
self,
@@ -204,6 +234,7 @@ def forward(
204234

205235
q = self.query_norm(q)
206236
k = self.key_norm(k)
237+
v = self.value_norm(v)
207238

208239
q = q.reshape(B, T, -1, self.config.head_dim)
209240
k = k.reshape(B, T, -1, self.config.head_dim)

0 commit comments

Comments
 (0)