Skip to content

Add EVA ViT based PE (Perceptual Encoder) impl #2487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 15, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions timm/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .activations import *
from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .attention import Attention, AttentionRope
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
from .attention_pool import AttentionPoolLatent
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
@@ -41,6 +42,7 @@
from .padding import get_padding, get_same_padding, pad_same
from .patch_dropout import PatchDropout
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
from .pool1d import global_pool_nlc
from .pool2d_same import AvgPool2dSame, create_pool2d
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \
212 changes: 212 additions & 0 deletions timm/layers/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from typing import Final, Optional, Type

import torch
from torch import nn as nn
from torch.nn import functional as F

from .config import use_fused_attn
from .pos_embed_sincos import apply_rot_embed_cat


class Attention(nn.Module):
"""Standard Multi-head Self Attention module with QKV projection.
This module implements the standard multi-head attention mechanism used in transformers.
It supports both the fused attention implementation (scaled_dot_product_attention) for
efficiency when available, and a manual implementation otherwise. The module includes
options for QK normalization, attention dropout, and projection dropout.
"""
fused_attn: Final[bool]

def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: Type[nn.Module] = nn.LayerNorm,
) -> None:
"""Initialize the Attention module.
Args:
dim: Input dimension of the token embeddings
num_heads: Number of attention heads
qkv_bias: Whether to use bias in the query, key, value projections
qk_norm: Whether to apply normalization to query and key vectors
proj_bias: Whether to use bias in the output projection
attn_drop: Dropout rate applied to the attention weights
proj_drop: Dropout rate applied after the output projection
norm_layer: Normalization layer constructor for QK normalization if enabled
"""
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)

def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)

if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if attn_mask is not None:
attn = attn + attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v

x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x


class AttentionRope(nn.Module):
""" A Self Attention module with ROPE support.
Includes options for:
* QK normalization option
* Attention output (scale) normalization
* Fused or unfused QKV projection support
"""
fused_attn: torch.jit.Final[bool]

def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
qkv_fused: bool = True,
num_prefix_tokens: int = 1,
attn_drop: float = 0.,
proj_drop: float = 0.,
attn_head_dim: Optional[int] = None,
norm_layer: Type[nn.Module] = None,
qk_norm: bool = False,
scale_norm: bool = False,
):
"""Initialize the Attention module.
Args:
dim: Input dimension of the token embeddings
num_heads: Number of attention heads
qkv_bias: Whether to add a bias term to the query, key, and value projections
num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that
should not have position embeddings applied
attn_drop: Dropout rate for attention weights
proj_drop: Dropout rate for the output projection
attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
norm_layer: Normalization layer constructor to use for QK and scale normalization
qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer
scale_norm: Enable normalization (scaling) of attention output with norm_layer
"""
super().__init__()
if scale_norm or qk_norm:
assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
attn_dim = head_dim * self.num_heads
self.scale = head_dim ** -0.5
self.num_prefix_tokens = num_prefix_tokens
self.fused_attn = use_fused_attn()

if qkv_fused:
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
self.q_proj = self.k_proj = self.v_proj = None
else:
self.qkv = None
self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
self.k_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)

self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.norm = norm_layer(attn_dim) if scale_norm else nn.Identity()
self.proj = nn.Linear(attn_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(
self,
x,
rope: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
"""Forward pass for the attention module.
Args:
x: Input tensor of shape (batch_size, sequence_length, embedding_dim)
rope: Rotary position embeddings tensor for position-aware attention
attn_mask: Optional attention mask to apply during attention computation
Returns:
Tensor of shape (batch_size, sequence_length, embedding_dim)
"""
B, N, C = x.shape

if self.qkv is not None:
qkv = self.qkv(x)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
else:
q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C
k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)

q, k = self.q_norm(q), self.k_norm(k)

if rope is not None:
npt = self.num_prefix_tokens
q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v)
k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope)], dim=2).type_as(v)

if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = (q @ k.transpose(-2, -1))

if attn_mask is not None:
attn_mask = attn_mask.to(torch.bool)
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
attn = attn.softmax(dim=-1)

attn = self.attn_drop(attn)
x = attn @ v

x = x.transpose(1, 2).reshape(B, N, C)
x = self.norm(x)
x = self.proj(x)
x = self.proj_drop(x)
return x
6 changes: 3 additions & 3 deletions timm/layers/attention_pool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Type

import torch
import torch.nn as nn
@@ -28,8 +28,8 @@ def __init__(
latent_dim: int = None,
pos_embed: str = '',
pool_type: str = 'token',
norm_layer: Optional[nn.Module] = None,
act_layer: Optional[nn.Module] = nn.GELU,
norm_layer: Optional[Type[nn.Module]] = None,
act_layer: Optional[Type[nn.Module]] = nn.GELU,
drop: float = 0.0,
):
super().__init__()
26 changes: 26 additions & 0 deletions timm/layers/pool1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch


def global_pool_nlc(
x: torch.Tensor,
pool_type: str = 'token',
num_prefix_tokens: int = 1,
reduce_include_prefix: bool = False,
):
if not pool_type:
return x

