Skip to content

Commit 02d534a

Browse files
haozha111copybara-github
authored andcommitted
In GatedFeedforward, add support for a single gating einsum parameter.
PiperOrigin-RevId: 762130610
1 parent ec7dfd2 commit 02d534a

File tree

3 files changed

+106
-61
lines changed

3 files changed

+106
-61
lines changed

ai_edge_torch/generative/layers/builder.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
# Builder class for individual components.
1616
from typing import Callable
1717

18+
from ai_edge_torch.generative.layers import normalization
1819
import ai_edge_torch.generative.layers.feed_forward as feed_forward
1920
import ai_edge_torch.generative.layers.model_config as cfg
20-
import ai_edge_torch.generative.layers.normalization as normalization
2121
import torch
2222
from torch import nn
2323
import torch.nn.functional as F
@@ -74,6 +74,8 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
7474
dim,
7575
eps=config.epsilon,
7676
zero_centered_gamma=config.zero_centered,
77+
with_scale=config.with_scale,
78+
scale_shift=config.scale_shift,
7779
enable_hlfb=config.enable_hlfb,
7880
)
7981
elif config.type == cfg.NormalizationType.LAYER_NORM:
@@ -107,20 +109,13 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
107109
else:
108110
raise ValueError("Unsupported feedforward type.")
109111

110-
activation = get_activation(config.activation)
111-
112112
pre_ff_norm = build_norm(dim, config.pre_ff_norm_config)
113113
post_ff_norm = build_norm(dim, config.post_ff_norm_config)
114114

115115
return ff_module(
116116
dim=dim,
117-
hidden_dim=config.intermediate_size,
118-
activation=activation,
119-
use_bias=config.use_bias,
120-
use_glu=(
121-
config.activation.type == cfg.ActivationType.GE_GLU
122-
or config.activation.type == cfg.ActivationType.SILU_GLU
123-
),
117+
activation=get_activation(config.activation),
118+
config=config,
124119
pre_ff_norm=pre_ff_norm,
125120
post_ff_norm=post_ff_norm,
126121
)

ai_edge_torch/generative/layers/feed_forward.py

Lines changed: 86 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -14,45 +14,69 @@
1414
# ==============================================================================
1515
# Common building blocks for FeedForward layers.
1616

17-
from typing import Callable, Optional
17+
import abc
18+
from typing import Callable
1819

20+
import ai_edge_torch.generative.layers.model_config as cfg
1921
import torch
2022
from torch import nn
2123

2224

23-
class SequentialFeedForward(nn.Module):
25+
class FeedForwardBase(nn.Module):
26+
"""Base class for feedforward layer."""
27+
28+
def __init__(
29+
self,
30+
dim: int,
31+
activation: Callable[[torch.Tensor], torch.Tensor],
32+
config: cfg.FeedForwardConfig,
33+
pre_ff_norm: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
34+
post_ff_norm: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
35+
):
36+
super().__init__()
37+
self.dim = dim
38+
self.act = activation
39+
self.config = config
40+
self.hidden_dim = config.intermediate_size
41+
self.use_bias = config.use_bias
42+
self.use_glu = (
43+
config.activation.type == cfg.ActivationType.GE_GLU
44+
or config.activation.type == cfg.ActivationType.SILU_GLU
45+
)
46+
self.pre_ff_norm = pre_ff_norm
47+
self.post_ff_norm = post_ff_norm
48+
49+
@abc.abstractmethod
50+
def forward(self, x: torch.Tensor) -> torch.Tensor:
51+
raise NotImplementedError()
52+
53+
54+
class SequentialFeedForward(FeedForwardBase):
2455
"""Vanilla sequential Feedforward with customizable activation."""
2556

