@@ -912,26 +912,40 @@ def resize_pos_embed(
912
912
913
913
914
914
@torch .no_grad ()
915
- def _load_weights (model : VisionTransformer , checkpoint_path : str , prefix : str = '' ) -> None :
915
+ def _load_weights (model : VisionTransformer , checkpoint_path : str , prefix : str = '' , load_bfloat16 : bool = False ) -> None :
916
916
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
917
917
"""
918
918
import numpy as np
919
+ if load_bfloat16 :
920
+ import jax .numpy as jnp
921
+ import ml_dtypes
919
922
920
- def _n2p (w , t = True , idx = None ):
923
+ def _n2p (_w , t = True , idx = None ):
921
924
if idx is not None :
922
- w = w [idx ]
923
- if w .ndim == 4 and w .shape [0 ] == w .shape [1 ] == w .shape [2 ] == 1 :
924
- w = w .flatten ()
925
+ _w = _w [idx ]
926
+
927
+ if load_bfloat16 :
928
+ _w = _w .view (ml_dtypes .bfloat16 ).astype (jnp .float32 )
929
+ _w = np .array (_w )
930
+
931
+ if _w .ndim == 4 and _w .shape [0 ] == _w .shape [1 ] == _w .shape [2 ] == 1 :
932
+ _w = _w .flatten ()
925
933
if t :
926
- if w .ndim == 4 :
927
- w = w .transpose ([3 , 2 , 0 , 1 ])
928
- elif w .ndim == 3 :
929
- w = w .transpose ([2 , 0 , 1 ])
930
- elif w .ndim == 2 :
931
- w = w .transpose ([1 , 0 ])
932
- return torch .from_numpy (w )
933
-
934
- w = np .load (checkpoint_path )
934
+ if _w .ndim == 4 :
935
+ _w = _w .transpose ([3 , 2 , 0 , 1 ])
936
+ elif _w .ndim == 3 :
937
+ _w = _w .transpose ([2 , 0 , 1 ])
938
+ elif _w .ndim == 2 :
939
+ _w = _w .transpose ([1 , 0 ])
940
+
941
+ _w = torch .from_numpy (_w )
942
+ return _w
943
+
944
+ if load_bfloat16 :
945
+ w = jnp .load (checkpoint_path )
946
+ else :
947
+ w = np .load (checkpoint_path )
948
+
935
949
interpolation = 'bilinear'
936
950
antialias = False
937
951
big_vision = False
@@ -1593,18 +1607,18 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
1593
1607
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 1280 ),
1594
1608
1595
1609
'vit_base_patch32_clip_224.laion400m_e32' : _cfg (
1596
- hf_hub_id = 'timm/' , hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1610
+ hf_hub_id = 'timm/' ,
1597
1611
notes = ('natively QuickGELU, use quickgelu model variant for original results' ,),
1598
1612
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , num_classes = 512 ),
1599
1613
'vit_base_patch16_clip_224.laion400m_e32' : _cfg (
1600
- hf_hub_id = 'timm/' , hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1614
+ hf_hub_id = 'timm/' ,
1601
1615
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 512 ),
1602
1616
'vit_base_patch16_plus_clip_240.laion400m_e32' : _cfg (
1603
- hf_hub_id = 'timm/' , hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1617
+ hf_hub_id = 'timm/' ,
1604
1618
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
1605
- input_size = (3 , 240 , 240 ), crop_pct = 1.0 , num_classes = 512 ),
1619
+ input_size = (3 , 240 , 240 ), crop_pct = 1.0 , num_classes = 640 ),
1606
1620
'vit_large_patch14_clip_224.laion400m_e32' : _cfg (
1607
- hf_hub_id = 'timm/' , hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1621
+ hf_hub_id = 'timm/' ,
1608
1622
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 768 ),
1609
1623
1610
1624
'vit_base_patch32_clip_224.datacompxl' : _cfg (
@@ -1622,22 +1636,18 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
1622
1636
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 768 ),
1623
1637
1624
1638
'vit_base_patch16_clip_224.dfn2b' : _cfg (
1625
- hf_hub_id = 'apple/DFN2B-CLIP-ViT-B-16' ,
1626
- hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1639
+ hf_hub_id = 'timm/' ,
1627
1640
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 512 ),
1628
1641
'vit_large_patch14_clip_224.dfn2b' : _cfg (
1629
- hf_hub_id = 'apple/DFN2B-CLIP-ViT-L-14' ,
1630
- hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1642
+ hf_hub_id = 'timm/' ,
1631
1643
notes = ('natively QuickGELU, use quickgelu model variant for original results' ,),
1632
1644
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 768 ),
1633
1645
'vit_huge_patch14_clip_224.dfn5b' : _cfg (
1634
- hf_hub_id = 'apple/DFN5B-CLIP-ViT-H-14' ,
1635
- hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1646
+ hf_hub_id = 'timm/' ,
1636
1647
notes = ('natively QuickGELU, use quickgelu model variant for original results' ,),
1637
1648
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 1024 ),
1638
1649
'vit_huge_patch14_clip_378.dfn5b' : _cfg (
1639
- hf_hub_id = 'apple/DFN5B-CLIP-ViT-H-14-378' ,
1640
- hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1650
+ hf_hub_id = 'timm/' ,
1641
1651
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
1642
1652
notes = ('natively QuickGELU, use quickgelu model variant for original results' ,),
1643
1653
crop_pct = 1.0 , input_size = (3 , 378 , 378 ), num_classes = 1024 ),
@@ -1700,7 +1710,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
1700
1710
notes = ('natively QuickGELU, use quickgelu model variant for original results' ,),
1701
1711
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , crop_pct = 1.0 , num_classes = 768 ),
1702
1712
'vit_large_patch14_clip_336.openai' : _cfg (
1703
- hf_hub_id = 'timm/' , hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1713
+ hf_hub_id = 'timm/' ,
1704
1714
notes = ('natively QuickGELU, use quickgelu model variant for original results' ,),
1705
1715
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
1706
1716
crop_pct = 1.0 , input_size = (3 , 336 , 336 ), num_classes = 768 ),
@@ -1907,15 +1917,22 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
1907
1917
hf_hub_id = 'timm/' ,
1908
1918
num_classes = 0 ),
1909
1919
'vit_so400m_patch14_siglip_gap_224.pali_mix' : _cfg (
1910
- hf_hub_id = 'google/paligemma-3b-mix-224-jax' ,
1911
- hf_hub_filename = 'paligemma-3b-mix-224.npz' ,
1912
- custom_load = 'hf' ,
1920
+ hf_hub_id = 'timm/' ,
1913
1921
num_classes = 0 ),
1914
1922
'vit_so400m_patch14_siglip_gap_224.pali_pt' : _cfg (
1915
- hf_hub_id = 'google/paligemma-3b-pt-224-jax' ,
1916
- hf_hub_filename = 'paligemma-3b-pt-224.npz' ,
1917
- custom_load = 'hf' ,
1923
+ hf_hub_id = 'timm/' ,
1924
+ num_classes = 0 ),
1925
+ 'vit_so400m_patch14_siglip_gap_224.pali2_3b_pt' : _cfg (
1926
+ hf_hub_id = 'timm/' ,
1918
1927
num_classes = 0 ),
1928
+ 'vit_so400m_patch14_siglip_gap_224.pali2_10b_pt' : _cfg (
1929
+ hf_hub_id = 'timm/' ,
1930
+ num_classes = 0 ),
1931
+ # 'vit_so400m_patch14_siglip_gap_224.pali2_28b_pt': _cfg(
1932
+ # hf_hub_id='google/paligemma2-28b-pt-224-jax',
1933
+ # hf_hub_filename='pt_27b_224.npz',
1934
+ # custom_load='hf',
1935
+ # num_classes=0),
1919
1936
'vit_so400m_patch16_siglip_gap_256.webli_i18n' : _cfg (
1920
1937
hf_hub_id = 'timm/' ,
1921
1938
input_size = (3 , 256 , 256 ),
@@ -1929,23 +1946,69 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
1929
1946
input_size = (3 , 384 , 384 ), crop_pct = 1.0 ,
1930
1947
num_classes = 0 ),
1931
1948
'vit_so400m_patch14_siglip_gap_448.pali_mix' : _cfg (
1932
- hf_hub_id = 'google/paligemma-3b-mix-448-jax' ,
1933
- hf_hub_filename = 'paligemma-3b-mix-448.npz' ,
1934
- custom_load = 'hf' ,
1949
+ hf_hub_id = 'timm/' ,
1935
1950
input_size = (3 , 448 , 448 ), crop_pct = 1.0 ,
1936
1951
num_classes = 0 ),
1937
1952
'vit_so400m_patch14_siglip_gap_448.pali_pt' : _cfg (
1938
- hf_hub_id = 'google/paligemma-3b-pt-448-jax' ,
1939
- hf_hub_filename = 'paligemma-3b-pt-448.npz' ,
1940
- custom_load = 'hf' ,
1953
+ hf_hub_id = 'timm/' ,
1954
+ input_size = (3 , 448 , 448 ), crop_pct = 1.0 ,
1955
+ num_classes = 0 ),
1956
+ 'vit_so400m_patch14_siglip_gap_448.pali_refcoco_seg' : _cfg (
1957
+ hf_hub_id = 'timm/' ,
1958
+ input_size = (3 , 448 , 448 ), crop_pct = 1.0 ,
1959
+ num_classes = 0 ),
1960
+ 'vit_so400m_patch14_siglip_gap_448.pali_ocrvqa' : _cfg (
1961
+ hf_hub_id = 'timm/' ,
1962
+ input_size = (3 , 448 , 448 ), crop_pct = 1.0 ,
1963
+ num_classes = 0 ),
1964
+ 'vit_so400m_patch14_siglip_gap_448.pali2_3b_pt' : _cfg (
1965
+ hf_hub_id = 'timm/' ,
1966
+ input_size = (3 , 448 , 448 ), crop_pct = 1.0 ,
1967
+ num_classes = 0 ),
1968
+ 'vit_so400m_patch14_siglip_gap_448.pali2_10b_pt' : _cfg (
1969
+ hf_hub_id = 'timm/' ,
1970
+ input_size = (3 , 448 , 448 ), crop_pct = 1.0 ,
1971
+ num_classes = 0 ),
1972
+ # 'vit_so400m_patch14_siglip_gap_448.pali2_28b_pt': _cfg(
1973
+ # hf_hub_id='google/paligemma2-28b-pt-448-jax',
1974
+ # hf_hub_filename='pt_27b_448.npz',
1975
+ # custom_load='hf',
1976
+ # input_size=(3, 448, 448), crop_pct=1.0,
1977
+ # num_classes=0),
1978
+ 'vit_so400m_patch14_siglip_gap_448.pali2_3b_docci' : _cfg (
1979
+ hf_hub_id = 'timm/' ,
1980
+ input_size = (3 , 448 , 448 ), crop_pct = 1.0 ,
1981
+ num_classes = 0 ),
1982
+ 'vit_so400m_patch14_siglip_gap_448.pali2_10b_docci' : _cfg (
1983
+ hf_hub_id = 'timm/' ,
1941
1984
input_size = (3 , 448 , 448 ), crop_pct = 1.0 ,
1942
1985
num_classes = 0 ),
1943
1986
'vit_so400m_patch14_siglip_gap_896.pali_pt' : _cfg (
1944
- hf_hub_id = 'google/paligemma-3b-pt-896-jax' ,
1945
- hf_hub_filename = 'paligemma-3b-pt-896.npz' ,
1946
- custom_load = 'hf' ,
1987
+ hf_hub_id = 'timm/' ,
1988
+ input_size = (3 , 896 , 896 ), crop_pct = 1.0 ,
1989
+ num_classes = 0 ),
1990
+ 'vit_so400m_patch14_siglip_gap_896.pali_refcoco_seg' : _cfg (
1991
+ hf_hub_id = 'timm/' ,
1992
+ input_size = (3 , 896 , 896 ), crop_pct = 1.0 ,
1993
+ num_classes = 0 ),
1994
+ 'vit_so400m_patch14_siglip_gap_896.pali_ocrvqa' : _cfg (
1995
+ hf_hub_id = 'timm/' ,
1996
+ input_size = (3 , 896 , 896 ), crop_pct = 1.0 ,
1997
+ num_classes = 0 ),
1998
+ 'vit_so400m_patch14_siglip_gap_896.pali2_3b_pt' : _cfg (
1999
+ hf_hub_id = 'timm/' ,
2000
+ input_size = (3 , 896 , 896 ), crop_pct = 1.0 ,
2001
+ num_classes = 0 ),
2002
+ 'vit_so400m_patch14_siglip_gap_896.pali2_10b_pt' : _cfg (
2003
+ hf_hub_id = 'timm/' ,
1947
2004
input_size = (3 , 896 , 896 ), crop_pct = 1.0 ,
1948
2005
num_classes = 0 ),
2006
+ # 'vit_so400m_patch14_siglip_gap_896.pali2_28b_pt': _cfg(
2007
+ # hf_hub_id='google/paligemma2-28b-pt-896-jax',
2008
+ # hf_hub_filename='pt_27b_896.npz',
2009
+ # custom_load='hf',
2010
+ # input_size=(3, 896, 896), crop_pct=1.0,
2011
+ # num_classes=0),
1949
2012
1950
2013
'vit_so400m_patch14_siglip_378.webli_ft_in1k' : _cfg (
1951
2014
hf_hub_id = 'timm/' ,
@@ -1958,22 +2021,18 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
1958
2021
1959
2022
'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m' : _cfg (
1960
2023
hf_hub_id = 'timm/' ,
1961
- hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1962
2024
license = 'mit' ,
1963
2025
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , num_classes = 512 ),
1964
2026
'vit_medium_patch32_clip_224.tinyclip_laion400m' : _cfg (
1965
2027
hf_hub_id = 'timm/' ,
1966
- hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1967
2028
license = 'mit' ,
1968
2029
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , num_classes = 512 ),
1969
2030
'vit_medium_patch16_clip_224.tinyclip_yfcc15m' : _cfg (
1970
2031
hf_hub_id = 'timm/' ,
1971
- hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1972
2032
license = 'mit' ,
1973
2033
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , num_classes = 512 ),
1974
2034
'vit_betwixt_patch32_clip_224.tinyclip_laion400m' : _cfg (
1975
2035
hf_hub_id = 'timm/' ,
1976
- hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1977
2036
license = 'mit' ,
1978
2037
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , num_classes = 512 ),
1979
2038
0 commit comments