Skip to content

Commit cdc7bce

Browse files
committed
Make 2d attention pool modules compatible with head interface. Use attention pool in CLIP ResNets as head. Make separate set of GAP models w/ avg pool instead of attn pool.
1 parent 30ffa15 commit cdc7bce

File tree

2 files changed

+209
-104
lines changed

2 files changed

+209
-104
lines changed

timm/layers/attention_pool2d.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,12 @@ def __init__(
4141
num_heads: Optional[int] = None,
4242
qkv_bias: bool = True,
4343
qkv_separate: bool = False,
44-
drop: float = 0.,
44+
pool_type: str = 'token',
45+
avg_token: bool = True,
46+
drop_rate: float = 0.,
4547
):
4648
super().__init__()
49+
assert pool_type in ('', 'token')
4750
self.embed_dim = embed_dim = embed_dim or in_features
4851
self.in_features = in_features
4952
self.out_features = out_features or in_features
@@ -56,6 +59,7 @@ def __init__(
5659
num_heads = embed_dim // head_dim
5760
self.num_heads = num_heads
5861
self.head_dim = head_dim
62+
self.pool_type = pool_type.lower()
5963
self.scale = self.head_dim ** -0.5
6064
self.fused_attn = use_fused_attn()
6165

@@ -66,6 +70,7 @@ def __init__(
6670
self.qkv = None
6771
else:
6872
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
73+
self.drop = nn.Dropout(drop_rate)
6974
self.proj = nn.Linear(embed_dim, self.out_features)
7075
self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size)
7176

@@ -83,6 +88,23 @@ def init_weights(self, zero_init_last: bool = False):
8388
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
8489
nn.init.zeros_(self.qkv.bias)
8590

91+
def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = None):
92+
# NOTE: this module is being used as a head, so need compatible reset()
93+
if pool_type is not None:
94+
assert pool_type in ('', 'token')
95+
self.pool_type = pool_type
96+
if num_classes is not None:
97+
self.proj = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
98+
self.out_features = num_classes if num_classes > 0 else self.embed_dim
99+
100+
def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
101+
if self.pool_type == 'token':
102+
x = x[:, 0]
103+
else:
104+
# if not pooled, return spatial output without token
105+
x = x[:, 1:].reshape(x.shape[0], H, W, -1).permute(0, 3, 1, 2)
106+
return x
107+
86108
def forward(self, x, pre_logits: bool = False):
87109
B, _, H, W = x.shape
88110
N = H * W
@@ -111,8 +133,10 @@ def forward(self, x, pre_logits: bool = False):
111133
x = x[:, 0]
112134
x = self.drop(x)
113135
if pre_logits:
136+
x = self._pool(x, H, W)
114137
return x
115138
x = self.proj(x)
139+
x = self._pool(x, H, W)
116140
return x
117141

118142

@@ -137,9 +161,12 @@ def __init__(
137161
num_heads: Optional[int] = None,
138162
qkv_bias: bool = True,
139163
qkv_separate: bool = False,
140-
drop: float = 0.,
164+
pool_type: str = 'token',
165+
learned_token: bool = False,
166+
drop_rate: float = 0.,
141167
):
142168
super().__init__()
169+
assert pool_type in ('', 'token')
143170
self.embed_dim = embed_dim = embed_dim or in_features
144171
self.in_features = in_features
145172
self.out_features = out_features or in_features
@@ -153,9 +180,15 @@ def __init__(
153180
self.seq_len = self.feat_size[0] * self.feat_size[1]
154181
self.num_heads = num_heads
155182
self.head_dim = head_dim
183+
self.pool_type = pool_type
156184
self.scale = self.head_dim ** -0.5
157185
self.fused_attn = use_fused_attn()
158186

187+
if learned_token:
188+
self.token = nn.Parameter(torch.zeros(1, embed_dim))
189+
else:
190+
self.token = None
191+
159192
if qkv_separate:
160193
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
161194
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
@@ -164,7 +197,7 @@ def __init__(
164197
else:
165198
self.q = self.k = self.v = None
166199
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
167-
self.drop = nn.Dropout(drop)
200+
self.drop = nn.Dropout(drop_rate)
168201
self.proj = nn.Linear(embed_dim, self.out_features)
169202
self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features))
170203

@@ -185,11 +218,31 @@ def init_weights(self, zero_init_last: bool = False):
185218
nn.init.zeros_(self.qkv.bias)
186219
trunc_normal_(self.pos_embed, std=in_features ** -0.5)
187220

221+
def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = None):
222+
# NOTE: this module is being used as a head, so need compatible reset()
223+
if pool_type is not None:
224+
assert pool_type in ('', 'token')
225+
self.pool_type = pool_type
226+
if num_classes is not None:
227+
self.proj = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
228+
self.out_features = num_classes if num_classes > 0 else self.embed_dim
229+
230+
def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
231+
if self.pool_type == 'token':
232+
x = x[:, 0]
233+
else:
234+
# if not pooled, return spatial output without token
235+
x = x[:, 1:].reshape(x.shape[0], H, W, -1).permute(0, 3, 1, 2)
236+
return x
237+
188238
def forward(self, x, pre_logits: bool = False):
189239
B, _, H, W = x.shape
190240
N = H * W
191241
x = x.flatten(2).transpose(1, 2)
192-
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
242+
if self.token is not None:
243+
x = torch.cat([self.token.expand(x.shape[0], -1, -1), x], dim=1)
244+
else:
245+
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
193246
pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
194247
x = x + pos_embed
195248

@@ -209,9 +262,10 @@ def forward(self, x, pre_logits: bool = False):
209262
attn = attn.softmax(dim=-1)
210263
x = attn @ v
211264
x = x.transpose(1, 2).reshape(B, N + 1, -1)
212-
x = x[:, 0]
213265
x = self.drop(x)
214266
if pre_logits:
267+
x = self._pool(x, H, W)
215268
return x
216269
x = self.proj(x)
270+
x = self._pool(x, H, W)
217271
return x

0 commit comments

Comments
 (0)