2657
def __init__(
2758
self,
2859
dim: int,
29-
hidden_dim: int,
3060
activation: Callable[[torch.Tensor], torch.Tensor],
31-
use_bias=False,
32-
use_glu=False,
33-
pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
34-
post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
61+
config: cfg.FeedForwardConfig,
62+
pre_ff_norm: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
63+
post_ff_norm: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
3564
):
3665
"""Init function for feedforward layer.
3766
3867
Args:
3968
dim (int): embedding size.
40-
hidden_dim (int): hidden dim size of the feedforward layer.
4169
activation (Callable): activation function used in this block.
42-
use_bias (Boolean): whether to use bias. Default is false.
43-
use_glu (Boolean): whether to use glu in activation. Default is false.
44-
pre_ff_norm (Callable): pre feedforward norm. Default is None.
45-
post_ff_norm (Callable): post feedforward norm. Default is None.
70+
config (cfg.FeedForwardConfig): feedforward layer configuration.
71+
pre_ff_norm (Callable): pre feedforward norm. Default is identity.
72+
post_ff_norm (Callable): post feedforward norm. Default is identity.
4673
"""
47-
super().__init__()
48-
self.act = activation
49-
if use_glu:
50-
self.w1 = nn.Linear(dim, hidden_dim * 2, bias=use_bias)
74+
super().__init__(dim, activation, config, pre_ff_norm, post_ff_norm)
75+
if self.use_glu:
76+
self.w1 = nn.Linear(dim, self.hidden_dim * 2, bias=self.use_bias)
5177
else:
52-
self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
53-
self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
54-
self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
55-
self.post_ff_norm = post_ff_norm if post_ff_norm else lambda x: x
78+
self.w1 = nn.Linear(dim, self.hidden_dim, bias=self.use_bias)
79+
self.w2 = nn.Linear(self.hidden_dim, dim, bias=self.use_bias)
5680

5781
def forward(self, x):
5882
"""Forward pass for Feedforward layer.
@@ -68,7 +92,7 @@ def forward(self, x):
6892
return self.post_ff_norm(out)
6993

7094

71-
class GatedFeedForward(nn.Module):
95+
class GatedFeedForward(FeedForwardBase):
7296
"""Gated Feedforward with customizable activation.
7397
7498
https://arxiv.org/pdf/2002.05202v1.pdf
@@ -77,34 +101,48 @@ class GatedFeedForward(nn.Module):
77101
def __init__(
78102
self,
79103
dim: int,
80-
hidden_dim: int,
81104
activation: Callable[[torch.Tensor], torch.Tensor],
82-
use_bias=False,
83-
use_glu=False,
84-
pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
85-
post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
105+
config: cfg.FeedForwardConfig,
106+
pre_ff_norm: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
107+
post_ff_norm: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
86108
):
87109
"""Init function for feedforward layer.
88110
89111
Args:
90112
dim (int): embedding size.
91-
hidden_dim (int): hidden dim size of the feedforward layer.
92113
activation (Callable): activation function used in this block.
93-
use_bias (Boolean): whether to use bias. Default is false.
94-
use_glu (Boolean): whether to use glu in activation. Default is false.
95-
pre_ff_norm (Callable): pre feedforward norm. Default is None.
96-
post_ff_norm (Callable): post feedforward norm. Default is None.
114+
pre_ff_norm (Callable): pre feedforward norm. Default is identity.
115+
post_ff_norm (Callable): post feedforward norm. Default is identity.
116+
config (cfg.FeedForwardConfig): feedforward layer configuration.
97117
"""
98-
super().__init__()
99-
self.act = activation
100-
if use_glu:
101-
self.w1 = nn.Linear(dim, hidden_dim * 2, bias=use_bias)
118+
super().__init__(dim, activation, config, pre_ff_norm, post_ff_norm)
119+
120+
if self.use_glu:
121+
assert (
122+
self.config.use_separate_gating
123+
), 'use_separate_gating must be True for GE_GLU | SILU_GLU activation.'
124+
125+
if self.config.use_separate_gating:
126+
if self.use_glu:
127+
self.w1 = nn.Linear(dim, self.hidden_dim * 2, bias=self.use_bias)
128+
else:
129+
self.w1 = nn.Linear(dim, self.hidden_dim, bias=self.use_bias)
130+
self.w3 = nn.Linear(dim, self.hidden_dim, bias=self.use_bias)
102131
else:
103-
self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
104-
self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
105-
self.w3 = nn.Linear(dim, hidden_dim, bias=use_bias)
106-
self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
107-
self.post_ff_norm = post_ff_norm if post_ff_norm else lambda x: x
132+
self.w_gating = nn.Parameter(
133+
torch.ones((2, dim, self.hidden_dim), dtype=torch.float32),
134+
requires_grad=False,
135+
)
136+
self.gating_bias = (
137+
nn.Parameter(
138+
torch.zeros((2, self.hidden_dim), dtype=torch.float32),
139+
requires_grad=False,
140+
)
141+
if self.use_bias
142+
else torch.zeros((2, self.hidden_dim), dtype=torch.float32)
143+
)
144+
145+
self.w2 = nn.Linear(self.hidden_dim, dim, bias=self.use_bias)
108146

