Skip to content

Commit 1d6ebeb

Browse files
committed
Add (almost) full set of aimv2 model instances. Switch back to unpacked SwiGLU. Verify correctness. Add DFN L/14 39B weight.
1 parent a4146b7 commit 1d6ebeb

File tree

1 file changed

+250
-20
lines changed

1 file changed

+250
-20
lines changed

timm/models/vision_transformer.py

Lines changed: 250 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@
4242

4343
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
4444
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
45-
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
45+
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \
4646
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
47-
SwiGLU, get_act_layer, get_norm_layer, LayerType
47+
get_act_layer, get_norm_layer, LayerType
4848
from ._builder import build_model_with_cfg
4949
from ._features import feature_take_indices
5050
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
@@ -1159,13 +1159,16 @@ def _convert_aimv2(
11591159
k = k.replace('trunk.', '')
11601160
k = k.replace('post_trunk_norm.', 'norm.')
11611161

1162-
if 'mlp.fc1' in k:
1163-
if k in out_dict:
1164-
v = torch.cat([v, out_dict[k]], dim=0)
1165-
elif 'mlp.fc3' in k:
1166-
k = k.replace('mlp.fc3', 'mlp.fc1')
1167-
if k in out_dict:
1168-
v = torch.cat([out_dict[k], v], dim=0)
1162+
# packed ver, FIXME to delete
1163+
# if 'mlp.fc1' in k:
1164+
# if k in out_dict:
1165+
# v = torch.cat([v, out_dict[k]], dim=0)
1166+
# elif 'mlp.fc3' in k:
1167+
# k = k.replace('mlp.fc3', 'mlp.fc1')
1168+
# if k in out_dict:
1169+
# v = torch.cat([out_dict[k], v], dim=0)
1170+
k = k.replace('mlp.fc1', 'mlp.fc1_g')
1171+
k = k.replace('mlp.fc3', 'mlp.fc1_x')
11691172

11701173
out_dict[k] = v
11711174

@@ -1682,18 +1685,27 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
16821685

16831686
'vit_base_patch16_clip_224.dfn2b': _cfg(
16841687
hf_hub_id='timm/',
1688+
license='apple-ascl',
16851689
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
1690+
'vit_large_patch14_clip_224.dfn2b_s39b': _cfg(
1691+
#hf_hub_id='timm/',
1692+
hf_hub_id='apple/DFN2B-CLIP-ViT-L-14-39B', hf_hub_filename='open_clip_pytorch_model.bin',
1693+
license='apple-ascl',
1694+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
16861695
'vit_large_patch14_clip_224.dfn2b': _cfg(
16871696
hf_hub_id='timm/',
1697+
license='apple-ascl',
16881698
notes=('natively QuickGELU, use quickgelu model variant for original results',),
16891699
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
16901700
'vit_huge_patch14_clip_224.dfn5b': _cfg(
16911701
hf_hub_id='timm/',
1702+
license='apple-ascl',
16921703
notes=('natively QuickGELU, use quickgelu model variant for original results',),
16931704
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
16941705
'vit_huge_patch14_clip_378.dfn5b': _cfg(
16951706
hf_hub_id='timm/',
16961707
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
1708+
license='apple-ascl',
16971709
notes=('natively QuickGELU, use quickgelu model variant for original results',),
16981710
crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024),
16991711

@@ -2164,11 +2176,62 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
21642176
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0,
21652177
),
21662178

2167-
'vit_large_patch14_aimv2_224': _cfg(
2179+
'aimv2_large_patch14_224.apple_pt': _cfg(
21682180
hf_hub_id='apple/aimv2-large-patch14-224',
2169-
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
2170-
input_size=(3, 224, 224), crop_pct=1.0,
2171-
num_classes=0),
2181+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
2182+
crop_pct=1.0, num_classes=0),
2183+
'aimv2_large_patch14_224.apple_pt_dist': _cfg(
2184+
hf_hub_id='apple/aimv2-large-patch14-224-distilled',
2185+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
2186+
crop_pct=1.0, num_classes=0),
2187+
'aimv2_huge_patch14_224.apple_pt': _cfg(
2188+
hf_hub_id='apple/aimv2-huge-patch14-224',
2189+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
2190+
crop_pct=1.0, num_classes=0),
2191+
'aimv2_1b_patch14_224.apple_pt': _cfg(
2192+
hf_hub_id='apple/aimv2-1b-patch14-224',
2193+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
2194+
crop_pct=1.0, num_classes=0),
2195+
'aimv2_3b_patch14_224.apple_pt': _cfg(
2196+
hf_hub_id='apple/aimv2-3b-patch14-224',
2197+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
2198+
crop_pct=1.0, num_classes=0),
2199+
'aimv2_large_patch14_336.apple_pt': _cfg(
2200+
hf_hub_id='apple/aimv2-large-patch14-336',
2201+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
2202+
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
2203+
'aimv2_large_patch14_336.apple_pt_dist': _cfg(
2204+
hf_hub_id='apple/aimv2-large-patch14-336-distilled',
2205+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
2206+
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
2207+
'aimv2_huge_patch14_336.apple_pt': _cfg(
2208+
hf_hub_id='apple/aimv2-huge-patch14-336',
2209+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
2210+
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
2211+
'aimv2_1b_patch14_336.apple_pt': _cfg(
2212+
hf_hub_id='apple/aimv2-1b-patch14-336',
2213+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
2214+
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
2215+
'aimv2_3b_patch14_336.apple_pt': _cfg(
2216+
hf_hub_id='apple/aimv2-3b-patch14-336',
2217+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
2218+
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
2219+
'aimv2_large_patch14_448.apple_pt': _cfg(
2220+
hf_hub_id='apple/aimv2-large-patch14-448',
2221+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
2222+
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
2223+
'aimv2_huge_patch14_448.apple_pt': _cfg(
2224+
hf_hub_id='apple/aimv2-huge-patch14-448',
2225+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
2226+
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
2227+
'aimv2_1b_patch14_448.apple_pt': _cfg(
2228+
hf_hub_id='apple/aimv2-1b-patch14-448',
2229+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
2230+
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
2231+
'aimv2_3b_patch14_448.apple_pt': _cfg(
2232+
hf_hub_id='apple/aimv2-3b-patch14-448',
2233+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
2234+
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
21722235

21732236
'test_vit.r160_in1k': _cfg(
21742237
hf_hub_id='timm/',
@@ -3442,17 +3505,171 @@ def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTran
34423505

34433506

34443507
@register_model
3445-
def vit_large_patch14_aimv2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
3446-
""" ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
3508+
def aimv2_large_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
3509+
""" ViT Large AIM-v2 model
3510+
"""
3511+
model_args = dict(
3512+
patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False,
3513+
mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
3514+
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
3515+
)
3516+
model = _create_vision_transformer(
3517+
'aimv2_large_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
3518+
return model
3519+
3520+
3521+
@register_model
3522+
def aimv2_huge_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
3523+
""" ViT Huge AIM-v2 model
3524+
"""
3525+
3526+
model_args = dict(
3527+
patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False,
3528+
mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
3529+
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
3530+
)
3531+
model = _create_vision_transformer(
3532+
'aimv2_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
3533+
return model
3534+
3535+
3536+
@register_model
3537+
def aimv2_1b_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
3538+
""" ViT 1B AIM-v2 model
34473539
"""
3448-
rms_norm = partial(RmsNorm, eps=1e-5)
34493540
model_args = dict(
3450-
patch_size=14, embed_dim=1024, depth=24, num_heads=16, class_token=False, fc_norm=False,
3451-
mlp_ratio=5.5, global_pool='avg', norm_layer=rms_norm, embed_norm_layer=rms_norm, mlp_layer=SwiGLUPacked,
3452-
qkv_bias=False, proj_bias=False, act_layer='silu'
3541+
patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False,
3542+
mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
3543+
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
34533544
)
34543545
model = _create_vision_transformer(
3455-
'vit_large_patch14_aimv2_224', pretrained=pretrained, **dict(model_args, **kwargs))
3546+
'aimv2_1b_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
3547+
return model
3548+
3549+
3550+
@register_model
3551+
def aimv2_3b_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
3552+
""" ViT 3B AIM-v2 model
3553+
"""
3554+
model_args = dict(
3555+
patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False,
3556+
mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
3557+
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
3558+
)
3559+
model = _create_vision_transformer(
3560+
'aimv2_3b_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
3561+
return model
3562+
3563+
3564+
@register_model
3565+
def aimv2_large_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
3566+
""" ViT Large AIM-v2 model
3567+
"""
3568+
model_args = dict(
3569+
patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False,
3570+
mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
3571+
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
3572+
)
3573+
model = _create_vision_transformer(
3574+
'aimv2_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
3575+
return model
3576+
3577+
3578+
@register_model
3579+
def aimv2_huge_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
3580+
""" ViT Huge AIM-v2 model
3581+
"""
3582+
model_args = dict(
3583+
patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False,
3584+
mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
3585+
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
3586+
)
3587+
model = _create_vision_transformer(
3588+
'aimv2_huge_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
3589+
return model
3590+
3591+
3592+
@register_model
3593+
def aimv2_1b_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
3594+
""" ViT 1B AIM-v2 model
3595+
"""
3596+
model_args = dict(
3597+
patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False,
3598+
mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
3599+
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
3600+
)
3601+
model = _create_vision_transformer(
3602+
'aimv2_1b_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
3603+
return model
3604+
3605+
3606+
@register_model
3607+
def aimv2_3b_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
3608+
""" ViT 3B AIM-v2 model
3609+
"""
3610+
model_args = dict(
3611+
patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False,
3612+
mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
3613+
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
3614+
)
3615+
model = _create_vision_transformer(
3616+
'aimv2_3b_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
3617+
return model
3618+
3619+
3620+
@register_model
3621+
def aimv2_large_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
3622+
""" ViT Large AIM-v2 model
3623+
"""
3624+
model_args = dict(
3625+
patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False,
3626+
mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
3627+
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
3628+
)
3629+
model = _create_vision_transformer(
3630+
'aimv2_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
3631+
return model
3632+
3633+
3634+
@register_model
3635+
def aimv2_huge_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
3636+
""" ViT Huge AIM-v2 model
3637+
"""
3638+
model_args = dict(
3639+
patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False,
3640+
mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
3641+
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
3642+
)
3643+
model = _create_vision_transformer(
3644+
'aimv2_huge_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
3645+
return model
3646+
3647+
3648+
@register_model
3649+
def aimv2_1b_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
3650+
""" ViT 1B AIM-v2 model
3651+
"""
3652+
model_args = dict(
3653+
patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False,
3654+
mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
3655+
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
3656+
)
3657+
model = _create_vision_transformer(
3658+
'aimv2_1b_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
3659+
return model
3660+
3661+
3662+
@register_model
3663+
def aimv2_3b_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
3664+
""" ViT 3B AIM-v2 model
3665+
"""
3666+
model_args = dict(
3667+
patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False,
3668+
mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
3669+
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
3670+
)
3671+
model = _create_vision_transformer(
3672+
'aimv2_3b_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
34563673
return model
34573674

34583675

@@ -3487,6 +3704,19 @@ def test_vit3(pretrained: bool = False, **kwargs) -> VisionTransformer:
34873704
return model
34883705

34893706

3707+
@register_model
3708+
def test_vit4(pretrained: bool = False, **kwargs) -> VisionTransformer:
3709+
""" ViT Test
3710+
"""
3711+
model_args = dict(
3712+
patch_size=16, embed_dim=96, depth=9, num_heads=3, mlp_ratio=3,
3713+
class_token=False, reg_tokens=1, global_pool='avg', init_values=1e-5, dynamic_img_size=True,
3714+
norm_layer='rmsnorm',
3715+
)
3716+
model = _create_vision_transformer('test_vit4', pretrained=pretrained, **dict(model_args, **kwargs))
3717+
return model
3718+
3719+
34903720
register_model_deprecations(__name__, {
34913721
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
34923722
'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',

0 commit comments

Comments
 (0)