Skip to content

Commit 65e8e9c

Browse files
authored
Merge pull request #2304 from huggingface/intern300m
Add intern300m vit w/ converted timm weights. Fix #2300
2 parents 60f517c + 89dffc5 commit 65e8e9c

File tree

10 files changed

+35
-13
lines changed

10 files changed

+35
-13
lines changed

timm/models/davit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,10 +710,12 @@ def checkpoint_filter_fn(state_dict, model):
710710
def _create_davit(variant, pretrained=False, **kwargs):
711711
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
712712
out_indices = kwargs.pop('out_indices', default_out_indices)
713-
strict = True
713+
714+
strict = kwargs.pop('pretrained_strict', True)
714715
if variant.endswith('_fl'):
715716
# FIXME cleaner approach to missing head norm?
716717
strict = False
718+
717719
model = build_model_with_cfg(
718720
DaVit,
719721
variant,

timm/models/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from ._factory import *
22

33
import warnings
4-
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
4+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)

timm/models/features.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from ._features import *
22

33
import warnings
4-
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
4+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)

timm/models/fx_features.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from ._features_fx import *
22

33
import warnings
4-
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
4+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)

timm/models/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
from ._prune import *
55

66
import warnings
7-
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
7+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)

timm/models/hub.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from ._hub import *
22

33
import warnings
4-
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
4+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)

timm/models/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,4 @@
4545
from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
4646

4747
import warnings
48-
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", DeprecationWarning)
48+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)

timm/models/mambaout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def __init__(
151151
self.num_features = in_features
152152
self.pre_logits = nn.Identity()
153153

154-
self.fc = nn.Linear(hidden_size, num_classes, bias=bias)
154+
self.fc = nn.Linear(hidden_size, num_classes, bias=bias) if num_classes > 0 else nn.Identity()
155155
self.head_dropout = nn.Dropout(drop_rate)
156156

157157
def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):

timm/models/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from ._registry import *
22

33
import warnings
4-
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
4+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)

timm/models/vision_transformer.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,7 @@ def __init__(
438438
no_embed_class: bool = False,
439439
reg_tokens: int = 0,
440440
pre_norm: bool = False,
441+
final_norm: bool = True,
441442
fc_norm: Optional[bool] = None,
442443
dynamic_img_size: bool = False,
443444
dynamic_img_pad: bool = False,
@@ -471,7 +472,9 @@ def __init__(
471472
class_token: Use class token.
472473
no_embed_class: Don't include position embeddings for class (or reg) tokens.
473474
reg_tokens: Number of register tokens.
474-
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
475+
pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT).
476+
final_norm: Enable norm after transformer blocks, before head (standard in most ViT).
477+
fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
475478
drop_rate: Head dropout rate.
476479
pos_drop_rate: Position embedding dropout rate.
477480
attn_drop_rate: Attention dropout rate.
@@ -554,7 +557,7 @@ def __init__(
554557
for i in range(depth)])
555558
self.feature_info = [
556559
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)]
557-
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
560+
self.norm = norm_layer(embed_dim) if final_norm and not use_fc_norm else nn.Identity()
558561

559562
# Classifier Head
560563
if global_pool == 'map':
@@ -566,7 +569,7 @@ def __init__(
566569
)
567570
else:
568571
self.attn_pool = None
569-
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
572+
self.fc_norm = norm_layer(embed_dim) if final_norm and use_fc_norm else nn.Identity()
570573
self.head_drop = nn.Dropout(drop_rate)
571574
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
572575

@@ -2051,6 +2054,12 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
20512054
'vit_so150m_patch16_reg4_map_256.untrained': _cfg(
20522055
input_size=(3, 256, 256)),
20532056

2057+
'vit_intern300m_patch14_448.ogvl_dist': _cfg(
2058+
hf_hub_id='timm/',
2059+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
2060+
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0,
2061+
),
2062+
20542063
'test_vit.r160_in1k': _cfg(
20552064
hf_hub_id='timm/',
20562065
input_size=(3, 160, 160), crop_pct=0.95),
@@ -2091,7 +2100,7 @@ def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs)
20912100
_filter_fn = checkpoint_filter_fn
20922101

20932102
# FIXME attn pool (currently only in siglip) params removed if pool disabled, is there a better soln?
2094-
strict = True
2103+
strict = kwargs.pop('pretrained_strict', True)
20952104
if 'siglip' in variant and kwargs.get('global_pool', None) != 'map':
20962105
strict = False
20972106

@@ -3298,6 +3307,17 @@ def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
32983307
return model
32993308

33003309

3310+
@register_model
3311+
def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
3312+
model_args = dict(
3313+
patch_size=14, embed_dim=1024, depth=24, num_heads=16,
3314+
init_values=0.1, final_norm=False, dynamic_img_size=True,
3315+
)
3316+
model = _create_vision_transformer(
3317+
'vit_intern300m_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
3318+
return model
3319+
3320+
33013321
@register_model
33023322
def test_vit(pretrained: bool = False, **kwargs) -> VisionTransformer:
33033323
""" ViT Test

0 commit comments

Comments
 (0)