Skip to content

Commit 3c7822c

Browse files
wojtkerwightman
authored andcommitted
fix pos embed dynamic resampling for deit
1 parent 3ae3f44 commit 3c7822c

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

timm/models/deit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,11 @@ def set_distilled_training(self, enable=True):
7575
def _pos_embed(self, x):
7676
if self.dynamic_img_size:
7777
B, H, W, C = x.shape
78+
prev_grid_size = self.patch_embed.grid_size
7879
pos_embed = resample_abs_pos_embed(
7980
self.pos_embed,
80-
(H, W),
81+
new_size=(H, W),
82+
old_size=prev_grid_size,
8183
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
8284
)
8385
x = x.view(B, -1, C)

0 commit comments

Comments
 (0)