@@ -41,9 +41,12 @@ def __init__(
41
41
num_heads : Optional [int ] = None ,
42
42
qkv_bias : bool = True ,
43
43
qkv_separate : bool = False ,
44
- drop : float = 0. ,
44
+ pool_type : str = 'token' ,
45
+ avg_token : bool = True ,
46
+ drop_rate : float = 0. ,
45
47
):
46
48
super ().__init__ ()
49
+ assert pool_type in ('' , 'token' )
47
50
self .embed_dim = embed_dim = embed_dim or in_features
48
51
self .in_features = in_features
49
52
self .out_features = out_features or in_features
@@ -56,6 +59,7 @@ def __init__(
56
59
num_heads = embed_dim // head_dim
57
60
self .num_heads = num_heads
58
61
self .head_dim = head_dim
62
+ self .pool_type = pool_type .lower ()
59
63
self .scale = self .head_dim ** - 0.5
60
64
self .fused_attn = use_fused_attn ()
61
65
@@ -66,6 +70,7 @@ def __init__(
66
70
self .qkv = None
67
71
else :
68
72
self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
73
+ self .drop = nn .Dropout (drop_rate )
69
74
self .proj = nn .Linear (embed_dim , self .out_features )
70
75
self .pos_embed = RotaryEmbedding (self .head_dim , in_pixels = False , ref_feat_shape = ref_feat_size )
71
76
@@ -83,6 +88,23 @@ def init_weights(self, zero_init_last: bool = False):
83
88
trunc_normal_ (self .qkv .weight , std = in_features ** - 0.5 )
84
89
nn .init .zeros_ (self .qkv .bias )
85
90
91
+ def reset (self , num_classes : Optional [int ] = None , pool_type : Optional [str ] = None ):
92
+ # NOTE: this module is being used as a head, so need compatible reset()
93
+ if pool_type is not None :
94
+ assert pool_type in ('' , 'token' )
95
+ self .pool_type = pool_type
96
+ if num_classes is not None :
97
+ self .proj = nn .Linear (self .in_features , num_classes ) if num_classes > 0 else nn .Identity ()
98
+ self .out_features = num_classes if num_classes > 0 else self .embed_dim
99
+
100
+ def _pool (self , x : torch .Tensor , H : int , W : int ) -> torch .Tensor :
101
+ if self .pool_type == 'token' :
102
+ x = x [:, 0 ]
103
+ else :
104
+ # if not pooled, return spatial output without token
105
+ x = x [:, 1 :].reshape (x .shape [0 ], H , W , - 1 ).permute (0 , 3 , 1 , 2 )
106
+ return x
107
+
86
108
def forward (self , x , pre_logits : bool = False ):
87
109
B , _ , H , W = x .shape
88
110
N = H * W
@@ -111,8 +133,10 @@ def forward(self, x, pre_logits: bool = False):
111
133
x = x [:, 0 ]
112
134
x = self .drop (x )
113
135
if pre_logits :
136
+ x = self ._pool (x , H , W )
114
137
return x
115
138
x = self .proj (x )
139
+ x = self ._pool (x , H , W )
116
140
return x
117
141
118
142
@@ -137,9 +161,12 @@ def __init__(
137
161
num_heads : Optional [int ] = None ,
138
162
qkv_bias : bool = True ,
139
163
qkv_separate : bool = False ,
140
- drop : float = 0. ,
164
+ pool_type : str = 'token' ,
165
+ learned_token : bool = False ,
166
+ drop_rate : float = 0. ,
141
167
):
142
168
super ().__init__ ()
169
+ assert pool_type in ('' , 'token' )
143
170
self .embed_dim = embed_dim = embed_dim or in_features
144
171
self .in_features = in_features
145
172
self .out_features = out_features or in_features
@@ -153,9 +180,15 @@ def __init__(
153
180
self .seq_len = self .feat_size [0 ] * self .feat_size [1 ]
154
181
self .num_heads = num_heads
155
182
self .head_dim = head_dim
183
+ self .pool_type = pool_type
156
184
self .scale = self .head_dim ** - 0.5
157
185
self .fused_attn = use_fused_attn ()
158
186
187
+ if learned_token :
188
+ self .token = nn .Parameter (torch .zeros (1 , embed_dim ))
189
+ else :
190
+ self .token = None
191
+
159
192
if qkv_separate :
160
193
self .q = nn .Linear (in_features , embed_dim , bias = qkv_bias )
161
194
self .k = nn .Linear (in_features , embed_dim , bias = qkv_bias )
@@ -164,7 +197,7 @@ def __init__(
164
197
else :
165
198
self .q = self .k = self .v = None
166
199
self .qkv = nn .Linear (in_features , embed_dim * 3 , bias = qkv_bias )
167
- self .drop = nn .Dropout (drop )
200
+ self .drop = nn .Dropout (drop_rate )
168
201
self .proj = nn .Linear (embed_dim , self .out_features )
169
202
self .pos_embed = nn .Parameter (torch .zeros (self .seq_len + 1 , in_features ))
170
203
@@ -185,11 +218,31 @@ def init_weights(self, zero_init_last: bool = False):
185
218
nn .init .zeros_ (self .qkv .bias )
186
219
trunc_normal_ (self .pos_embed , std = in_features ** - 0.5 )
187
220
221
+ def reset (self , num_classes : Optional [int ] = None , pool_type : Optional [str ] = None ):
222
+ # NOTE: this module is being used as a head, so need compatible reset()
223
+ if pool_type is not None :
224
+ assert pool_type in ('' , 'token' )
225
+ self .pool_type = pool_type
226
+ if num_classes is not None :
227
+ self .proj = nn .Linear (self .in_features , num_classes ) if num_classes > 0 else nn .Identity ()
228
+ self .out_features = num_classes if num_classes > 0 else self .embed_dim
229
+
230
+ def _pool (self , x : torch .Tensor , H : int , W : int ) -> torch .Tensor :
231
+ if self .pool_type == 'token' :
232
+ x = x [:, 0 ]
233
+ else :
234
+ # if not pooled, return spatial output without token
235
+ x = x [:, 1 :].reshape (x .shape [0 ], H , W , - 1 ).permute (0 , 3 , 1 , 2 )
236
+ return x
237
+
188
238
def forward (self , x , pre_logits : bool = False ):
189
239
B , _ , H , W = x .shape
190
240
N = H * W
191
241
x = x .flatten (2 ).transpose (1 , 2 )
192
- x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
242
+ if self .token is not None :
243
+ x = torch .cat ([self .token .expand (x .shape [0 ], - 1 , - 1 ), x ], dim = 1 )
244
+ else :
245
+ x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
193
246
pos_embed = resample_abs_pos_embed (self .pos_embed .unsqueeze (0 ), (H , W ), num_prefix_tokens = 1 )
194
247
x = x + pos_embed
195
248
@@ -209,9 +262,10 @@ def forward(self, x, pre_logits: bool = False):
209
262
attn = attn .softmax (dim = - 1 )
210
263
x = attn @ v
211
264
x = x .transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
212
- x = x [:, 0 ]
213
265
x = self .drop (x )
214
266
if pre_logits :
267
+ x = self ._pool (x , H , W )
215
268
return x
216
269
x = self .proj (x )
270
+ x = self ._pool (x , H , W )
217
271
return x
0 commit comments