@@ -280,41 +280,51 @@ def forward(self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None):
280
280
281
281
return x
282
282
283
+ #@torch.compiler.disable()
283
284
def _apply_learned_naflex_pos_embed (
284
285
self ,
285
286
x : torch .Tensor ,
286
287
naflex_grid_sizes : List [Tuple [int , int ]],
287
288
):
288
- orig_h , orig_w = self .pos_embed .shape [1 :3 ]
289
-
290
- # Determine unique grid sizes
291
- size_to_indices : Dict [Tuple [int , int ], List [int ]] = {}
292
- for bi , (h , w ) in enumerate (naflex_grid_sizes ):
293
- #k = h << 16 | w # FIXME can get jit compat with this
294
- k = (h , w )
295
- if not k in size_to_indices :
296
- size_to_indices [k ] = [bi ]
297
- else :
298
- size_to_indices [k ].append (bi )
299
-
300
289
# Handle each batch element separately with its own grid size
290
+ orig_h , orig_w = self .pos_embed .shape [1 :3 ]
301
291
pos_embed_nchw = self .pos_embed .permute (0 , 3 , 1 , 2 ).float () # B,C,H,W
302
- for k , batch_indices in size_to_indices .items ():
303
- h , w = k
304
- #h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
305
- # Interpolate only once for this (h, w)
306
- if (h == orig_h ) and (w == orig_w ):
292
+
293
+ def _interp (_size ):
294
+ if (_size [0 ] == orig_h ) and (_size [1 ] == orig_w ):
307
295
pos_embed_flat = self .pos_embed .reshape (1 , orig_h * orig_w , - 1 )
308
296
else :
309
297
pos_embed_flat = F .interpolate (
310
298
pos_embed_nchw ,
311
- size = ( h , w ) ,
299
+ size = _size ,
312
300
mode = self .pos_embed_interp_mode ,
313
301
align_corners = False ,
314
302
antialias = True ,
315
303
).flatten (2 ).transpose (1 , 2 )
316
- pos_embed_flat = pos_embed_flat .to (dtype = x .dtype )
304
+ return pos_embed_flat .to (dtype = x .dtype )
305
+
306
+ # FIXME leaving alternative code commented here for now for comparisons
307
+ # pos_embed_cache: Dict[Tuple[int, int], torch.Tensor] = {}
308
+ # for i, s in enumerate(naflex_grid_sizes):
309
+ # if s in pos_embed_cache:
310
+ # pos_embed_flat = pos_embed_cache[s]
311
+ # else:
312
+ # pos_embed_flat = _interp(s)
313
+ # pos_embed_cache[s] = pos_embed_flat
314
+ #
315
+ # seq_len = min(x.shape[1], pos_embed_flat.shape[1])
316
+ # x[i, :seq_len] += pos_embed_flat[0, :seq_len]
317
317
318
+ # Determine unique grid sizes
319
+ size_to_indices : Dict [Tuple [int , int ], List [int ]] = {}
320
+ for bi , k in enumerate (naflex_grid_sizes ):
321
+ # k = h << 16 | w # FIXME can get jit compat with this
322
+ size_to_indices .setdefault (k , []).append (bi )
323
+
324
+ for k , batch_indices in size_to_indices .items ():
325
+ # h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
326
+ # Interpolate only once for this (h, w)
327
+ pos_embed_flat = _interp (k )
318
328
seq_len = min (x .shape [1 ], pos_embed_flat .shape [1 ])
319
329
x [:, :seq_len ].index_add_ (
320
330
0 ,
@@ -1015,7 +1025,6 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
1015
1025
1016
1026
1017
1027
default_cfgs = generate_default_cfgs ({
1018
- 'vit_naflex_base_patch16' : _cfg (),
1019
1028
'vit_naflex_base_patch16_gap' : _cfg (),
1020
1029
'vit_naflex_base_patch16_map' : _cfg (),
1021
1030
@@ -1050,43 +1059,15 @@ def _create_vision_transformer_flex(variant, pretrained=False, **kwargs):
1050
1059
return model
1051
1060
1052
1061
1053
- @register_model
1054
- def vit_naflex_mediumd_patch16_reg4_gap (pretrained = False , ** kwargs ):
1055
- """ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
1056
- """
1057
- model_args = dict (
1058
- patch_size = 16 , embed_dim = 512 , depth = 20 , num_heads = 8 , init_values = 1e-5 ,
1059
- global_pool = 'avg' , class_token = False , reg_tokens = 4 , fc_norm = True , ** kwargs )
1060
- model = _create_vision_transformer_flex (
1061
- 'vit_naflex_mediumd_patch16_reg4_gap' , pretrained = pretrained , ** model_args )
1062
- return model
1063
-
1064
-
1065
- @register_model
1066
- def vit_naflex_base_patch16 (pretrained = False , ** kwargs ):
1067
- """ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
1068
-
1069
- This model supports:
1070
- 1. Variable aspect ratios and resolutions via patch coordinates
1071
- 2. Position embedding interpolation for arbitrary grid sizes
1072
- 3. Explicit patch coordinates and valid token masking
1073
- """
1074
- model_args = dict (
1075
- patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , ** kwargs )
1076
- model = _create_vision_transformer_flex (
1077
- 'vit_naflex_base_patch16' , pretrained = pretrained , ** model_args )
1078
- return model
1079
-
1080
-
1081
1062
@register_model
1082
1063
def vit_naflex_base_patch16_gap (pretrained = False , ** kwargs ):
1083
1064
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
1084
1065
"""
1085
1066
model_args = dict (
1086
- patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 ,
1087
- global_pool = 'avg' , class_token = False , reg_tokens = 4 , ** kwargs )
1067
+ patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , init_values = 1e-5 ,
1068
+ global_pool = 'avg' , class_token = False , reg_tokens = 4 , fc_norm = True , ** kwargs )
1088
1069
model = _create_vision_transformer_flex (
1089
- 'vit_naflex_base_patch16_gap' , pretrained = pretrained , ** model_args )
1070
+ 'vit_naflex_base_patch16_gap' , pretrained = pretrained , ** dict ( model_args , ** kwargs ) )
1090
1071
return model
1091
1072
1092
1073
@@ -1095,9 +1076,10 @@ def vit_naflex_base_patch16_map(pretrained=False, **kwargs):
1095
1076
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
1096
1077
"""
1097
1078
model_args = dict (
1098
- patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , global_pool = 'map' , ** kwargs )
1079
+ patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , init_values = 1e-5 ,
1080
+ global_pool = 'map' , reg_tokens = 1 )
1099
1081
model = _create_vision_transformer_flex (
1100
- 'vit_naflex_base_patch16_map' , pretrained = pretrained , ** model_args )
1082
+ 'vit_naflex_base_patch16_map' , pretrained = pretrained , ** dict ( model_args , ** kwargs ) )
1101
1083
return model
1102
1084
1103
1085
@@ -1112,9 +1094,9 @@ def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs):
1112
1094
"""
1113
1095
model_args = dict (
1114
1096
patch_size = 16 , embed_dim = 832 , depth = 21 , num_heads = 13 , mlp_ratio = 34 / 13 , init_values = 1e-5 ,
1115
- qkv_bias = False , class_token = False , reg_tokens = 1 , global_pool = 'avg' , fc_norm = True , ** kwargs )
1097
+ qkv_bias = False , class_token = False , reg_tokens = 1 , global_pool = 'avg' , fc_norm = True )
1116
1098
model = _create_vision_transformer_flex (
1117
- 'vit_naflex_so150m2_patch16_reg1_gap' , pretrained = pretrained , ** model_args )
1099
+ 'vit_naflex_so150m2_patch16_reg1_gap' , pretrained = pretrained , ** dict ( model_args , ** kwargs ) )
1118
1100
return model
1119
1101
1120
1102
@@ -1123,6 +1105,8 @@ def vit_naflex_base_patch16(pretrained: bool = False, **kwargs):
1123
1105
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
1124
1106
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
1125
1107
"""
1126
- model_args = dict (patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , global_pool = 'token' , class_token = True , pos_embed_grid_size = (14 , 14 ))
1108
+ model_args = dict (
1109
+ patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 ,
1110
+ global_pool = 'token' , class_token = True , pos_embed_grid_size = (14 , 14 ))
1127
1111
model = _create_vision_transformer_flex ('vit_naflex_base_patch16' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1128
1112
return model
0 commit comments