@@ -41,9 +41,10 @@ def __init__(
41
41
num_heads : Optional [int ] = None ,
42
42
qkv_bias : bool = True ,
43
43
qkv_separate : bool = False ,
44
+ drop : float = 0. ,
44
45
):
45
46
super ().__init__ ()
46
- embed_dim = embed_dim or in_features
47
+ self . embed_dim = embed_dim = embed_dim or in_features
47
48
self .in_features = in_features
48
49
self .out_features = out_features or in_features
49
50
ref_feat_size = to_2tuple (ref_feat_size )
@@ -82,7 +83,7 @@ def init_weights(self, zero_init_last: bool = False):
82
83
trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
83
84
nn .init .zeros_ (self .qkv .bias )
84
85
85
- def forward (self , x ):
86
+ def forward (self , x , pre_logits : bool = False ):
86
87
B , _ , H , W = x .shape
87
88
N = H * W
88
89
x = x .flatten (2 ).transpose (1 , 2 )
@@ -107,8 +108,12 @@ def forward(self, x):
107
108
attn = attn .softmax (dim = - 1 )
108
109
x = attn @ v
109
110
x = x .transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
111
+ x = x [:, 0 ]
112
+ x = self .drop (x )
113
+ if pre_logits :
114
+ return x
110
115
x = self .proj (x )
111
- return x [:, 0 ]
116
+ return x
112
117
113
118
114
119
class AttentionPool2d (nn .Module ):
@@ -132,9 +137,10 @@ def __init__(
132
137
num_heads : Optional [int ] = None ,
133
138
qkv_bias : bool = True ,
134
139
qkv_separate : bool = False ,
140
+ drop : float = 0. ,
135
141
):
136
142
super ().__init__ ()
137
- embed_dim = embed_dim or in_features
143
+ self . embed_dim = embed_dim = embed_dim or in_features
138
144
self .in_features = in_features
139
145
self .out_features = out_features or in_features
140
146
if num_heads is not None :
@@ -158,6 +164,7 @@ def __init__(
158
164
else :
159
165
self .q = self .k = self .v = None
160
166
self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
167
+ self .drop = nn .Dropout (drop )
161
168
self .proj = nn .Linear (embed_dim , self .out_features )
162
169
self .pos_embed = nn .Parameter (torch .zeros (self .seq_len + 1 , in_features ))
163
170
@@ -178,15 +185,12 @@ def init_weights(self, zero_init_last: bool = False):
178
185
nn .init .zeros_ (self .qkv .bias )
179
186
trunc_normal_ (self .pos_embed , std = in_features ** - 0.5 )
180
187
181
- def forward (self , x ):
188
+ def forward (self , x , pre_logits : bool = False ):
182
189
B , _ , H , W = x .shape
183
190
N = H * W
184
191
x = x .flatten (2 ).transpose (1 , 2 )
185
192
x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
186
- if self .seq_len != N :
187
- pos_embed = resample_abs_pos_embed (self .pos_embed .unsqueeze (0 ), (H , W ), num_prefix_tokens = 1 )
188
- else :
189
- pos_embed = self .pos_embed .unsqueeze (0 ).to (x .dtype )
193
+ pos_embed = resample_abs_pos_embed (self .pos_embed .unsqueeze (0 ), (H , W ), num_prefix_tokens = 1 )
190
194
x = x + pos_embed
191
195
192
196
if self .qkv is None :
@@ -205,5 +209,9 @@ def forward(self, x):
205
209
attn = attn .softmax (dim = - 1 )
206
210
x = attn @ v
207
211
x = x .transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
212
+ x = x [:, 0 ]
213
+ x = self .drop (x )
214
+ if pre_logits :
215
+ return x
208
216
x = self .proj (x )
209
- return x [:, 0 ]
217
+ return x
0 commit comments