14
14
# ==============================================================================
15
15
# Common building blocks for FeedForward layers.
16
16
17
- from typing import Callable , Optional
17
+ import abc
18
+ from typing import Callable
18
19
20
+ import ai_edge_torch .generative .layers .model_config as cfg
19
21
import torch
20
22
from torch import nn
21
23
22
24
23
- class SequentialFeedForward (nn .Module ):
25
+ class FeedForwardBase (nn .Module ):
26
+ """Base class for feedforward layer."""
27
+
28
+ def __init__ (
29
+ self ,
30
+ dim : int ,
31
+ activation : Callable [[torch .Tensor ], torch .Tensor ],
32
+ config : cfg .FeedForwardConfig ,
33
+ pre_ff_norm : Callable [[torch .Tensor ], torch .Tensor ] = lambda x : x ,
34
+ post_ff_norm : Callable [[torch .Tensor ], torch .Tensor ] = lambda x : x ,
35
+ ):
36
+ super ().__init__ ()
37
+ self .dim = dim
38
+ self .act = activation
39
+ self .config = config
40
+ self .hidden_dim = config .intermediate_size
41
+ self .use_bias = config .use_bias
42
+ self .use_glu = (
43
+ config .activation .type == cfg .ActivationType .GE_GLU
44
+ or config .activation .type == cfg .ActivationType .SILU_GLU
45
+ )
46
+ self .pre_ff_norm = pre_ff_norm
47
+ self .post_ff_norm = post_ff_norm
48
+
49
+ @abc .abstractmethod
50
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
51
+ raise NotImplementedError ()
52
+
53
+
54
+ class SequentialFeedForward (FeedForwardBase ):
24
55
"""Vanilla sequential Feedforward with customizable activation."""
25
56
26
57
def __init__ (
27
58
self ,
28
59
dim : int ,
29
- hidden_dim : int ,
30
60
activation : Callable [[torch .Tensor ], torch .Tensor ],
31
- use_bias = False ,
32
- use_glu = False ,
33
- pre_ff_norm : Optional [Callable [[torch .Tensor ], torch .Tensor ]] = None ,
34
- post_ff_norm : Optional [Callable [[torch .Tensor ], torch .Tensor ]] = None ,
61
+ config : cfg .FeedForwardConfig ,
62
+ pre_ff_norm : Callable [[torch .Tensor ], torch .Tensor ] = lambda x : x ,
63
+ post_ff_norm : Callable [[torch .Tensor ], torch .Tensor ] = lambda x : x ,
35
64
):
36
65
"""Init function for feedforward layer.
37
66
38
67
Args:
39
68
dim (int): embedding size.
40
- hidden_dim (int): hidden dim size of the feedforward layer.
41
69
activation (Callable): activation function used in this block.
42
- use_bias (Boolean): whether to use bias. Default is false.
43
- use_glu (Boolean): whether to use glu in activation. Default is false.
44
- pre_ff_norm (Callable): pre feedforward norm. Default is None.
45
- post_ff_norm (Callable): post feedforward norm. Default is None.
70
+ config (cfg.FeedForwardConfig): feedforward layer configuration.
71
+ pre_ff_norm (Callable): pre feedforward norm. Default is identity.
72
+ post_ff_norm (Callable): post feedforward norm. Default is identity.
46
73
"""
47
- super ().__init__ ()
48
- self .act = activation
49
- if use_glu :
50
- self .w1 = nn .Linear (dim , hidden_dim * 2 , bias = use_bias )
74
+ super ().__init__ (dim , activation , config , pre_ff_norm , post_ff_norm )
75
+ if self .use_glu :
76
+ self .w1 = nn .Linear (dim , self .hidden_dim * 2 , bias = self .use_bias )
51
77
else :
52
- self .w1 = nn .Linear (dim , hidden_dim , bias = use_bias )
53
- self .w2 = nn .Linear (hidden_dim , dim , bias = use_bias )
54
- self .pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x : x
55
- self .post_ff_norm = post_ff_norm if post_ff_norm else lambda x : x
78
+ self .w1 = nn .Linear (dim , self .hidden_dim , bias = self .use_bias )
79
+ self .w2 = nn .Linear (self .hidden_dim , dim , bias = self .use_bias )
56
80
57
81
def forward (self , x ):
58
82
"""Forward pass for Feedforward layer.
@@ -68,7 +92,7 @@ def forward(self, x):
68
92
return self .post_ff_norm (out )
69
93
70
94
71
- class GatedFeedForward (nn . Module ):
95
+ class GatedFeedForward (FeedForwardBase ):
72
96
"""Gated Feedforward with customizable activation.
73
97
74
98
https://arxiv.org/pdf/2002.05202v1.pdf
@@ -77,34 +101,48 @@ class GatedFeedForward(nn.Module):
77
101
def __init__ (
78
102
self ,
79
103
dim : int ,
80
- hidden_dim : int ,
81
104
activation : Callable [[torch .Tensor ], torch .Tensor ],
82
- use_bias = False ,
83
- use_glu = False ,
84
- pre_ff_norm : Optional [Callable [[torch .Tensor ], torch .Tensor ]] = None ,
85
- post_ff_norm : Optional [Callable [[torch .Tensor ], torch .Tensor ]] = None ,
105
+ config : cfg .FeedForwardConfig ,
106
+ pre_ff_norm : Callable [[torch .Tensor ], torch .Tensor ] = lambda x : x ,
107
+ post_ff_norm : Callable [[torch .Tensor ], torch .Tensor ] = lambda x : x ,
86
108
):
87
109
"""Init function for feedforward layer.
88
110
89
111
Args:
90
112
dim (int): embedding size.
91
- hidden_dim (int): hidden dim size of the feedforward layer.
92
113
activation (Callable): activation function used in this block.
93
- use_bias (Boolean): whether to use bias. Default is false.
94
- use_glu (Boolean): whether to use glu in activation. Default is false.
95
- pre_ff_norm (Callable): pre feedforward norm. Default is None.
96
- post_ff_norm (Callable): post feedforward norm. Default is None.
114
+ pre_ff_norm (Callable): pre feedforward norm. Default is identity.
115
+ post_ff_norm (Callable): post feedforward norm. Default is identity.
116
+ config (cfg.FeedForwardConfig): feedforward layer configuration.
97
117
"""
98
- super ().__init__ ()
99
- self .act = activation
100
- if use_glu :
101
- self .w1 = nn .Linear (dim , hidden_dim * 2 , bias = use_bias )
118
+ super ().__init__ (dim , activation , config , pre_ff_norm , post_ff_norm )
119
+
120
+ if self .use_glu :
121
+ assert (
122
+ self .config .use_separate_gating
123
+ ), 'use_separate_gating must be True for GE_GLU | SILU_GLU activation.'
124
+
125
+ if self .config .use_separate_gating :
126
+ if self .use_glu :
127
+ self .w1 = nn .Linear (dim , self .hidden_dim * 2 , bias = self .use_bias )
128
+ else :
129
+ self .w1 = nn .Linear (dim , self .hidden_dim , bias = self .use_bias )
130
+ self .w3 = nn .Linear (dim , self .hidden_dim , bias = self .use_bias )
102
131
else :
103
- self .w1 = nn .Linear (dim , hidden_dim , bias = use_bias )
104
- self .w2 = nn .Linear (hidden_dim , dim , bias = use_bias )
105
- self .w3 = nn .Linear (dim , hidden_dim , bias = use_bias )
106
- self .pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x : x
107
- self .post_ff_norm = post_ff_norm if post_ff_norm else lambda x : x
132
+ self .w_gating = nn .Parameter (
133
+ torch .ones ((2 , dim , self .hidden_dim ), dtype = torch .float32 ),
134
+ requires_grad = False ,
135
+ )
136
+ self .gating_bias = (
137
+ nn .Parameter (
138
+ torch .zeros ((2 , self .hidden_dim ), dtype = torch .float32 ),
139
+ requires_grad = False ,
140
+ )
141
+ if self .use_bias
142
+ else torch .zeros ((2 , self .hidden_dim ), dtype = torch .float32 )
143
+ )
144
+
145
+ self .w2 = nn .Linear (self .hidden_dim , dim , bias = self .use_bias )
108
146
109
147
def forward (self , x ):
110
148
"""Forward pass for Feedforward layer.
@@ -116,5 +154,12 @@ def forward(self, x):
116
154
torch.Tensor: output tensor after feedforward.
117
155
"""
118
156
x_norm = self .pre_ff_norm (x )
119
- out = self .w2 (self .act (self .w1 (x_norm )) * self .w3 (x_norm ))
157
+ if self .config .use_separate_gating :
158
+ out = self .w2 (self .act (self .w1 (x_norm )) * self .w3 (x_norm ))
159
+ else :
160
+ out = self .w2 (
161
+ self .act (torch .matmul (x_norm , self .w_gating [0 ]) + self .gating_bias [0 ])
162
+ * (torch .matmul (x_norm , self .w_gating [1 ]) + self .gating_bias [1 ])
163
+ )
164
+
120
165
return self .post_ff_norm (out )
0 commit comments