Skip to content

Commit 3ae3f44

Browse files
wojtkerwightman
authored andcommitted
Fix positional embedding resampling for non-square inputs in ViT
1 parent 51ac8d2 commit 3ae3f44

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

timm/models/vision_transformer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,9 +669,11 @@ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
669669

670670
if self.dynamic_img_size:
671671
B, H, W, C = x.shape
672+
prev_grid_size = self.patch_embed.grid_size
672673
pos_embed = resample_abs_pos_embed(
673674
self.pos_embed,
674-
(H, W),
675+
new_size=(H, W),
676+
old_size=prev_grid_size,
675677
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
676678
)
677679
x = x.view(B, -1, C)

0 commit comments

Comments
 (0)