@@ -438,6 +438,7 @@ def __init__(
438
438
no_embed_class : bool = False ,
439
439
reg_tokens : int = 0 ,
440
440
pre_norm : bool = False ,
441
+ final_norm : bool = True ,
441
442
fc_norm : Optional [bool ] = None ,
442
443
dynamic_img_size : bool = False ,
443
444
dynamic_img_pad : bool = False ,
@@ -471,7 +472,9 @@ def __init__(
471
472
class_token: Use class token.
472
473
no_embed_class: Don't include position embeddings for class (or reg) tokens.
473
474
reg_tokens: Number of register tokens.
474
- fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
475
+ pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT).
476
+ final_norm: Enable norm after transformer blocks, before head (standard in most ViT).
477
+ fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
475
478
drop_rate: Head dropout rate.
476
479
pos_drop_rate: Position embedding dropout rate.
477
480
attn_drop_rate: Attention dropout rate.
@@ -554,7 +557,7 @@ def __init__(
554
557
for i in range (depth )])
555
558
self .feature_info = [
556
559
dict (module = f'blocks.{ i } ' , num_chs = embed_dim , reduction = reduction ) for i in range (depth )]
557
- self .norm = norm_layer (embed_dim ) if not use_fc_norm else nn .Identity ()
560
+ self .norm = norm_layer (embed_dim ) if final_norm and not use_fc_norm else nn .Identity ()
558
561
559
562
# Classifier Head
560
563
if global_pool == 'map' :
@@ -566,7 +569,7 @@ def __init__(
566
569
)
567
570
else :
568
571
self .attn_pool = None
569
- self .fc_norm = norm_layer (embed_dim ) if use_fc_norm else nn .Identity ()
572
+ self .fc_norm = norm_layer (embed_dim ) if final_norm and use_fc_norm else nn .Identity ()
570
573
self .head_drop = nn .Dropout (drop_rate )
571
574
self .head = nn .Linear (self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
572
575
@@ -2051,6 +2054,12 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
2051
2054
'vit_so150m_patch16_reg4_map_256.untrained' : _cfg (
2052
2055
input_size = (3 , 256 , 256 )),
2053
2056
2057
+ 'vit_intern300m_patch14_448.ogvl_dist' : _cfg (
2058
+ hf_hub_id = 'timm/' ,
2059
+ mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ,
2060
+ input_size = (3 , 448 , 448 ), crop_pct = 1.0 , num_classes = 0 ,
2061
+ ),
2062
+
2054
2063
'test_vit.r160_in1k' : _cfg (
2055
2064
hf_hub_id = 'timm/' ,
2056
2065
input_size = (3 , 160 , 160 ), crop_pct = 0.95 ),
@@ -2091,7 +2100,7 @@ def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs)
2091
2100
_filter_fn = checkpoint_filter_fn
2092
2101
2093
2102
# FIXME attn pool (currently only in siglip) params removed if pool disabled, is there a better soln?
2094
- strict = True
2103
+ strict = kwargs . pop ( 'pretrained_strict' , True )
2095
2104
if 'siglip' in variant and kwargs .get ('global_pool' , None ) != 'map' :
2096
2105
strict = False
2097
2106
@@ -3298,6 +3307,17 @@ def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
3298
3307
return model
3299
3308
3300
3309
3310
+ @register_model
3311
+ def vit_intern300m_patch14_448 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
3312
+ model_args = dict (
3313
+ patch_size = 14 , embed_dim = 1024 , depth = 24 , num_heads = 16 ,
3314
+ init_values = 0.1 , final_norm = False , dynamic_img_size = True ,
3315
+ )
3316
+ model = _create_vision_transformer (
3317
+ 'vit_intern300m_patch14_448' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
3318
+ return model
3319
+
3320
+
3301
3321
@register_model
3302
3322
def test_vit (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
3303
3323
""" ViT Test
0 commit comments