Skip to content

Commit 30ffa15

Browse files
committed
Fix load of larger ResNet CLIP models, experimenting with making AttentionPool *the* head, seems to fine-tune better, one less layer.
1 parent 5e9ff57 commit 30ffa15

File tree

2 files changed

+116
-64
lines changed

2 files changed

+116
-64
lines changed

timm/layers/attention_pool2d.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ def __init__(
4141
num_heads: Optional[int] = None,
4242
qkv_bias: bool = True,
4343
qkv_separate: bool = False,
44+
drop: float = 0.,
4445
):
4546
super().__init__()
46-
embed_dim = embed_dim or in_features
47+
self.embed_dim = embed_dim = embed_dim or in_features
4748
self.in_features = in_features
4849
self.out_features = out_features or in_features
4950
ref_feat_size = to_2tuple(ref_feat_size)
@@ -82,7 +83,7 @@ def init_weights(self, zero_init_last: bool = False):
8283
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
8384
nn.init.zeros_(self.qkv.bias)
8485

85-
def forward(self, x):
86+
def forward(self, x, pre_logits: bool = False):
8687
B, _, H, W = x.shape
8788
N = H * W
8889
x = x.flatten(2).transpose(1, 2)
@@ -107,8 +108,12 @@ def forward(self, x):
107108
attn = attn.softmax(dim=-1)
108109
x = attn @ v
109110
x = x.transpose(1, 2).reshape(B, N + 1, -1)
111+
x = x[:, 0]
112+
x = self.drop(x)
113+
if pre_logits:
114+
return x
110115
x = self.proj(x)
111-
return x[:, 0]
116+
return x
112117

113118

114119
class AttentionPool2d(nn.Module):
@@ -132,9 +137,10 @@ def __init__(
132137
num_heads: Optional[int] = None,
133138
qkv_bias: bool = True,
134139
qkv_separate: bool = False,
140+
drop: float = 0.,
135141
):
136142
super().__init__()
137-
embed_dim = embed_dim or in_features
143+
self.embed_dim = embed_dim = embed_dim or in_features
138144
self.in_features = in_features
139145
self.out_features = out_features or in_features
140146
if num_heads is not None:
@@ -158,6 +164,7 @@ def __init__(
158164
else:
159165
self.q = self.k = self.v = None
160166
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
167+
self.drop = nn.Dropout(drop)
161168
self.proj = nn.Linear(embed_dim, self.out_features)
162169
self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features))
163170

@@ -178,15 +185,12 @@ def init_weights(self, zero_init_last: bool = False):
178185
nn.init.zeros_(self.qkv.bias)
179186
trunc_normal_(self.pos_embed, std=in_features ** -0.5)
180187

181-
def forward(self, x):
188+
def forward(self, x, pre_logits: bool = False):
182189
B, _, H, W = x.shape
183190
N = H * W
184191
x = x.flatten(2).transpose(1, 2)
185192
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
186-
if self.seq_len != N:
187-
pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
188-
else:
189-
pos_embed = self.pos_embed.unsqueeze(0).to(x.dtype)
193+
pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
190194
x = x + pos_embed
191195

192196
if self.qkv is None:
@@ -205,5 +209,9 @@ def forward(self, x):
205209
attn = attn.softmax(dim=-1)
206210
x = attn @ v
207211
x = x.transpose(1, 2).reshape(B, N + 1, -1)
212+
x = x[:, 0]
213+
x = self.drop(x)
214+
if pre_logits:
215+
return x
208216
x = self.proj(x)
209-
return x[:, 0]
217+
return x

0 commit comments

Comments
 (0)