Skip to content

Commit 37bbac1

Browse files
committed
Fix checkpoint_filter_fn
1 parent fc0b6ad commit 37bbac1

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

timm/models/tnt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def checkpoint_filter_fn(state_dict, model):
488488
k = k.replace('outer_attn', 'attn_out')
489489
k = k.replace('outer_norm2', 'norm_mlp')
490490
k = k.replace('outer_mlp', 'mlp')
491-
if k == 'pixel_pos':
491+
if k == 'pixel_pos' and model.pixel_embed.legacy == False:
492492
B, N, C = v.shape
493493
H = W = int(N ** 0.5)
494494
assert H * W == N

0 commit comments

Comments
 (0)