Skip to content

Commit eb84e4b

Browse files
committed
Switch hf hub entries for new aimv2 / dfn weights to point to timm locations. Undo forced device for SDR linspace, part of another change.
1 parent 874037e commit eb84e4b

File tree

1 file changed

+17
-28
lines changed

1 file changed

+17
-28
lines changed

timm/models/vision_transformer.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def __init__(
556556
self.patch_drop = nn.Identity()
557557
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
558558

559-
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth, device='cpu')] # stochastic depth decay rule
559+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
560560
self.blocks = nn.Sequential(*[
561561
block_fn(
562562
dim=embed_dim,
@@ -1158,22 +1158,12 @@ def _convert_aimv2(
11581158
k = k.replace('preprocessor.pos_embed', 'pos_embed')
11591159
k = k.replace('trunk.', '')
11601160
k = k.replace('post_trunk_norm.', 'norm.')
1161-
1162-
# packed ver, FIXME to delete
1163-
# if 'mlp.fc1' in k:
1164-
# if k in out_dict:
1165-
# v = torch.cat([v, out_dict[k]], dim=0)
1166-
# elif 'mlp.fc3' in k:
1167-
# k = k.replace('mlp.fc3', 'mlp.fc1')
1168-
# if k in out_dict:
1169-
# v = torch.cat([out_dict[k], v], dim=0)
11701161
k = k.replace('mlp.fc1', 'mlp.fc1_g')
11711162
k = k.replace('mlp.fc3', 'mlp.fc1_x')
1172-
11731163
out_dict[k] = v
1174-
11751164
return out_dict
11761165

1166+
11771167
def checkpoint_filter_fn(
11781168
state_dict: Dict[str, torch.Tensor],
11791169
model: VisionTransformer,
@@ -1688,8 +1678,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
16881678
license='apple-ascl',
16891679
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
16901680
'vit_large_patch14_clip_224.dfn2b_s39b': _cfg(
1691-
#hf_hub_id='timm/',
1692-
hf_hub_id='apple/DFN2B-CLIP-ViT-L-14-39B', hf_hub_filename='open_clip_pytorch_model.bin',
1681+
hf_hub_id='timm/',
16931682
license='apple-ascl',
16941683
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
16951684
'vit_large_patch14_clip_224.dfn2b': _cfg(
@@ -2177,59 +2166,59 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
21772166
),
21782167

21792168
'aimv2_large_patch14_224.apple_pt': _cfg(
2180-
hf_hub_id='apple/aimv2-large-patch14-224',
2169+
hf_hub_id='timm/',
21812170
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
21822171
crop_pct=1.0, num_classes=0),
21832172
'aimv2_large_patch14_224.apple_pt_dist': _cfg(
2184-
hf_hub_id='apple/aimv2-large-patch14-224-distilled',
2173+
hf_hub_id='timm/',
21852174
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
21862175
crop_pct=1.0, num_classes=0),
21872176
'aimv2_huge_patch14_224.apple_pt': _cfg(
2188-
hf_hub_id='apple/aimv2-huge-patch14-224',
2177+
hf_hub_id='timm/',
21892178
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
21902179
crop_pct=1.0, num_classes=0),
21912180
'aimv2_1b_patch14_224.apple_pt': _cfg(
2192-
hf_hub_id='apple/aimv2-1b-patch14-224',
2181+
hf_hub_id='timm/',
21932182
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
21942183
crop_pct=1.0, num_classes=0),
21952184
'aimv2_3b_patch14_224.apple_pt': _cfg(
2196-
hf_hub_id='apple/aimv2-3b-patch14-224',
2185+
hf_hub_id='timm/',
21972186
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
21982187
crop_pct=1.0, num_classes=0),
21992188
'aimv2_large_patch14_336.apple_pt': _cfg(
2200-
hf_hub_id='apple/aimv2-large-patch14-336',
2189+
hf_hub_id='timm/',
22012190
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
22022191
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
22032192
'aimv2_large_patch14_336.apple_pt_dist': _cfg(
2204-
hf_hub_id='apple/aimv2-large-patch14-336-distilled',
2193+
hf_hub_id='timm/',
22052194
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
22062195
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
22072196
'aimv2_huge_patch14_336.apple_pt': _cfg(
2208-
hf_hub_id='apple/aimv2-huge-patch14-336',
2197+
hf_hub_id='timm/',
22092198
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
22102199
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
22112200
'aimv2_1b_patch14_336.apple_pt': _cfg(
2212-
hf_hub_id='apple/aimv2-1b-patch14-336',
2201+
hf_hub_id='timm/',
22132202
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
22142203
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
22152204
'aimv2_3b_patch14_336.apple_pt': _cfg(
2216-
hf_hub_id='apple/aimv2-3b-patch14-336',
2205+
hf_hub_id='timm/',
22172206
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
22182207
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
22192208
'aimv2_large_patch14_448.apple_pt': _cfg(
2220-
hf_hub_id='apple/aimv2-large-patch14-448',
2209+
hf_hub_id='timm/',
22212210
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
22222211
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
22232212
'aimv2_huge_patch14_448.apple_pt': _cfg(
2224-
hf_hub_id='apple/aimv2-huge-patch14-448',
2213+
hf_hub_id='timm/',
22252214
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
22262215
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
22272216
'aimv2_1b_patch14_448.apple_pt': _cfg(
2228-
hf_hub_id='apple/aimv2-1b-patch14-448',
2217+
hf_hub_id='timm/',
22292218
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
22302219
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
22312220
'aimv2_3b_patch14_448.apple_pt': _cfg(
2232-
hf_hub_id='apple/aimv2-3b-patch14-448',
2221+
hf_hub_id='timm/',
22332222
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
22342223
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
22352224

0 commit comments

Comments
 (0)