Skip to content

Commit 7bfe606

Browse files
committed
Merge remote-tracking branch 'origin/main' into naflex
2 parents f001b15 + e7925ea commit 7bfe606

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2655
-305
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ Model validation results can be found in the [results tables](results/README.md)
566566

567567
The official documentation can be found at https://huggingface.co/docs/hub/timm. Documentation contributions are welcome.
568568

569-
[Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
569+
[Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055-2/) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
570570

571571
[timmdocs](http://timm.fast.ai/) is an alternate set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.
572572

@@ -598,6 +598,7 @@ One of the greatest assets of PyTorch is the community and their contributions.
598598

599599
### Training / Frameworks
600600
* fastai - https://github.com/fastai/fastai
601+
* lightly_train - https://github.com/lightly-ai/lightly-train
601602

602603
## Licenses
603604

onnx_export.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,13 @@
4343
metavar='N', help='mini-batch size (default: 1)')
4444
parser.add_argument('--img-size', default=None, type=int,
4545
metavar='N', help='Input image dimension, uses model default if empty')
46+
parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N',
47+
help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
4648
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
4749
help='Override mean pixel value of dataset')
4850
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
4951
help='Override std deviation of of dataset')
50-
parser.add_argument('--num-classes', type=int, default=1000,
52+
parser.add_argument('--num-classes', type=int, default=None,
5153
help='Number classes in dataset')
5254
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
5355
help='path to checkpoint (default: none)')
@@ -82,6 +84,14 @@ def main():
8284
if args.reparam:
8385
model = reparameterize_model(model)
8486

87+
if args.input_size is not None:
88+
assert len(args.input_size) == 3, 'input-size should be N H W (channels, height, width)'
89+
input_size = args.input_size
90+
elif args.img_size is not None:
91+
input_size = (3, args.img_size, args.img_size)
92+
else:
93+
input_size = None
94+
8595
onnx_export(
8696
model,
8797
args.output,
@@ -93,7 +103,7 @@ def main():
93103
training=args.training,
94104
verbose=args.verbose,
95105
use_dynamo=args.dynamo,
96-
input_size=(3, args.img_size, args.img_size),
106+
input_size=input_size,
97107
batch_size=args.batch_size,
98108
)
99109

tests/test_models.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@
5353
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
5454
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
5555
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
56-
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*'
56+
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt',
57+
'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest',
58+
'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext',
59+
'davit', 'rdnet', 'convnext', 'pit'
5760
]
5861

5962
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
@@ -508,8 +511,9 @@ def test_model_forward_intermediates(model_name, batch_size):
508511
spatial_axis = get_spatial_dim(output_fmt)
509512
import math
510513

514+
inpt = torch.randn((batch_size, *input_size))
511515
output, intermediates = model.forward_intermediates(
512-
torch.randn((batch_size, *input_size)),
516+
inpt,
513517
output_fmt=output_fmt,
514518
)
515519
assert len(expected_channels) == len(intermediates)
@@ -521,6 +525,9 @@ def test_model_forward_intermediates(model_name, batch_size):
521525
assert o.shape[0] == batch_size
522526
assert not torch.isnan(o).any()
523527

528+
output2 = model.forward_features(inpt)
529+
assert torch.allclose(output, output2)
530+
524531

525532
def _create_fx_model(model, train=False):
526533
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode

timm/data/dataset_factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def create_dataset(
144144
use_train = split in _TRAIN_SYNONYM
145145
ds = QMNIST(train=use_train, **torch_kwargs)
146146
elif name == 'imagenet':
147+
torch_kwargs.pop('download')
147148
assert has_imagenet, 'Please update to a newer PyTorch and torchvision for ImageNet dataset.'
148149
if split in _EVAL_SYNONYM:
149150
split = 'val'

timm/layers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .activations import *
22
from .adaptive_avgmax_pool import \
33
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
4-
from .attention import Attention
4+
from .attention import Attention, AttentionRope
55
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
66
from .attention_pool import AttentionPoolLatent
77
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
@@ -42,6 +42,7 @@
4242
from .padding import get_padding, get_same_padding, pad_same
4343
from .patch_dropout import PatchDropout
4444
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
45+
from .pool1d import global_pool_nlc
4546
from .pool2d_same import AvgPool2dSame, create_pool2d
4647
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
4748
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \

timm/layers/attention.py

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
1-
from typing import Final, Type, Optional
1+
from typing import Final, Optional, Type
22

33
import torch
44
from torch import nn as nn
55
from torch.nn import functional as F
66

77
from .config import use_fused_attn
8+
from .pos_embed_sincos import apply_rot_embed_cat
89

910

1011
class Attention(nn.Module):
12+
"""Standard Multi-head Self Attention module with QKV projection.
13+
14+
This module implements the standard multi-head attention mechanism used in transformers.
15+
It supports both the fused attention implementation (scaled_dot_product_attention) for
16+
efficiency when available, and a manual implementation otherwise. The module includes
17+
options for QK normalization, attention dropout, and projection dropout.
18+
"""
1119
fused_attn: Final[bool]
1220

1321
def __init__(
@@ -21,6 +29,18 @@ def __init__(
2129
proj_drop: float = 0.,
2230
norm_layer: Type[nn.Module] = nn.LayerNorm,
2331
) -> None:
32+
"""Initialize the Attention module.
33+
34+
Args:
35+
dim: Input dimension of the token embeddings
36+
num_heads: Number of attention heads
37+
qkv_bias: Whether to use bias in the query, key, value projections
38+
qk_norm: Whether to apply normalization to query and key vectors
39+
proj_bias: Whether to use bias in the output projection
40+
attn_drop: Dropout rate applied to the attention weights
41+
proj_drop: Dropout rate applied after the output projection
42+
norm_layer: Normalization layer constructor for QK normalization if enabled
43+
"""
2444
super().__init__()
2545
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
2646
self.num_heads = num_heads
@@ -64,3 +84,129 @@ def forward(
6484
x = self.proj(x)
6585
x = self.proj_drop(x)
6686
return x
87+
88+
89+
class AttentionRope(nn.Module):
90+
""" A Self Attention module with ROPE support.
91+
92+
Includes options for:
93+
* QK normalization option
94+
* Attention output (scale) normalization
95+
* Fused or unfused QKV projection support
96+
"""
97+
fused_attn: torch.jit.Final[bool]
98+
99+
def __init__(
100+
self,
101+
dim: int,
102+
num_heads: int = 8,
103+
qkv_bias: bool = True,
104+
qkv_fused: bool = True,
105+
num_prefix_tokens: int = 1,
106+
attn_drop: float = 0.,
107+
proj_drop: float = 0.,
108+
attn_head_dim: Optional[int] = None,
109+
norm_layer: Type[nn.Module] = None,
110+
qk_norm: bool = False,
111+
scale_norm: bool = False,
112+
):
113+
"""Initialize the Attention module.
114+
115+
Args:
116+
dim: Input dimension of the token embeddings
117+
num_heads: Number of attention heads
118+
qkv_bias: Whether to add a bias term to the query, key, and value projections
119+
num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that
120+
should not have position embeddings applied
121+
attn_drop: Dropout rate for attention weights
122+
proj_drop: Dropout rate for the output projection
123+
attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
124+
norm_layer: Normalization layer constructor to use for QK and scale normalization
125+
qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer
126+
scale_norm: Enable normalization (scaling) of attention output with norm_layer
127+
"""
128+
super().__init__()
129+
if scale_norm or qk_norm:
130+
assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
131+
self.num_heads = num_heads
132+
head_dim = dim // num_heads
133+
if attn_head_dim is not None:
134+
head_dim = attn_head_dim
135+
attn_dim = head_dim * self.num_heads
136+
self.scale = head_dim ** -0.5
137+
self.num_prefix_tokens = num_prefix_tokens
138+
self.fused_attn = use_fused_attn()
139+
140+
if qkv_fused:
141+
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
142+
self.q_proj = self.k_proj = self.v_proj = None
143+
else:
144+
self.qkv = None
145+
self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
146+
self.k_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
147+
self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
148+
149+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
150+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
151+
self.attn_drop = nn.Dropout(attn_drop)
152+
self.norm = norm_layer(attn_dim) if scale_norm else nn.Identity()
153+
self.proj = nn.Linear(attn_dim, dim)
154+
self.proj_drop = nn.Dropout(proj_drop)
155+
156+
def forward(
157+
self,
158+
x,
159+
rope: Optional[torch.Tensor] = None,
160+
attn_mask: Optional[torch.Tensor] = None,
161+
):
162+
"""Forward pass for the attention module.
163+
164+
Args:
165+
x: Input tensor of shape (batch_size, sequence_length, embedding_dim)
166+
rope: Rotary position embeddings tensor for position-aware attention
167+
attn_mask: Optional attention mask to apply during attention computation
168+
169+
Returns:
170+
Tensor of shape (batch_size, sequence_length, embedding_dim)
171+
"""
172+
B, N, C = x.shape
173+
174+
if self.qkv is not None:
175+
qkv = self.qkv(x)
176+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
177+
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
178+
else:
179+
q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C
180+
k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
181+
v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
182+
183+
q, k = self.q_norm(q), self.k_norm(k)
184+
185+
if rope is not None:
186+
npt = self.num_prefix_tokens
187+
q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v)
188+
k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope)], dim=2).type_as(v)
189+
190+
if self.fused_attn:
191+
x = F.scaled_dot_product_attention(
192+
q, k, v,
193+
attn_mask=attn_mask,
194+
dropout_p=self.attn_drop.p if self.training else 0.,
195+
)
196+
else:
197+
q = q * self.scale
198+
attn = (q @ k.transpose(-2, -1))
199+
200+
if attn_mask is not None:
201+
attn_mask = attn_mask.to(torch.bool)
202+
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
203+
attn = attn.softmax(dim=-1)
204+
205+
attn = self.attn_drop(attn)
206+
x = attn @ v
207+
208+
x = x.transpose(1, 2).reshape(B, N, C)
209+
x = self.norm(x)
210+
x = self.proj(x)
211+
x = self.proj_drop(x)
212+
return x