109147
def forward(self, x):
110148
"""Forward pass for Feedforward layer.
@@ -116,5 +154,12 @@ def forward(self, x):
116154
torch.Tensor: output tensor after feedforward.
117155
"""
118156
x_norm = self.pre_ff_norm(x)
119-
out = self.w2(self.act(self.w1(x_norm)) * self.w3(x_norm))
157+
if self.config.use_separate_gating:
158+
out = self.w2(self.act(self.w1(x_norm)) * self.w3(x_norm))
159+
else:
160+
out = self.w2(
161+
self.act(torch.matmul(x_norm, self.w_gating[0]) + self.gating_bias[0])
162+
* (torch.matmul(x_norm, self.w_gating[1]) + self.gating_bias[1])
163+
)
164+
120165
return self.post_ff_norm(out)

ai_edge_torch/generative/layers/feed_forward_test.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# ==============================================================================
1515

1616
from ai_edge_torch.generative.layers import feed_forward
17+
from ai_edge_torch.generative.layers import model_config as cfg
1718
import torch
1819
import torch.nn.functional as F
1920
from absl.testing import absltest as googletest
@@ -22,28 +23,32 @@
2223
class FeedForwardTest(googletest.TestCase):
2324

2425
def test_sequential_feed_forward(self):
26+
ff_config = cfg.FeedForwardConfig(
27+
type=cfg.FeedForwardType.SEQUENTIAL,
28+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
29+
intermediate_size=10,
30+
use_bias=True,
31+
)
2532
ff = feed_forward.SequentialFeedForward(
2633
dim=10,
27-
hidden_dim=10,
2834
activation=F.silu,
29-
use_bias=True,
30-
use_glu=False,
31-
pre_ff_norm=torch.nn.Identity(),
32-
post_ff_norm=torch.nn.Identity(),
35+
config=ff_config,
3336
)
3437
x = torch.ones((1, 10))
3538
out = ff(x)
3639
self.assertEqual(out.shape, (1, 10))
3740

3841
def test_gated_feed_forward(self):
42+
ff_config = cfg.FeedForwardConfig(
43+
type=cfg.FeedForwardType.GATED,
44+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
45+
intermediate_size=10,
46+
use_bias=True,
47+
)
3948
ff = feed_forward.GatedFeedForward(
4049
dim=10,
41-
hidden_dim=10,
4250
activation=F.silu,
43-
use_bias=True,
44-
use_glu=False,
45-
pre_ff_norm=torch.nn.Identity(),
46-
post_ff_norm=torch.nn.Identity(),
51+
config=ff_config,
4752
)
4853
x = torch.ones((1, 10))
4954
out = ff(x)

0 commit comments

Comments
 (0)