Skip to content

Commit fc0b6ad

Browse files
committed
Fix default_cfgs
1 parent 848b8c3 commit fc0b6ad

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
5454
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
5555
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
56-
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*'
56+
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt',
5757
]
5858

5959
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.

timm/models/tnt.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ._builder import build_model_with_cfg
2121
from ._features import feature_take_indices
2222
from ._manipulate import checkpoint
23-
from ._registry import register_model
23+
from ._registry import generate_default_cfgs, register_model
2424

2525

2626
__all__ = ['TNT'] # model_registry will add each entrypoint fn to this
@@ -450,11 +450,14 @@ def _cfg(url='', **kwargs):
450450
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
451451
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
452452
'first_conv': 'pixel_embed.proj', 'classifier': 'head',
453+
'paper_ids': 'arXiv:2103.00112',
454+
'paper_name': 'Transformer in Transformer',
455+
'origin_url': 'https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch',
453456
**kwargs
454457
}
455458

456459

457-
default_cfgs = {
460+
default_cfgs = generate_default_cfgs({
458461
'tnt_s_patch16_224.in1k': _cfg(
459462
# hf_hub_id='timm/',
460463
# url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
@@ -464,7 +467,7 @@ def _cfg(url='', **kwargs):
464467
# hf_hub_id='timm/',
465468
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
466469
),
467-
}
470+
})
468471

469472

470473
def checkpoint_filter_fn(state_dict, model):

0 commit comments

Comments
 (0)