Skip to content

Commit a1f379e

Browse files
committed
Add intern300m vit w/ converted timm weights. Fix #2300
1 parent 60f517c commit a1f379e

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

timm/models/davit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,10 +710,12 @@ def checkpoint_filter_fn(state_dict, model):
710710
def _create_davit(variant, pretrained=False, **kwargs):
711711
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
712712
out_indices = kwargs.pop('out_indices', default_out_indices)
713-
strict = True
713+
714+
strict = kwargs.pop('pretrained_strict', True)
714715
if variant.endswith('_fl'):
715716
# FIXME cleaner approach to missing head norm?
716717
strict = False
718+
717719
model = build_model_with_cfg(
718720
DaVit,
719721
variant,

timm/models/vision_transformer.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,7 @@ def __init__(
438438
no_embed_class: bool = False,
439439
reg_tokens: int = 0,
440440
pre_norm: bool = False,
441+
final_norm: bool = True,
441442
fc_norm: Optional[bool] = None,
442443
dynamic_img_size: bool = False,
443444
dynamic_img_pad: bool = False,
@@ -471,7 +472,9 @@ def __init__(
471472
class_token: Use class token.
472473
no_embed_class: Don't include position embeddings for class (or reg) tokens.
473474
reg_tokens: Number of register tokens.
474-
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
475+
pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT).
476+
final_norm: Enable norm after transformer blocks, before head (standard in most ViT).
477+
fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
475478
drop_rate: Head dropout rate.
476479
pos_drop_rate: Position embedding dropout rate.
477480
attn_drop_rate: Attention dropout rate.
@@ -554,7 +557,7 @@ def __init__(
554557
for i in range(depth)])
555558
self.feature_info = [
556559
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)]
557-
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
560+
self.norm = norm_layer(embed_dim) if final_norm and not use_fc_norm else nn.Identity()
558561

559562
# Classifier Head
560563
if global_pool == 'map':
@@ -566,7 +569,7 @@ def __init__(
566569
)
567570
else:
568571
self.attn_pool = None
569-
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
572+
self.fc_norm = norm_layer(embed_dim) if final_norm and use_fc_norm else nn.Identity()
570573
self.head_drop = nn.Dropout(drop_rate)
571574
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
572575

@@ -2051,6 +2054,12 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
20512054
'vit_so150m_patch16_reg4_map_256.untrained': _cfg(
20522055
input_size=(3, 256, 256)),
20532056

2057+
'vit_intern300m_patch14_448.ogvl_dist': _cfg(
2058+
hf_hub_id='timm/',
2059+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
2060+
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0,
2061+
),
2062+
20542063
'test_vit.r160_in1k': _cfg(
20552064
hf_hub_id='timm/',
20562065
input_size=(3, 160, 160), crop_pct=0.95),
@@ -2091,7 +2100,7 @@ def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs)
20912100
_filter_fn = checkpoint_filter_fn
20922101

20932102
# FIXME attn pool (currently only in siglip) params removed if pool disabled, is there a better soln?
2094-
strict = True
2103+
strict = kwargs.pop('pretrained_strict', True)
20952104
if 'siglip' in variant and kwargs.get('global_pool', None) != 'map':
20962105
strict = False
20972106

@@ -3298,6 +3307,17 @@ def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
32983307
return model
32993308

33003309

3310+
@register_model
3311+
def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
3312+
model_args = dict(
3313+
patch_size=14, embed_dim=1024, depth=24, num_heads=16,
3314+
init_values=0.1, final_norm=False, dynamic_img_size=True,
3315+
)
3316+
model = _create_vision_transformer(
3317+
'vit_intern300m_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
3318+
return model
3319+
3320+
33013321
@register_model
33023322
def test_vit(pretrained: bool = False, **kwargs) -> VisionTransformer:
33033323
""" ViT Test

0 commit comments

Comments
 (0)