Skip to content

Commit ee27b73

Browse files
committed
Further pos embed tweaks, rejig model defs for testing
1 parent 3dc90ed commit ee27b73

File tree

1 file changed

+40
-56
lines changed

1 file changed

+40
-56
lines changed

timm/models/vision_transformer_flex.py

Lines changed: 40 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -280,41 +280,51 @@ def forward(self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None):
280280

281281
return x
282282

283+
#@torch.compiler.disable()
283284
def _apply_learned_naflex_pos_embed(
284285
self,
285286
x: torch.Tensor,
286287
naflex_grid_sizes: List[Tuple[int, int]],
287288
):
288-
orig_h, orig_w = self.pos_embed.shape[1:3]
289-
290-
# Determine unique grid sizes
291-
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
292-
for bi, (h, w) in enumerate(naflex_grid_sizes):
293-
#k = h << 16 | w # FIXME can get jit compat with this
294-
k = (h, w)
295-
if not k in size_to_indices:
296-
size_to_indices[k] = [bi]
297-
else:
298-
size_to_indices[k].append(bi)
299-
300289
# Handle each batch element separately with its own grid size
290+
orig_h, orig_w = self.pos_embed.shape[1:3]
301291
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W
302-
for k, batch_indices in size_to_indices.items():
303-
h, w = k
304-
#h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
305-
# Interpolate only once for this (h, w)
306-
if (h == orig_h) and (w == orig_w):
292+
293+
def _interp(_size):
294+
if (_size[0] == orig_h) and (_size[1] == orig_w):
307295
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
308296
else:
309297
pos_embed_flat = F.interpolate(
310298
pos_embed_nchw,
311-
size=(h, w),
299+
size=_size,
312300
mode=self.pos_embed_interp_mode,
313301
align_corners=False,
314302
antialias=True,
315303
).flatten(2).transpose(1, 2)
316-
pos_embed_flat = pos_embed_flat.to(dtype=x.dtype)
304+
return pos_embed_flat.to(dtype=x.dtype)
305+
306+
# FIXME leaving alternative code commented here for now for comparisons
307+
# pos_embed_cache: Dict[Tuple[int, int], torch.Tensor] = {}
308+
# for i, s in enumerate(naflex_grid_sizes):
309+
# if s in pos_embed_cache:
310+
# pos_embed_flat = pos_embed_cache[s]
311+
# else:
312+
# pos_embed_flat = _interp(s)
313+
# pos_embed_cache[s] = pos_embed_flat
314+
#
315+
# seq_len = min(x.shape[1], pos_embed_flat.shape[1])
316+
# x[i, :seq_len] += pos_embed_flat[0, :seq_len]
317317

318+
# Determine unique grid sizes
319+
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
320+
for bi, k in enumerate(naflex_grid_sizes):
321+
# k = h << 16 | w # FIXME can get jit compat with this
322+
size_to_indices.setdefault(k, []).append(bi)
323+
324+
for k, batch_indices in size_to_indices.items():
325+
# h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
326+
# Interpolate only once for this (h, w)
327+
pos_embed_flat = _interp(k)
318328
seq_len = min(x.shape[1], pos_embed_flat.shape[1])
319329
x[:, :seq_len].index_add_(
320330
0,
@@ -1015,7 +1025,6 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
10151025

10161026

10171027
default_cfgs = generate_default_cfgs({
1018-
'vit_naflex_base_patch16': _cfg(),
10191028
'vit_naflex_base_patch16_gap': _cfg(),
10201029
'vit_naflex_base_patch16_map': _cfg(),
10211030

@@ -1050,43 +1059,15 @@ def _create_vision_transformer_flex(variant, pretrained=False, **kwargs):
10501059
return model
10511060

10521061

1053-
@register_model
1054-
def vit_naflex_mediumd_patch16_reg4_gap(pretrained=False, **kwargs):
1055-
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
1056-
"""
1057-
model_args = dict(
1058-
patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5,
1059-
global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs)
1060-
model = _create_vision_transformer_flex(
1061-
'vit_naflex_mediumd_patch16_reg4_gap', pretrained=pretrained, **model_args)
1062-
return model
1063-
1064-
1065-
@register_model
1066-
def vit_naflex_base_patch16(pretrained=False, **kwargs):
1067-
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
1068-
1069-
This model supports:
1070-
1. Variable aspect ratios and resolutions via patch coordinates
1071-
2. Position embedding interpolation for arbitrary grid sizes
1072-
3. Explicit patch coordinates and valid token masking
1073-
"""
1074-
model_args = dict(
1075-
patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
1076-
model = _create_vision_transformer_flex(
1077-
'vit_naflex_base_patch16', pretrained=pretrained, **model_args)
1078-
return model
1079-
1080-
10811062
@register_model
10821063
def vit_naflex_base_patch16_gap(pretrained=False, **kwargs):
10831064
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
10841065
"""
10851066
model_args = dict(
1086-
patch_size=16, embed_dim=768, depth=12, num_heads=12,
1087-
global_pool='avg', class_token=False, reg_tokens=4, **kwargs)
1067+
patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
1068+
global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs)
10881069
model = _create_vision_transformer_flex(
1089-
'vit_naflex_base_patch16_gap', pretrained=pretrained, **model_args)
1070+
'vit_naflex_base_patch16_gap', pretrained=pretrained, **dict(model_args, **kwargs))
10901071
return model
10911072

10921073

@@ -1095,9 +1076,10 @@ def vit_naflex_base_patch16_map(pretrained=False, **kwargs):
10951076
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
10961077
"""
10971078
model_args = dict(
1098-
patch_size=16, embed_dim=768, depth=12, num_heads=12, global_pool='map', **kwargs)
1079+
patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
1080+
global_pool='map', reg_tokens=1)
10991081
model = _create_vision_transformer_flex(
1100-
'vit_naflex_base_patch16_map', pretrained=pretrained, **model_args)
1082+
'vit_naflex_base_patch16_map', pretrained=pretrained, **dict(model_args, **kwargs))
11011083
return model
11021084

11031085

@@ -1112,9 +1094,9 @@ def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs):
11121094
"""
11131095
model_args = dict(
11141096
patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
1115-
qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True, **kwargs)
1097+
qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True)
11161098
model = _create_vision_transformer_flex(
1117-
'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **model_args)
1099+
'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **dict(model_args, **kwargs))
11181100
return model
11191101

11201102

@@ -1123,6 +1105,8 @@ def vit_naflex_base_patch16(pretrained: bool = False, **kwargs):
11231105
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
11241106
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
11251107
"""
1126-
model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, global_pool='token', class_token=True, pos_embed_grid_size=(14, 14))
1108+
model_args = dict(
1109+
patch_size=16, embed_dim=768, depth=12, num_heads=12,
1110+
global_pool='token', class_token=True, pos_embed_grid_size=(14, 14))
11271111
model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
11281112
return model

0 commit comments

Comments
 (0)