timm/layers/attention_pool.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, Type
22

33
import torch
44
import torch.nn as nn
@@ -28,8 +28,8 @@ def __init__(
2828
latent_dim: int = None,
2929
pos_embed: str = '',
3030
pool_type: str = 'token',
31-
norm_layer: Optional[nn.Module] = None,
32-
act_layer: Optional[nn.Module] = nn.GELU,
31+
norm_layer: Optional[Type[nn.Module]] = None,
32+
act_layer: Optional[Type[nn.Module]] = nn.GELU,
3333
drop: float = 0.0,
3434
):
3535
super().__init__()

timm/layers/patch_embed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class PatchEmbed(nn.Module):
3131

3232
def __init__(
3333
self,
34-
img_size: Optional[int] = 224,
34+
img_size: Union[int, Tuple[int, int]] = 224,
3535
patch_size: int = 16,
3636
in_chans: int = 3,
3737
embed_dim: int = 768,

timm/layers/pool1d.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch
2+
3+
4+
def global_pool_nlc(
5+
x: torch.Tensor,
6+
pool_type: str = 'token',
7+
num_prefix_tokens: int = 1,
8+
reduce_include_prefix: bool = False,
9+
):
10+
if not pool_type:
11+
return x
12+
13+
if pool_type == 'token':
14+
x = x[:, 0] # class token
15+
else:
16+
x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
17+
if pool_type == 'avg':
18+
x = x.mean(dim=1)
19+
elif pool_type == 'avgmax':
20+
x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
21+
elif pool_type == 'max':
22+
x = x.amax(dim=1)
23+
else:
24+
assert not pool_type, f'Unknown pool type {pool_type}'
25+
26+
return x

0 commit comments

Comments
 (0)