@@ -52,6 +52,7 @@ def batch_patchify(
5252
5353 nh , nw = H // ph , W // pw
5454 patches = x .view (B , C , nh , ph , nw , pw ).permute (0 , 2 , 4 , 3 , 5 , 1 ).reshape (B , nh * nw , ph * pw * C )
55+ # FIXME confirm we want 'channels last' in the patch channel layout, egg ph, ph, C instead of C, ph, hw
5556
5657 return patches , (nh , nw )
5758
@@ -297,7 +298,7 @@ def _apply_learned_naflex_pos_embed(
297298 size_to_indices [k ].append (bi )
298299
299300 # Handle each batch element separately with its own grid size
300- pos_embed_nchw = self .pos_embed .permute (0 , 3 , 1 , 2 ) # B,C,H,W
301+ pos_embed_nchw = self .pos_embed .permute (0 , 3 , 1 , 2 ). float () # B,C,H,W
301302 for k , batch_indices in size_to_indices .items ():
302303 h , w = k
303304 #h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
@@ -312,9 +313,14 @@ def _apply_learned_naflex_pos_embed(
312313 align_corners = False ,
313314 antialias = True ,
314315 ).flatten (2 ).transpose (1 , 2 )
316+ pos_embed_flat = pos_embed_flat .to (dtype = x .dtype )
315317
316318 seq_len = min (x .shape [1 ], pos_embed_flat .shape [1 ])
317- x [batch_indices , :seq_len ].add_ (pos_embed_flat [:, :seq_len ])
319+ x [:, :seq_len ].index_add_ (
320+ 0 ,
321+ torch .as_tensor (batch_indices , device = x .device ),
322+ pos_embed_flat [:, :seq_len ].expand (len (batch_indices ), - 1 , - 1 )
323+ )
318324
319325 def _apply_learned_pos_embed (
320326 self ,
@@ -328,12 +334,13 @@ def _apply_learned_pos_embed(
328334 else :
329335 # Resize if needed - directly using F.interpolate
330336 pos_embed_flat = F .interpolate (
331- self .pos_embed .permute (0 , 3 , 1 , 2 ), # B,C,H,W
337+ self .pos_embed .permute (0 , 3 , 1 , 2 ). float () , # B,C,H,W
332338 size = grid_size ,
333339 mode = self .pos_embed_interp_mode ,
334340 align_corners = False ,
335341 antialias = True ,
336342 ).flatten (2 ).transpose (1 , 2 )
343+ pos_embed_flat = pos_embed_flat .to (dtype = x .dtype )
337344
338345 x .add_ (pos_embed_flat )
339346
@@ -806,21 +813,20 @@ def _pool(
806813 # Apply the mask to extract only valid tokens
807814 x = x [:, self .num_prefix_tokens :] # prefix tokens not included in pooling
808815
816+ patch_valid_float = patch_valid .to (x .dtype )
809817 if pool_type == 'avg' :
810- # Compute masked average pooling
811- # Sum valid tokens and divide by count of valid tokens
812- masked_sums = (x * patch_valid .unsqueeze (- 1 ).float ()).sum (dim = 1 )
813- valid_counts = patch_valid .float ().sum (dim = 1 , keepdim = True ).clamp (min = 1 )
818+ # Compute masked average pooling, sum valid tokens and divide by count of valid tokens
819+ masked_sums = (x * patch_valid_float .unsqueeze (- 1 )).sum (dim = 1 )
820+ valid_counts = patch_valid_float .sum (dim = 1 , keepdim = True ).clamp (min = 1 )
814821 pooled = masked_sums / valid_counts
815822 return pooled
816823 elif pool_type == 'avgmax' :
817824 # For avgmax, compute masked average and masked max
818- # For max, we set masked positions to large negative value
819- masked_sums = (x * patch_valid .unsqueeze (- 1 ).float ()).sum (dim = 1 )
820- valid_counts = patch_valid .float ().sum (dim = 1 , keepdim = True ).clamp (min = 1 )
825+ masked_sums = (x * patch_valid_float .unsqueeze (- 1 )).sum (dim = 1 )
826+ valid_counts = patch_valid_float .sum (dim = 1 , keepdim = True ).clamp (min = 1 )
821827 masked_avg = masked_sums / valid_counts
822828
823- # For max pooling with mask
829+ # For max pooling we set masked positions to large negative value
824830 masked_x = x .clone ()
825831 masked_x [~ patch_valid ] = torch .finfo (masked_x .dtype ).min
826832 masked_max = masked_x .max (dim = 1 )[0 ]
@@ -915,6 +921,82 @@ def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
915921 return init_weights_vit_timm
916922
917923
924+ def checkpoint_filter_fn (state_dict , model ):
925+ """Handle state dict conversion from original ViT to the new version with combined embedding."""
926+ from .vision_transformer import checkpoint_filter_fn as orig_filter_fn
927+
928+ # Handle CombinedEmbed module pattern
929+ out_dict = {}
930+ for k , v in state_dict .items ():
931+ # Convert tokens and embeddings to combined_embed structure
932+ if k == 'pos_embed' :
933+ # Handle position embedding format conversion - from (1, N, C) to (1, H, W, C)
934+ if hasattr (model .embeds , 'pos_embed' ) and v .ndim == 3 :
935+ num_cls_token = 0
936+ num_reg_token = 0
937+ if 'reg_token' in state_dict :
938+ num_reg_token = state_dict ['reg_token' ].shape [1 ]
939+ if 'cls_token' in state_dict :
940+ num_cls_token = state_dict ['cls_token' ].shape [1 ]
941+ num_prefix_tokens = num_cls_token + num_reg_token
942+
943+ # Original format is (1, N, C), need to reshape to (1, H, W, C)
944+ num_patches = v .shape [1 ]
945+ num_patches_no_prefix = num_patches - num_prefix_tokens
946+ grid_size_no_prefix = math .sqrt (num_patches_no_prefix )
947+ grid_size = math .sqrt (num_patches )
948+ if (grid_size_no_prefix != grid_size and (
949+ grid_size_no_prefix .is_integer () and not grid_size .is_integer ())):
950+ # make a decision, did the pos_embed of the original include the prefix tokens?
951+ num_patches = num_patches_no_prefix
952+ cls_token_emb = v [:, 0 :num_cls_token ]
953+ if cls_token_emb .numel ():
954+ state_dict ['cls_token' ] += cls_token_emb
955+ reg_token_emb = v [:, num_cls_token :num_reg_token ]
956+ if reg_token_emb .numel ():
957+ state_dict ['reg_token' ] += reg_token_emb
958+ v = v [:, num_prefix_tokens :]
959+ grid_size = grid_size_no_prefix
960+ grid_size = int (grid_size )
961+
962+ # Check if it's a perfect square for a standard grid
963+ if grid_size * grid_size == num_patches :
964+ # Reshape from (1, N, C) to (1, H, W, C)
965+ v = v .reshape (1 , grid_size , grid_size , v .shape [2 ])
966+ else :
967+ # Not a square grid, we need to get the actual dimensions
968+ if hasattr (model .embeds .patch_embed , 'grid_size' ):
969+ h , w = model .embeds .patch_embed .grid_size
970+ if h * w == num_patches :
971+ # We have the right dimensions
972+ v = v .reshape (1 , h , w , v .shape [2 ])
973+ else :
974+ # Dimensions don't match, use interpolation
975+ _logger .warning (
976+ f"Position embedding size mismatch: checkpoint={ num_patches } , model={ (h * w )} . "
977+ f"Using default initialization and will resize in forward pass."
978+ )
979+ # Keep v as is, the forward pass will handle resizing
980+
981+ out_dict ['embeds.pos_embed' ] = v
982+ elif k == 'cls_token' :
983+ out_dict ['embeds.cls_token' ] = v
984+ elif k == 'reg_token' :
985+ out_dict ['embeds.reg_token' ] = v
986+ # Convert patch_embed.X to embeds.patch_embed.X
987+ elif k .startswith ('patch_embed.' ):
988+ suffix = k [12 :]
989+ if suffix == 'proj.weight' :
990+ # FIXME confirm patchify memory layout across use cases
991+ v = v .permute (0 , 2 , 3 , 1 ).flatten (1 )
992+ new_key = 'embeds.' + suffix
993+ out_dict [new_key ] = v
994+ else :
995+ out_dict [k ] = v
996+
997+ return out_dict
998+
999+
9181000def _cfg (url : str = '' , ** kwargs ) -> Dict [str , Any ]:
9191001 return {
9201002 'url' : url ,
@@ -936,6 +1018,26 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
9361018 'vit_naflex_base_patch16' : _cfg (),
9371019 'vit_naflex_base_patch16_gap' : _cfg (),
9381020 'vit_naflex_base_patch16_map' : _cfg (),
1021+
1022+ # sbb model testijg
1023+ 'vit_naflex_mediumd_patch16_reg4_gap.sbb2_r256_e200_in12k_ft_in1k' : _cfg (
1024+ hf_hub_id = 'timm/vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k' ,
1025+ input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1026+ 'vit_naflex_so150m2_patch16_reg1_gap.sbb_r256_e200_in12k_ft_in1k' : _cfg (
1027+ hf_hub_id = 'timm/vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k_ft_in1k' ,
1028+ input_size = (3 , 256 , 256 ), crop_pct = 1.0 ),
1029+ 'vit_naflex_so150m2_patch16_reg1_gap.sbb_r384_e200_in12k_ft_in1k' : _cfg (
1030+ hf_hub_id = 'timm/vit_so150m2_patch16_reg1_gap_384.sbb_e200_in12k_ft_in1k' ,
1031+ input_size = (3 , 384 , 384 ), crop_pct = 1.0 ),
1032+ 'vit_naflex_so150m2_patch16_reg1_gap.sbb_r448_e200_in12k_ft_in1k' : _cfg (
1033+ hf_hub_id = 'timm/vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k' ,
1034+ input_size = (3 , 448 , 448 ), crop_pct = 1.0 , crop_mode = 'squash' ),
1035+
1036+ # traditional vit testing
1037+ 'vit_naflex_base_patch16.augreg2_r224_in21k_ft_in1k' : _cfg (
1038+ hf_hub_id = 'timm/vit_base_patch16_224.augreg2_in21k_ft_in1k' ),
1039+ 'vit_naflex_base_patch8.augreg2_r224_in21k_ft_in1k' : _cfg (
1040+ hf_hub_id = 'timm/vit_base_patch16_224.augreg2_in21k_ft_in1k' ),
9391041})
9401042
9411043
@@ -948,10 +1050,22 @@ def _create_vision_transformer_flex(variant, pretrained=False, **kwargs):
9481050 return model
9491051
9501052
1053+ @register_model
1054+ def vit_naflex_mediumd_patch16_reg4_gap (pretrained = False , ** kwargs ):
1055+ """ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
1056+ """
1057+ model_args = dict (
1058+ patch_size = 16 , embed_dim = 512 , depth = 20 , num_heads = 8 , init_values = 1e-5 ,
1059+ global_pool = 'avg' , class_token = False , reg_tokens = 4 , fc_norm = True , ** kwargs )
1060+ model = _create_vision_transformer_flex (
1061+ 'vit_naflex_mediumd_patch16_reg4_gap' , pretrained = pretrained , ** model_args )
1062+ return model
1063+
1064+
9511065@register_model
9521066def vit_naflex_base_patch16 (pretrained = False , ** kwargs ):
9531067 """ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
954-
1068+
9551069 This model supports:
9561070 1. Variable aspect ratios and resolutions via patch coordinates
9571071 2. Position embedding interpolation for arbitrary grid sizes
@@ -987,54 +1101,28 @@ def vit_naflex_base_patch16_map(pretrained=False, **kwargs):
9871101 return model
9881102
9891103
990- def checkpoint_filter_fn ( state_dict , model ):
991- """Handle state dict conversion from original ViT to the new version with combined embedding."""
992- from . vision_transformer import checkpoint_filter_fn as orig_filter_fn
1104+ @ register_model
1105+ def vit_naflex_so150m2_patch16_reg1_gap ( pretrained = False , ** kwargs ):
1106+ """ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
9931107
994- # FIXME conversion of existing vit checkpoints has not been finished or tested
1108+ This model supports:
1109+ 1. Variable aspect ratios and resolutions via patch coordinates
1110+ 2. Position embedding interpolation for arbitrary grid sizes
1111+ 3. Explicit patch coordinates and valid token masking
1112+ """
1113+ model_args = dict (
1114+ patch_size = 16 , embed_dim = 832 , depth = 21 , num_heads = 13 , mlp_ratio = 34 / 13 , init_values = 1e-5 ,
1115+ qkv_bias = False , class_token = False , reg_tokens = 1 , global_pool = 'avg' , fc_norm = True , ** kwargs )
1116+ model = _create_vision_transformer_flex (
1117+ 'vit_naflex_so150m2_patch16_reg1_gap' , pretrained = pretrained , ** model_args )
1118+ return model
9951119
996- # Handle CombinedEmbed module pattern
997- out_dict = {}
998- for k , v in state_dict .items ():
999- # Convert tokens and embeddings to combined_embed structure
1000- if k == 'pos_embed' :
1001- # Handle position embedding format conversion - from (1, N, C) to (1, H, W, C)
1002- if hasattr (model .embeds , 'pos_embed' ) and v .ndim == 3 :
1003- # Original format is (1, N, C) - need to reshape to (1, H, W, C)
1004- num_patches = v .shape [1 ]
1005- grid_size = int (math .sqrt (num_patches ))
1006-
1007- # Check if it's a perfect square for a standard grid
1008- if grid_size * grid_size == num_patches :
1009- # Reshape from (1, N, C) to (1, H, W, C)
1010- v = v .reshape (1 , grid_size , grid_size , v .shape [2 ])
1011- else :
1012- # Not a square grid, we need to get the actual dimensions
1013- if hasattr (model .embeds .patch_embed , 'grid_size' ):
1014- h , w = model .embeds .patch_embed .grid_size
1015- if h * w == num_patches :
1016- # We have the right dimensions
1017- v = v .reshape (1 , h , w , v .shape [2 ])
1018- else :
1019- # Dimensions don't match, use interpolation
1020- _logger .warning (
1021- f"Position embedding size mismatch: checkpoint={ num_patches } , model={ (h * w )} . "
1022- f"Using default initialization and will resize in forward pass."
1023- )
1024- # Keep v as is, the forward pass will handle resizing
1025-
1026- out_dict ['embeds.pos_embed' ] = v
1027-
1028- elif k == 'cls_token' :
1029- out_dict ['embeds.cls_token' ] = v
1030- elif k == 'reg_token' :
1031- out_dict ['embeds.reg_token' ] = v
1032- # Convert patch_embed.X to embeds.patch_embed.X
1033- elif k .startswith ('patch_embed.' ):
1034- new_key = 'embeds.' + k [12 :]
1035- out_dict [new_key ] = v
1036- else :
1037- out_dict [k ] = v
1038-
1039- # Call the original filter function to handle other patterns
1040- return orig_filter_fn (out_dict , model )
1120+
1121+ @register_model
1122+ def vit_naflex_base_patch16 (pretrained : bool = False , ** kwargs ):
1123+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
1124+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
1125+ """
1126+ model_args = dict (patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , global_pool = 'token' , class_token = True , pos_embed_grid_size = (14 , 14 ))
1127+ model = _create_vision_transformer_flex ('vit_naflex_base_patch16' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1128+ return model
0 commit comments