Skip to content

Commit 3dc90ed

Browse files
committed
Add naflex loader support to validate.py, fix bug in naflex pos embed add, classic vit weight loading for naflex model
1 parent c527c37 commit 3dc90ed

File tree

4 files changed

+198
-84
lines changed

4 files changed

+198
-84
lines changed

timm/data/naflex_transforms.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,6 @@ def patchify(
760760
nh, nw = h // ph, w // pw
761761
# Reshape image to patches [nh, nw, ph, pw, c]
762762
patches = img.view(c, nh, ph, nw, pw).permute(1, 3, 2, 4, 0).reshape(nh * nw, ph * pw * c)
763-
764763
if include_info:
765764
# Create coordinate indices
766765
y_idx, x_idx = torch.meshgrid(torch.arange(nh), torch.arange(nw), indexing='ij')

timm/data/transforms_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def transforms_imagenet_eval(
318318
tfl += [ResizeToSequence(
319319
patch_size=patch_size,
320320
max_seq_len=max_seq_len,
321-
interpolation=interpolation
321+
interpolation=interpolation,
322322
)]
323323
else:
324324
if crop_mode == 'squash':

timm/models/vision_transformer_flex.py

Lines changed: 149 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def batch_patchify(
5252

5353
nh, nw = H // ph, W // pw
5454
patches = x.view(B, C, nh, ph, nw, pw).permute(0, 2, 4, 3, 5, 1).reshape(B, nh * nw, ph * pw * C)
55+
# FIXME confirm we want 'channels last' in the patch channel layout, egg ph, ph, C instead of C, ph, hw
5556

5657
return patches, (nh, nw)
5758

@@ -297,7 +298,7 @@ def _apply_learned_naflex_pos_embed(
297298
size_to_indices[k].append(bi)
298299

299300
# Handle each batch element separately with its own grid size
300-
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2) # B,C,H,W
301+
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W
301302
for k, batch_indices in size_to_indices.items():
302303
h, w = k
303304
#h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
@@ -312,9 +313,14 @@ def _apply_learned_naflex_pos_embed(
312313
align_corners=False,
313314
antialias=True,
314315
).flatten(2).transpose(1, 2)
316+
pos_embed_flat = pos_embed_flat.to(dtype=x.dtype)
315317

316318
seq_len = min(x.shape[1], pos_embed_flat.shape[1])
317-
x[batch_indices, :seq_len].add_(pos_embed_flat[:, :seq_len])
319+
x[:, :seq_len].index_add_(
320+
0,
321+
torch.as_tensor(batch_indices, device=x.device),
322+
pos_embed_flat[:, :seq_len].expand(len(batch_indices), -1, -1)
323+
)
318324

319325
def _apply_learned_pos_embed(
320326
self,
@@ -328,12 +334,13 @@ def _apply_learned_pos_embed(
328334
else:
329335
# Resize if needed - directly using F.interpolate
330336
pos_embed_flat = F.interpolate(
331-
self.pos_embed.permute(0, 3, 1, 2), # B,C,H,W
337+
self.pos_embed.permute(0, 3, 1, 2).float(), # B,C,H,W
332338
size=grid_size,
333339
mode=self.pos_embed_interp_mode,
334340
align_corners=False,
335341
antialias=True,
336342
).flatten(2).transpose(1, 2)
343+
pos_embed_flat = pos_embed_flat.to(dtype=x.dtype)
337344

338345
x.add_(pos_embed_flat)
339346

@@ -806,21 +813,20 @@ def _pool(
806813
# Apply the mask to extract only valid tokens
807814
x = x[:, self.num_prefix_tokens:] # prefix tokens not included in pooling
808815

816+
patch_valid_float = patch_valid.to(x.dtype)
809817
if pool_type == 'avg':
810-
# Compute masked average pooling
811-
# Sum valid tokens and divide by count of valid tokens
812-
masked_sums = (x * patch_valid.unsqueeze(-1).float()).sum(dim=1)
813-
valid_counts = patch_valid.float().sum(dim=1, keepdim=True).clamp(min=1)
818+
# Compute masked average pooling, sum valid tokens and divide by count of valid tokens
819+
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
820+
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
814821
pooled = masked_sums / valid_counts
815822
return pooled
816823
elif pool_type == 'avgmax':
817824
# For avgmax, compute masked average and masked max
818-
# For max, we set masked positions to large negative value
819-
masked_sums = (x * patch_valid.unsqueeze(-1).float()).sum(dim=1)
820-
valid_counts = patch_valid.float().sum(dim=1, keepdim=True).clamp(min=1)
825+
masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
826+
valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
821827
masked_avg = masked_sums / valid_counts
822828

823-
# For max pooling with mask
829+
# For max pooling we set masked positions to large negative value
824830
masked_x = x.clone()
825831
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
826832
masked_max = masked_x.max(dim=1)[0]
@@ -915,6 +921,82 @@ def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
915921
return init_weights_vit_timm
916922

917923

924+
def checkpoint_filter_fn(state_dict, model):
925+
"""Handle state dict conversion from original ViT to the new version with combined embedding."""
926+
from .vision_transformer import checkpoint_filter_fn as orig_filter_fn
927+
928+
# Handle CombinedEmbed module pattern
929+
out_dict = {}
930+
for k, v in state_dict.items():
931+
# Convert tokens and embeddings to combined_embed structure
932+
if k == 'pos_embed':
933+
# Handle position embedding format conversion - from (1, N, C) to (1, H, W, C)
934+
if hasattr(model.embeds, 'pos_embed') and v.ndim == 3:
935+
num_cls_token = 0
936+
num_reg_token = 0
937+
if 'reg_token' in state_dict:
938+
num_reg_token = state_dict['reg_token'].shape[1]
939+
if 'cls_token' in state_dict:
940+
num_cls_token = state_dict['cls_token'].shape[1]
941+
num_prefix_tokens = num_cls_token + num_reg_token
942+
943+
# Original format is (1, N, C), need to reshape to (1, H, W, C)
944+
num_patches = v.shape[1]
945+
num_patches_no_prefix = num_patches - num_prefix_tokens
946+
grid_size_no_prefix = math.sqrt(num_patches_no_prefix)
947+
grid_size = math.sqrt(num_patches)
948+
if (grid_size_no_prefix != grid_size and (
949+
grid_size_no_prefix.is_integer() and not grid_size.is_integer())):
950+
# make a decision, did the pos_embed of the original include the prefix tokens?
951+
num_patches = num_patches_no_prefix
952+
cls_token_emb = v[:, 0:num_cls_token]
953+
if cls_token_emb.numel():
954+
state_dict['cls_token'] += cls_token_emb
955+
reg_token_emb = v[:, num_cls_token:num_reg_token]
956+
if reg_token_emb.numel():
957+
state_dict['reg_token'] += reg_token_emb
958+
v = v[:, num_prefix_tokens:]
959+
grid_size = grid_size_no_prefix
960+
grid_size = int(grid_size)
961+
962+
# Check if it's a perfect square for a standard grid
963+
if grid_size * grid_size == num_patches:
964+
# Reshape from (1, N, C) to (1, H, W, C)
965+
v = v.reshape(1, grid_size, grid_size, v.shape[2])
966+
else:
967+
# Not a square grid, we need to get the actual dimensions
968+
if hasattr(model.embeds.patch_embed, 'grid_size'):
969+
h, w = model.embeds.patch_embed.grid_size
970+
if h * w == num_patches:
971+
# We have the right dimensions
972+
v = v.reshape(1, h, w, v.shape[2])
973+
else:
974+
# Dimensions don't match, use interpolation
975+
_logger.warning(
976+
f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. "
977+
f"Using default initialization and will resize in forward pass."
978+
)
979+
# Keep v as is, the forward pass will handle resizing
980+
981+
out_dict['embeds.pos_embed'] = v
982+
elif k == 'cls_token':
983+
out_dict['embeds.cls_token'] = v
984+
elif k == 'reg_token':
985+
out_dict['embeds.reg_token'] = v
986+
# Convert patch_embed.X to embeds.patch_embed.X
987+
elif k.startswith('patch_embed.'):
988+
suffix = k[12:]
989+
if suffix == 'proj.weight':
990+
# FIXME confirm patchify memory layout across use cases
991+
v = v.permute(0, 2, 3, 1).flatten(1)
992+
new_key = 'embeds.' + suffix
993+
out_dict[new_key] = v
994+
else:
995+
out_dict[k] = v
996+
997+
return out_dict
998+
999+
9181000
def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
9191001
return {
9201002
'url': url,
@@ -936,6 +1018,26 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
9361018
'vit_naflex_base_patch16': _cfg(),
9371019
'vit_naflex_base_patch16_gap': _cfg(),
9381020
'vit_naflex_base_patch16_map': _cfg(),
1021+
1022+
# sbb model testijg
1023+
'vit_naflex_mediumd_patch16_reg4_gap.sbb2_r256_e200_in12k_ft_in1k': _cfg(
1024+
hf_hub_id='timm/vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k',
1025+
input_size=(3, 256, 256), crop_pct=0.95),
1026+
'vit_naflex_so150m2_patch16_reg1_gap.sbb_r256_e200_in12k_ft_in1k': _cfg(
1027+
hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k_ft_in1k',
1028+
input_size=(3, 256, 256), crop_pct=1.0),
1029+
'vit_naflex_so150m2_patch16_reg1_gap.sbb_r384_e200_in12k_ft_in1k': _cfg(
1030+
hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_384.sbb_e200_in12k_ft_in1k',
1031+
input_size=(3, 384, 384), crop_pct=1.0),
1032+
'vit_naflex_so150m2_patch16_reg1_gap.sbb_r448_e200_in12k_ft_in1k': _cfg(
1033+
hf_hub_id='timm/vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k',
1034+
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash'),
1035+
1036+
# traditional vit testing
1037+
'vit_naflex_base_patch16.augreg2_r224_in21k_ft_in1k': _cfg(
1038+
hf_hub_id='timm/vit_base_patch16_224.augreg2_in21k_ft_in1k'),
1039+
'vit_naflex_base_patch8.augreg2_r224_in21k_ft_in1k': _cfg(
1040+
hf_hub_id='timm/vit_base_patch16_224.augreg2_in21k_ft_in1k'),
9391041
})
9401042

9411043

@@ -948,10 +1050,22 @@ def _create_vision_transformer_flex(variant, pretrained=False, **kwargs):
9481050
return model
9491051

9501052

1053+
@register_model
1054+
def vit_naflex_mediumd_patch16_reg4_gap(pretrained=False, **kwargs):
1055+
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
1056+
"""
1057+
model_args = dict(
1058+
patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5,
1059+
global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs)
1060+
model = _create_vision_transformer_flex(
1061+
'vit_naflex_mediumd_patch16_reg4_gap', pretrained=pretrained, **model_args)
1062+
return model
1063+
1064+
9511065
@register_model
9521066
def vit_naflex_base_patch16(pretrained=False, **kwargs):
9531067
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
954-
1068+
9551069
This model supports:
9561070
1. Variable aspect ratios and resolutions via patch coordinates
9571071
2. Position embedding interpolation for arbitrary grid sizes
@@ -987,54 +1101,28 @@ def vit_naflex_base_patch16_map(pretrained=False, **kwargs):
9871101
return model
9881102

9891103

990-
def checkpoint_filter_fn(state_dict, model):
991-
"""Handle state dict conversion from original ViT to the new version with combined embedding."""
992-
from .vision_transformer import checkpoint_filter_fn as orig_filter_fn
1104+
@register_model
1105+
def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs):
1106+
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
9931107
994-
# FIXME conversion of existing vit checkpoints has not been finished or tested
1108+
This model supports:
1109+
1. Variable aspect ratios and resolutions via patch coordinates
1110+
2. Position embedding interpolation for arbitrary grid sizes
1111+
3. Explicit patch coordinates and valid token masking
1112+
"""
1113+
model_args = dict(
1114+
patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
1115+
qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True, **kwargs)
1116+
model = _create_vision_transformer_flex(
1117+
'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **model_args)
1118+
return model
9951119

996-
# Handle CombinedEmbed module pattern
997-
out_dict = {}
998-
for k, v in state_dict.items():
999-
# Convert tokens and embeddings to combined_embed structure
1000-
if k == 'pos_embed':
1001-
# Handle position embedding format conversion - from (1, N, C) to (1, H, W, C)
1002-
if hasattr(model.embeds, 'pos_embed') and v.ndim == 3:
1003-
# Original format is (1, N, C) - need to reshape to (1, H, W, C)
1004-
num_patches = v.shape[1]
1005-
grid_size = int(math.sqrt(num_patches))
1006-
1007-
# Check if it's a perfect square for a standard grid
1008-
if grid_size * grid_size == num_patches:
1009-
# Reshape from (1, N, C) to (1, H, W, C)
1010-
v = v.reshape(1, grid_size, grid_size, v.shape[2])
1011-
else:
1012-
# Not a square grid, we need to get the actual dimensions
1013-
if hasattr(model.embeds.patch_embed, 'grid_size'):
1014-
h, w = model.embeds.patch_embed.grid_size
1015-
if h * w == num_patches:
1016-
# We have the right dimensions
1017-
v = v.reshape(1, h, w, v.shape[2])
1018-
else:
1019-
# Dimensions don't match, use interpolation
1020-
_logger.warning(
1021-
f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. "
1022-
f"Using default initialization and will resize in forward pass."
1023-
)
1024-
# Keep v as is, the forward pass will handle resizing
1025-
1026-
out_dict['embeds.pos_embed'] = v
1027-
1028-
elif k == 'cls_token':
1029-
out_dict['embeds.cls_token'] = v
1030-
elif k == 'reg_token':
1031-
out_dict['embeds.reg_token'] = v
1032-
# Convert patch_embed.X to embeds.patch_embed.X
1033-
elif k.startswith('patch_embed.'):
1034-
new_key = 'embeds.' + k[12:]
1035-
out_dict[new_key] = v
1036-
else:
1037-
out_dict[k] = v
1038-
1039-
# Call the original filter function to handle other patterns
1040-
return orig_filter_fn(out_dict, model)
1120+
1121+
@register_model
1122+
def vit_naflex_base_patch16(pretrained: bool = False, **kwargs):
1123+
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
1124+
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
1125+
"""
1126+
model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, global_pool='token', class_token=True, pos_embed_grid_size=(14, 14))
1127+
model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
1128+
return model

0 commit comments

Comments
 (0)