Skip to content
Merged
8 changes: 4 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
'regnet', 'byobnet', 'byoanet', 'mlp_mixer'
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera',
]

# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*'
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*'
]
NUM_NON_STD = len(NON_STD_FILTERS)

Expand All @@ -77,7 +77,7 @@
EXCLUDE_FILTERS = ['*enormous*']
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*']

EXCLUDE_JIT_FILTERS = []
EXCLUDE_JIT_FILTERS = ['hiera_*']

TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
TARGET_BWD_SIZE = 128
Expand Down Expand Up @@ -486,7 +486,7 @@ def _create_fx_model(model, train=False):
return fx_model


EXCLUDE_FX_FILTERS = ['vit_gi*']
EXCLUDE_FX_FILTERS = ['vit_gi*', 'hiera*']
# not enough memory to run fx on more models than other tests
if 'GITHUB_ACTIONS' in os.environ:
EXCLUDE_FX_FILTERS += [
Expand Down
4 changes: 2 additions & 2 deletions timm/layers/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
self.fc = fc
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()

def reset(self, num_classes, pool_type=None):
def reset(self, num_classes: int, pool_type: Optional[str] = None):
if pool_type is not None and pool_type != self.global_pool.pool_type:
self.global_pool, self.fc = create_classifier(
self.in_features,
Expand Down Expand Up @@ -180,7 +180,7 @@ def __init__(
self.drop = nn.Dropout(drop_rate)
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

def reset(self, num_classes, pool_type=None):
def reset(self, num_classes: int, pool_type: Optional[str] = None):
if pool_type is not None:
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
Expand Down
2 changes: 1 addition & 1 deletion timm/layers/create_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_norm_layer(norm_layer):
if isinstance(norm_layer, str):
if not norm_layer:
return None
layer_name = norm_layer.replace('_', '')
layer_name = norm_layer.replace('_', '').lower()
norm_layer = _NORM_MAP[layer_name]
else:
norm_layer = norm_layer
Expand Down
1 change: 1 addition & 0 deletions timm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .ghostnet import *
from .hardcorenas import *
from .hgnet import *
from .hiera import *
from .hrnet import *
from .inception_next import *
from .inception_resnet_v2 import *
Expand Down
2 changes: 1 addition & 1 deletion timm/models/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def group_matcher(self, coarse=False):
def get_classifier(self):
return self.head

def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
Expand Down
2 changes: 1 addition & 1 deletion timm/models/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def _matcher(name):
def get_classifier(self):
return self.head

def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'token', 'avg')
Expand Down
5 changes: 2 additions & 3 deletions timm/models/coat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

Modified from timm/models/vision_transformer.py
"""
from functools import partial
from typing import Tuple, List, Union
from typing import List, Optional, Union, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -560,7 +559,7 @@ def group_matcher(self, coarse=False):
def get_classifier(self):
return self.head

def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('token', 'avg')
Expand Down
5 changes: 2 additions & 3 deletions timm/models/convit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
'''These modules are adapted from those of timm, see
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
'''

from functools import partial
from typing import Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -349,7 +348,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self):
return self.head

def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'token', 'avg')
Expand Down
4 changes: 3 additions & 1 deletion timm/models/convmixer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
""" ConvMixer

"""
from typing import Optional

import torch
import torch.nn as nn

Expand Down Expand Up @@ -75,7 +77,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self):
return self.head

def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
Expand Down
1 change: 0 additions & 1 deletion timm/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
# LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
# No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.

from collections import OrderedDict
from functools import partial
from typing import Callable, List, Optional, Tuple, Union

Expand Down
5 changes: 2 additions & 3 deletions timm/models/crossvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@

"""
from functools import partial
from typing import List
from typing import Tuple
from typing import List, Optional, Tuple

import torch
import torch.hub
Expand Down Expand Up @@ -419,7 +418,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self):
return self.head

def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('token', 'avg')
Expand Down
4 changes: 2 additions & 2 deletions timm/models/davit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# All rights reserved.
# This source code is licensed under the MIT license
from functools import partial
from typing import Tuple
from typing import Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -568,7 +568,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self):
return self.head.fc

def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, global_pool)

def forward_features(self, x):
Expand Down
5 changes: 2 additions & 3 deletions timm/models/deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
from functools import partial
from typing import Sequence, Union
from typing import Optional

import torch
from torch import nn as nn
Expand All @@ -20,7 +20,6 @@
from timm.layers import resample_abs_pos_embed
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model, register_model_deprecations

__all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this
Expand Down Expand Up @@ -64,7 +63,7 @@ def group_matcher(self, coarse=False):
def get_classifier(self):
return self.head, self.head_dist

def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
Expand Down
3 changes: 1 addition & 2 deletions timm/models/edgenext.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
Modifications and additions for timm by / Copyright 2022, Ross Wightman
"""
import math
from collections import OrderedDict
from functools import partial
from typing import Tuple

Expand All @@ -17,7 +16,7 @@
from torch import nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d, \
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \
use_fused_attn, NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_module
Expand Down
2 changes: 1 addition & 1 deletion timm/models/efficientformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self):
return self.head, self.head_dist

def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
Expand Down
4 changes: 2 additions & 2 deletions timm/models/efficientformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""
import math
from functools import partial
from typing import Dict
from typing import Dict, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -612,7 +612,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self):
return self.head, self.head_dist

def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
Expand Down
5 changes: 2 additions & 3 deletions timm/models/efficientvit_mit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
Expand Down Expand Up @@ -740,7 +739,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self):
return self.head.classifier[-1]

def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
Expand Down Expand Up @@ -858,7 +857,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self):
return self.head.classifier[-1]

def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
Expand Down
4 changes: 2 additions & 2 deletions timm/models/efficientvit_msra.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
__all__ = ['EfficientVitMsra']
import itertools
from collections import OrderedDict
from typing import Dict
from typing import Dict, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -464,7 +464,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self):
return self.head.linear

def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
if global_pool == 'avg':
Expand Down
2 changes: 1 addition & 1 deletion timm/models/eva.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def group_matcher(self, coarse=False):
def get_classifier(self):
return self.head

def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
Expand Down
4 changes: 2 additions & 2 deletions timm/models/fastvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def reparameterize(self) -> None:

@staticmethod
def _fuse_bn(
conv: torch.Tensor, bn: nn.BatchNorm2d
conv: nn.Conv2d, bn: nn.BatchNorm2d
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Method to fuse batchnorm layer with conv layer.

Expand Down Expand Up @@ -1232,7 +1232,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self):
return self.head.fc

def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

Expand Down
2 changes: 1 addition & 1 deletion timm/models/focalnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self):
return self.head.fc

def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.head.reset(num_classes, pool_type=global_pool)

def forward_features(self, x):
Expand Down
2 changes: 1 addition & 1 deletion timm/models/gcvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def set_grad_checkpointing(self, enable=True):
def get_classifier(self):
return self.head.fc

def reset_classifier(self, num_classes, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is None:
global_pool = self.head.global_pool.pool_type
Expand Down
Loading