@@ -52,6 +52,7 @@ def batch_patchify(
52
52
53
53
nh , nw = H // ph , W // pw
54
54
patches = x .view (B , C , nh , ph , nw , pw ).permute (0 , 2 , 4 , 3 , 5 , 1 ).reshape (B , nh * nw , ph * pw * C )
55
+ # FIXME confirm we want 'channels last' in the patch channel layout, egg ph, ph, C instead of C, ph, hw
55
56
56
57
return patches , (nh , nw )
57
58
@@ -297,7 +298,7 @@ def _apply_learned_naflex_pos_embed(
297
298
size_to_indices [k ].append (bi )
298
299
299
300
# Handle each batch element separately with its own grid size
300
- pos_embed_nchw = self .pos_embed .permute (0 , 3 , 1 , 2 ) # B,C,H,W
301
+ pos_embed_nchw = self .pos_embed .permute (0 , 3 , 1 , 2 ). float () # B,C,H,W
301
302
for k , batch_indices in size_to_indices .items ():
302
303
h , w = k
303
304
#h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
@@ -312,9 +313,14 @@ def _apply_learned_naflex_pos_embed(
312
313
align_corners = False ,
313
314
antialias = True ,
314
315
).flatten (2 ).transpose (1 , 2 )
316
+ pos_embed_flat = pos_embed_flat .to (dtype = x .dtype )
315
317
316
318
seq_len = min (x .shape [1 ], pos_embed_flat .shape [1 ])
317
- x [batch_indices , :seq_len ].add_ (pos_embed_flat [:, :seq_len ])
319
+ x [:, :seq_len ].index_add_ (
320
+ 0 ,
321
+ torch .as_tensor (batch_indices , device = x .device ),
322
+ pos_embed_flat [:, :seq_len ].expand (len (batch_indices ), - 1 , - 1 )
323
+ )
318
324
319
325
def _apply_learned_pos_embed (
320
326
self ,
@@ -328,12 +334,13 @@ def _apply_learned_pos_embed(
328
334
else :
329
335
# Resize if needed - directly using F.interpolate
330
336
pos_embed_flat = F .interpolate (
331
- self .pos_embed .permute (0 , 3 , 1 , 2 ), # B,C,H,W
337
+ self .pos_embed .permute (0 , 3 , 1 , 2 ). float () , # B,C,H,W
332
338
size = grid_size ,
333
339
mode = self .pos_embed_interp_mode ,
334
340
align_corners = False ,
335
341
antialias = True ,
336
342
).flatten (2 ).transpose (1 , 2 )
343
+ pos_embed_flat = pos_embed_flat .to (dtype = x .dtype )
337
344
338
345
x .add_ (pos_embed_flat )
339
346
@@ -806,21 +813,20 @@ def _pool(
806
813
# Apply the mask to extract only valid tokens
807
814
x = x [:, self .num_prefix_tokens :] # prefix tokens not included in pooling
808
815
816
+ patch_valid_float = patch_valid .to (x .dtype )
809
817
if pool_type == 'avg' :
810
- # Compute masked average pooling
811
- # Sum valid tokens and divide by count of valid tokens
812
- masked_sums = (x * patch_valid .unsqueeze (- 1 ).float ()).sum (dim = 1 )
813
- valid_counts = patch_valid .float ().sum (dim = 1 , keepdim = True ).clamp (min = 1 )
818
+ # Compute masked average pooling, sum valid tokens and divide by count of valid tokens
819
+ masked_sums = (x * patch_valid_float .unsqueeze (- 1 )).sum (dim = 1 )
820
+ valid_counts = patch_valid_float .sum (dim = 1 , keepdim = True ).clamp (min = 1 )
814
821
pooled = masked_sums / valid_counts
815
822
return pooled
816
823
elif pool_type == 'avgmax' :
817
824
# For avgmax, compute masked average and masked max
818
- # For max, we set masked positions to large negative value
819
- masked_sums = (x * patch_valid .unsqueeze (- 1 ).float ()).sum (dim = 1 )
820
- valid_counts = patch_valid .float ().sum (dim = 1 , keepdim = True ).clamp (min = 1 )
825
+ masked_sums = (x * patch_valid_float .unsqueeze (- 1 )).sum (dim = 1 )
826
+ valid_counts = patch_valid_float .sum (dim = 1 , keepdim = True ).clamp (min = 1 )
821
827
masked_avg = masked_sums / valid_counts
822
828
823
- # For max pooling with mask
829
+ # For max pooling we set masked positions to large negative value
824
830
masked_x = x .clone ()
825
831
masked_x [~ patch_valid ] = torch .finfo (masked_x .dtype ).min
826
832
masked_max = masked_x .max (dim = 1 )[0 ]
@@ -915,6 +921,82 @@ def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
915
921
return init_weights_vit_timm
916
922
917
923
924
+ def checkpoint_filter_fn (state_dict , model ):
925
+ """Handle state dict conversion from original ViT to the new version with combined embedding."""
926
+ from .vision_transformer import checkpoint_filter_fn as orig_filter_fn
927
+
928
+ # Handle CombinedEmbed module pattern
929
+ out_dict = {}
930
+ for k , v in state_dict .items ():
931
+ # Convert tokens and embeddings to combined_embed structure
932
+ if k == 'pos_embed' :
933
+ # Handle position embedding format conversion - from (1, N, C) to (1, H, W, C)
934
+ if hasattr (model .embeds , 'pos_embed' ) and v .ndim == 3 :
935
+ num_cls_token = 0
936
+ num_reg_token = 0
937
+ if 'reg_token' in state_dict :
938
+ num_reg_token = state_dict ['reg_token' ].shape [1 ]
939
+ if 'cls_token' in state_dict :
940
+ num_cls_token = state_dict ['cls_token' ].shape [1 ]
941
+ num_prefix_tokens = num_cls_token + num_reg_token
942
+
943
+ # Original format is (1, N, C), need to reshape to (1, H, W, C)
944
+ num_patches = v .shape [1 ]
945
+ num_patches_no_prefix = num_patches - num_prefix_tokens
946
+ grid_size_no_prefix = math .sqrt (num_patches_no_prefix )
947
+ grid_size = math .sqrt (num_patches )
948
+ if (grid_size_no_prefix != grid_size and (
949
+ grid_size_no_prefix .is_integer () and not grid_size .is_integer ())):
950
+ # make a decision, did the pos_embed of the original include the prefix tokens?
951
+ num_patches = num_patches_no_prefix
952
+ cls_token_emb = v [:, 0 :num_cls_token ]
953
+ if cls_token_emb .numel ():
954
+ state_dict ['cls_token' ] += cls_token_emb
955
+ reg_token_emb = v [:, num_cls_token :num_reg_token ]
956
+ if reg_token_emb .numel ():
957
+ state_dict ['reg_token' ] += reg_token_emb
958
+ v = v [:, num_prefix_tokens :]
959
+ grid_size = grid_size_no_prefix
960
+ grid_size = int (grid_size )
961
+
962
+ # Check if it's a perfect square for a standard grid
963
+ if grid_size * grid_size == num_patches :
964
+ # Reshape from (1, N, C) to (1, H, W, C)
965
+ v = v .reshape (1 , grid_size , grid_size , v .shape [2 ])
966
+ else :
967
+ # Not a square grid, we need to get the actual dimensions
968
+ if hasattr (model .embeds .patch_embed , 'grid_size' ):
969
+ h , w = model .embeds .patch_embed .grid_size
970
+ if h * w == num_patches :
971
+ # We have the right dimensions
972
+ v = v .reshape (1 , h , w , v .shape [2 ])
973
+ else :
974
+ # Dimensions don't match, use interpolation
975
+ _logger .warning (
976
+ f"Position embedding size mismatch: checkpoint={ num_patches } , model={ (h * w )} . "
977
+ f"Using default initialization and will resize in forward pass."
978
+ )
979
+ # Keep v as is, the forward pass will handle resizing
980
+
981
+ out_dict ['embeds.pos_embed' ] = v
982
+ elif k == 'cls_token' :
983
+ out_dict ['embeds.cls_token' ] = v
984
+ elif k == 'reg_token' :
985
+ out_dict ['embeds.reg_token' ] = v
986
+ # Convert patch_embed.X to embeds.patch_embed.X
987
+ elif k .startswith ('patch_embed.' ):
988
+ suffix = k [12 :]
989
+ if suffix == 'proj.weight' :
990
+ # FIXME confirm patchify memory layout across use cases
991
+ v = v .permute (0 , 2 , 3 , 1 ).flatten (1 )
992
+ new_key = 'embeds.' + suffix
993
+ out_dict [new_key ] = v
994
+ else :
995
+ out_dict [k ] = v
996
+
997
+ return out_dict
998
+
999
+
918
1000
def _cfg (url : str = '' , ** kwargs ) -> Dict [str , Any ]:
919
1001
return {
920
1002
'url' : url ,
@@ -936,6 +1018,26 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
936
1018
'vit_naflex_base_patch16' : _cfg (),
937
1019
'vit_naflex_base_patch16_gap' : _cfg (),
938
1020
'vit_naflex_base_patch16_map' : _cfg (),
1021
+
1022
+ # sbb model testijg
1023
+ 'vit_naflex_mediumd_patch16_reg4_gap.sbb2_r256_e200_in12k_ft_in1k' : _cfg (
1024
+ hf_hub_id = 'timm/vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k' ,
1025
+ input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1026
+ 'vit_naflex_so150m2_patch16_reg1_gap.sbb_r256_e200_in12k_ft_in1k' : _cfg (
1027
+ hf_hub_id = 'timm/vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k_ft_in1k' ,
1028
+ input_size = (3 , 256 , 256 ), crop_pct = 1.0 ),
1029
+ 'vit_naflex_so150m2_patch16_reg1_gap.sbb_r384_e200_in12k_ft_in1k' : _cfg (
1030
+ hf_hub_id = 'timm/vit_so150m2_patch16_reg1_gap_384.sbb_e200_in12k_ft_in1k' ,
1031
+ input_size = (3 , 384 , 384 ), crop_pct = 1.0 ),
1032
+ 'vit_naflex_so150m2_patch16_reg1_gap.sbb_r448_e200_in12k_ft_in1k' : _cfg (
1033
+ hf_hub_id = 'timm/vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k' ,
1034
+ input_size = (3 , 448 , 448 ), crop_pct = 1.0 , crop_mode = 'squash' ),
1035
+
1036
+ # traditional vit testing
1037
+ 'vit_naflex_base_patch16.augreg2_r224_in21k_ft_in1k' : _cfg (
1038
+ hf_hub_id = 'timm/vit_base_patch16_224.augreg2_in21k_ft_in1k' ),
1039
+ 'vit_naflex_base_patch8.augreg2_r224_in21k_ft_in1k' : _cfg (
1040
+ hf_hub_id = 'timm/vit_base_patch16_224.augreg2_in21k_ft_in1k' ),
939
1041
})
940
1042
941
1043
@@ -948,10 +1050,22 @@ def _create_vision_transformer_flex(variant, pretrained=False, **kwargs):
948
1050
return model
949
1051
950
1052
1053
+ @register_model
1054
+ def vit_naflex_mediumd_patch16_reg4_gap (pretrained = False , ** kwargs ):
1055
+ """ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
1056
+ """
1057
+ model_args = dict (
1058
+ patch_size = 16 , embed_dim = 512 , depth = 20 , num_heads = 8 , init_values = 1e-5 ,
1059
+ global_pool = 'avg' , class_token = False , reg_tokens = 4 , fc_norm = True , ** kwargs )
1060
+ model = _create_vision_transformer_flex (
1061
+ 'vit_naflex_mediumd_patch16_reg4_gap' , pretrained = pretrained , ** model_args )
1062
+ return model
1063
+
1064
+
951
1065
@register_model
952
1066
def vit_naflex_base_patch16 (pretrained = False , ** kwargs ):
953
1067
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
954
-
1068
+
955
1069
This model supports:
956
1070
1. Variable aspect ratios and resolutions via patch coordinates
957
1071
2. Position embedding interpolation for arbitrary grid sizes
@@ -987,54 +1101,28 @@ def vit_naflex_base_patch16_map(pretrained=False, **kwargs):
987
1101
return model
988
1102
989
1103
990
- def checkpoint_filter_fn ( state_dict , model ):
991
- """Handle state dict conversion from original ViT to the new version with combined embedding."""
992
- from . vision_transformer import checkpoint_filter_fn as orig_filter_fn
1104
+ @ register_model
1105
+ def vit_naflex_so150m2_patch16_reg1_gap ( pretrained = False , ** kwargs ):
1106
+ """ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
993
1107
994
- # FIXME conversion of existing vit checkpoints has not been finished or tested
1108
+ This model supports:
1109
+ 1. Variable aspect ratios and resolutions via patch coordinates
1110
+ 2. Position embedding interpolation for arbitrary grid sizes
1111
+ 3. Explicit patch coordinates and valid token masking
1112
+ """
1113
+ model_args = dict (
1114
+ patch_size = 16 , embed_dim = 832 , depth = 21 , num_heads = 13 , mlp_ratio = 34 / 13 , init_values = 1e-5 ,
1115
+ qkv_bias = False , class_token = False , reg_tokens = 1 , global_pool = 'avg' , fc_norm = True , ** kwargs )
1116
+ model = _create_vision_transformer_flex (
1117
+ 'vit_naflex_so150m2_patch16_reg1_gap' , pretrained = pretrained , ** model_args )
1118
+ return model
995
1119
996
- # Handle CombinedEmbed module pattern
997
- out_dict = {}
998
- for k , v in state_dict .items ():
999
- # Convert tokens and embeddings to combined_embed structure
1000
- if k == 'pos_embed' :
1001
- # Handle position embedding format conversion - from (1, N, C) to (1, H, W, C)
1002
- if hasattr (model .embeds , 'pos_embed' ) and v .ndim == 3 :
1003
- # Original format is (1, N, C) - need to reshape to (1, H, W, C)
1004
- num_patches = v .shape [1 ]
1005
- grid_size = int (math .sqrt (num_patches ))
1006
-
1007
- # Check if it's a perfect square for a standard grid
1008
- if grid_size * grid_size == num_patches :
1009
- # Reshape from (1, N, C) to (1, H, W, C)
1010
- v = v .reshape (1 , grid_size , grid_size , v .shape [2 ])
1011
- else :
1012
- # Not a square grid, we need to get the actual dimensions
1013
- if hasattr (model .embeds .patch_embed , 'grid_size' ):
1014
- h , w = model .embeds .patch_embed .grid_size
1015
- if h * w == num_patches :
1016
- # We have the right dimensions
1017
- v = v .reshape (1 , h , w , v .shape [2 ])
1018
- else :
1019
- # Dimensions don't match, use interpolation
1020
- _logger .warning (
1021
- f"Position embedding size mismatch: checkpoint={ num_patches } , model={ (h * w )} . "
1022
- f"Using default initialization and will resize in forward pass."
1023
- )
1024
- # Keep v as is, the forward pass will handle resizing
1025
-
1026
- out_dict ['embeds.pos_embed' ] = v
1027
-
1028
- elif k == 'cls_token' :
1029
- out_dict ['embeds.cls_token' ] = v
1030
- elif k == 'reg_token' :
1031
- out_dict ['embeds.reg_token' ] = v
1032
- # Convert patch_embed.X to embeds.patch_embed.X
1033
- elif k .startswith ('patch_embed.' ):
1034
- new_key = 'embeds.' + k [12 :]
1035
- out_dict [new_key ] = v
1036
- else :
1037
- out_dict [k ] = v
1038
-
1039
- # Call the original filter function to handle other patterns
1040
- return orig_filter_fn (out_dict , model )
1120
+
1121
+ @register_model
1122
+ def vit_naflex_base_patch16 (pretrained : bool = False , ** kwargs ):
1123
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
1124
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
1125
+ """
1126
+ model_args = dict (patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , global_pool = 'token' , class_token = True , pos_embed_grid_size = (14 , 14 ))
1127
+ model = _create_vision_transformer_flex ('vit_naflex_base_patch16' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1128
+ return model
0 commit comments