if pool_type == 'token':
x = x[:, 0] # class token
else:
x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
if pool_type == 'avg':
x = x.mean(dim=1)
elif pool_type == 'avgmax':
x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
elif pool_type == 'max':
x = x.amax(dim=1)
else:
assert not pool_type, f'Unknown pool type {pool_type}'

return x
40 changes: 37 additions & 3 deletions timm/layers/pos_embed_sincos.py
Original file line number Diff line number Diff line change
@@ -87,6 +87,8 @@ def build_fourier_pos_embed(
include_grid: bool = False,
in_pixels: bool = True,
ref_feat_shape: Optional[List[int]] = None,
grid_offset: float = 0.,
grid_indexing: str = 'ij',
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> List[torch.Tensor]:
@@ -102,6 +104,8 @@ def build_fourier_pos_embed(
include_grid: Include the spatial grid in output.
in_pixels: Output in pixel freq.
ref_feat_shape: Reference feature shape for resize / fine-tune.
grid_offset: Constant offset to add to grid for non-pixel freq.
grid_indexing: Indexing mode for meshgrid ('ij' or 'xy')
dtype: Output dtype.
device: Output device.
@@ -130,15 +134,21 @@ def build_fourier_pos_embed(
dtype = bands.dtype

if in_pixels:
t = [torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32) for s in feat_shape]
t = [
torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32)
for s in feat_shape
]
else:
t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape]
t = [
torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) + grid_offset
for s in feat_shape
]

if ref_feat_shape is not None:
# eva's scheme for resizing rope embeddings (ref shape = pretrain)
t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]

grid = torch.stack(ndgrid(t), dim=-1)
grid = torch.stack(torch.meshgrid(t, indexing=grid_indexing), dim=-1)
grid = grid.unsqueeze(-1)
pos = grid * bands

@@ -229,6 +239,8 @@ def build_rotary_pos_embed(
linear_bands: bool = False,
in_pixels: bool = True,
ref_feat_shape: Optional[List[int]] = None,
grid_offset: float = 0.,
grid_indexing: str = 'ij',
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
):
@@ -242,6 +254,9 @@ def build_rotary_pos_embed(
temperature: Temperature (inv freq) for non-pixel mode
linear_bands: Linearly (instead of log) spaced bands for pixel mode
in_pixels: Pixel vs language (inv freq) mode.
ref_feat_shape: Reference feature shape for resize / fine-tune.
grid_offset: Constant offset to add to grid for non-pixel freq.
grid_indexing: Indexing mode for meshgrid ('ij' or 'xy')
dtype: Output dtype.
device: Output device.
@@ -257,6 +272,8 @@ def build_rotary_pos_embed(
linear_bands=linear_bands,
in_pixels=in_pixels,
ref_feat_shape=ref_feat_shape,
grid_offset=grid_offset,
grid_indexing=grid_indexing,
device=device,
dtype=dtype,
)
@@ -289,6 +306,8 @@ def __init__(
linear_bands: bool = False,
feat_shape: Optional[List[int]] = None,
ref_feat_shape: Optional[List[int]] = None,
grid_offset: float = 0.,
grid_indexing: str = 'ij',
):
super().__init__()
self.dim = dim
@@ -297,6 +316,8 @@ def __init__(
self.in_pixels = in_pixels
self.feat_shape = feat_shape
self.ref_feat_shape = ref_feat_shape
self.grid_offset = grid_offset
self.grid_indexing = grid_indexing

if feat_shape is None:
# only cache bands
@@ -328,6 +349,8 @@ def __init__(
linear_bands=linear_bands,
in_pixels=in_pixels,
ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
grid_indexing=self.grid_indexing,
)
self.bands = None
self.register_buffer(
@@ -349,6 +372,9 @@ def get_embed(self, shape: Optional[List[int]] = None):
shape,
self.bands,
in_pixels=self.in_pixels,
ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
grid_indexing=self.grid_indexing,
)
else:
return self.pos_embed_sin, self.pos_embed_cos
@@ -376,6 +402,8 @@ def __init__(
linear_bands: bool = False,
feat_shape: Optional[List[int]] = None,
ref_feat_shape: Optional[List[int]] = None,
grid_offset: float = 0.,
grid_indexing: str = 'ij',
):
super().__init__()
self.dim = dim
@@ -384,6 +412,8 @@ def __init__(
self.in_pixels = in_pixels
self.feat_shape = feat_shape
self.ref_feat_shape = ref_feat_shape
self.grid_offset = grid_offset
self.grid_indexing = grid_indexing

if feat_shape is None:
# only cache bands
@@ -414,6 +444,8 @@ def __init__(
linear_bands=linear_bands,
in_pixels=in_pixels,
ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
grid_indexing=self.grid_indexing,
)
self.bands = None
self.register_buffer(
@@ -430,6 +462,8 @@ def get_embed(self, shape: Optional[List[int]] = None):
self.bands,
in_pixels=self.in_pixels,
ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
grid_indexing=self.grid_indexing,
)
return torch.cat(embeds, -1)
elif self.pos_embed is not None:
542 changes: 448 additions & 94 deletions timm/models/eva.py

Large diffs are not rendered by default.