From 0f6e29019c16ef09d3121ca57752a9d8a38fa238 Mon Sep 17 00:00:00 2001 From: berniebear Date: Thu, 24 Apr 2025 22:39:32 +0000 Subject: [PATCH 01/15] pe integration --- timm/models/__init__.py | 1 + timm/models/pe.py | 915 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 916 insertions(+) create mode 100644 timm/models/pe.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 238c1ccca5..66b0ff857e 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -45,6 +45,7 @@ from .nest import * from .nextvit import * from .nfnet import * +from .pe import * from .pit import * from .pnasnet import * from .pvt_v2 import * diff --git a/timm/models/pe.py b/timm/models/pe.py new file mode 100644 index 0000000000..a102f83e21 --- /dev/null +++ b/timm/models/pe.py @@ -0,0 +1,915 @@ +from collections import OrderedDict +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Literal + +import torch +import torch.nn as nn +from einops import rearrange, repeat +from torch import nn, Tensor, broadcast_tensors, einsum +from torch.nn import functional as F +from torch.nn import Module, ModuleList +from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ +from torch.nn.parameter import Parameter +from torch.amp import autocast +from torch.utils.checkpoint import checkpoint + +### Import timm layers +from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \ + trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ + get_act_layer, get_norm_layer, LayerType, LayerScale + +from ._builder import build_model_with_cfg +from ._features import feature_take_indices +from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv +from ._registry import generate_default_cfgs, register_model, register_model_deprecations + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +@autocast("cuda", enabled=False) +def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2): + dtype = t.dtype + + if t.ndim == 3: + seq_len = t.shape[seq_dim] + freqs = freqs[-seq_len:] + + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + + assert ( + rot_dim <= t.shape[-1] + ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" + + t_left, t, t_right = ( + t[..., :start_index], + t[..., start_index:end_index], + t[..., end_index:], + ) + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + out = torch.cat((t_left, t, t_right), dim=-1) + + return out.type(dtype) + + +class RotaryEmbedding(Module): + def __init__( + self, + dim, + custom_freqs: Optional[Tensor] = None, + freqs_for: Union[ + Literal["lang"], Literal["pixel"], Literal["constant"] + ] = "lang", + theta=10000, + max_freq=10, + num_freqs=1, + learned_freq=False, + use_xpos=False, + xpos_scale_base=512, + interpolate_factor=1.0, + theta_rescale_factor=1.0, + seq_before_head_dim=False, + cache_if_possible=True, + ): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + self.freqs_for = freqs_for + + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + + self.cache_if_possible = cache_if_possible + + self.tmp_store("cached_freqs", None) + self.tmp_store("cached_scales", None) + + self.freqs = nn.Parameter(freqs, requires_grad=learned_freq) + + self.learned_freq = learned_freq + + # dummy for device + + self.tmp_store("dummy", torch.tensor(0)) + + # default sequence dimension + + self.seq_before_head_dim = seq_before_head_dim + self.default_seq_dim = -3 if seq_before_head_dim else -2 + + # interpolation factors + + assert interpolate_factor >= 1.0 + self.interpolate_factor = interpolate_factor + + # xpos + + self.use_xpos = use_xpos + if not use_xpos: + self.tmp_store("scale", None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + + self.scale_base = xpos_scale_base + self.tmp_store("scale", scale) + + # add apply_rotary_emb as static method + + self.apply_rotary_emb = staticmethod(apply_rotary_emb) + + @property + def device(self): + return self.dummy.device + + def tmp_store(self, key, value): + self.register_buffer(key, value, persistent=False) + + def get_seq_pos(self, seq_len, device, dtype, offset=0): + return ( + torch.arange(seq_len, device=device, dtype=dtype) + offset + ) / self.interpolate_factor + + def rotate_queries_or_keys(self, t, seq_dim=None, offset=0): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert ( + not self.use_xpos + ), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings" + + device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] + + freqs = self.forward( + self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset), + seq_len=seq_len, + offset=offset, + ) + + if seq_dim == -3: + freqs = rearrange(freqs, "n d -> n 1 d") + + return apply_rotary_emb(freqs, t, seq_dim=seq_dim) + + def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0): + seq_dim = default(seq_dim, self.default_seq_dim) + + q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] + assert q_len <= k_len + + rotated_q = self.rotate_queries_or_keys( + q, seq_dim=seq_dim, offset=k_len - q_len + offset + ) + rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def rotate_queries_and_keys(self, q, k, seq_dim=None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert self.use_xpos + device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] + + seq = self.get_seq_pos(seq_len, dtype=dtype, device=device) + + freqs = self.forward(seq, seq_len=seq_len) + scale = self.get_scale(seq, seq_len=seq_len).to(dtype) + + if seq_dim == -3: + freqs = rearrange(freqs, "n d -> n 1 d") + scale = rearrange(scale, "n d -> n 1 d") + + rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim) + rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0): + assert self.use_xpos + + should_cache = self.cache_if_possible and exists(seq_len) + + if ( + should_cache + and exists(self.cached_scales) + and (seq_len + offset) <= self.cached_scales.shape[0] + ): + return self.cached_scales[offset : (offset + seq_len)] + + scale = 1.0 + if self.use_xpos: + power = (t - len(t) // 2) / self.scale_base + scale = self.scale ** rearrange(power, "n -> n 1") + scale = torch.cat((scale, scale), dim=-1) + + if should_cache: + self.tmp_store("cached_scales", scale) + + return scale + + def get_axial_freqs(self, *dims): + Colon = slice(None) + all_freqs = [] + + for ind, dim in enumerate(dims): + if self.freqs_for == "pixel": + pos = torch.linspace(-1, 1, steps=dim, device=self.device) + else: + pos = torch.arange(dim, device=self.device) + + freqs = self.forward(pos, seq_len=dim) + + all_axis = [None] * len(dims) + all_axis[ind] = Colon + + new_axis_slice = (Ellipsis, *all_axis, Colon) + all_freqs.append(freqs[new_axis_slice]) + + all_freqs = broadcast_tensors(*all_freqs) + return torch.cat(all_freqs, dim=-1) + + @autocast("cuda", enabled=False) + def forward(self, t: Tensor, seq_len=None, offset=0): + should_cache = ( + self.cache_if_possible + and not self.learned_freq + and exists(seq_len) + and self.freqs_for != "pixel" + ) + + if ( + should_cache + and exists(self.cached_freqs) + and (offset + seq_len) <= self.cached_freqs.shape[0] + ): + return self.cached_freqs[offset : (offset + seq_len)].detach() + + freqs = self.freqs + + freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + + if should_cache: + self.tmp_store("cached_freqs", freqs.detach()) + + return freqs + + +class Rope2D: + """ Helper class to apply RoPE2D as well as interpolate on the fly. """ + + def __init__(self, dim, use_cls_token=False): + self.dim = dim + self.use_cls_token = use_cls_token + self.grid_size = None + self.freq = None + + def init_tensors(self): + self.rope = RotaryEmbedding(self.dim // 2) + + def update_grid(self, device, grid_h, grid_w): + if self.grid_size != (grid_h, grid_w): + self.grid_size = (grid_h, grid_w) + + self.rope = self.rope.to(device) + + if self.use_cls_token: + # +1 to leave space for the cls token to be (0, 0) + grid_y_range = torch.arange(grid_h, device=device) + 1 + grid_x_range = torch.arange(grid_w, device=device) + 1 + else: + grid_y_range = torch.arange(grid_h, device=device) + grid_x_range = torch.arange(grid_w, device=device) + + freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1) + freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1) + freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1) + + if self.use_cls_token: + freq = torch.cat( + [freq, torch.zeros(1, freq.shape[-1], device=device)], dim=0 + ) + + self.freq = freq[None, ...] + + self.freq = self.freq.to(device) + + def __call__(self, q, k): + # batch, heads, seq, dim = q.shape + q = apply_rotary_emb(self.freq[:, None, :, :], q) + k = apply_rotary_emb(self.freq[:, None, :, :], k) + + return q, k + + +class AttentionPooling(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + num_probe: int = 1, + mlp_ratio: int = 4, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + ): + super().__init__() + + self.embed_dim = embed_dim + self.num_heads = num_heads + + assert ( + self.embed_dim % num_heads == 0 + ), "embed_dim must be divisible by num_heads" + + self.probe = nn.Parameter(torch.randn(1, num_probe, self.embed_dim)) + self.attn = nn.MultiheadAttention( + self.embed_dim, self.num_heads, batch_first=True + ) + + self.layernorm = norm_layer(embed_dim) + self.mlp_width = int(embed_dim * mlp_ratio) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(self.embed_dim, self.mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(self.mlp_width, self.embed_dim)), + ] + ) + ) + + def forward(self, x: torch.Tensor): + batch, _, _ = x.shape + + q = self.probe.repeat((batch, 1, 1)).to(x.dtype) + x = self.attn(q, x, x, need_weights=False)[0] + x = x + self.mlp(self.layernorm(x)) + + return x + + +class SelfAttention(nn.Module): + r""" + Implements sequence packed attention and RoPe + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + rope: Optional[nn.Module] = None, + ): + super(SelfAttention, self).__init__() + self.embed_dim = embed_dim + + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + # To make this compatibile with nn.MultiHeadAttention + self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) + self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + self.rope = rope + self.scale = self.head_dim ** (-0.5) + + def init_tensors(self): + xavier_uniform_(self.in_proj_weight) + constant_(self.in_proj_bias, 0.0) + constant_(self.out_proj.bias, 0.0) + + def forward(self, x, attn_mask=None): + batch, seq, embed_dim = x.shape + proj = F.linear(x, self.in_proj_weight, self.in_proj_bias) + + # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() + proj = ( + proj.unflatten(-1, (3, embed_dim)) + .unsqueeze(0) + .transpose(0, -2) + .squeeze(-2) + .contiguous() + ) + q, k, v = proj[0], proj[1], proj[2] + + # Use "q_" so that we don't accidentally quit in pdb :) + q = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads) + k = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads) + v = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads) + + if self.rope: + q, k = self.rope(q, k) + + attn = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale + ) + attn = rearrange(attn, "b h s d -> b s (h d)") + + return F.linear(attn, self.out_proj.weight, self.out_proj.bias) + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + drop_path: float = 0.0, + rope: Optional[nn.Module] = None, + ): + super().__init__() + + if rope: + self.attn = SelfAttention(d_model, n_head, rope=rope) + else: + self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True) + + self.ls_1 = ( + LayerScale(d_model, ls_init_value) + if ls_init_value is not None + else nn.Identity() + ) + self.ls_2 = ( + LayerScale(d_model, ls_init_value) + if ls_init_value is not None + else nn.Identity() + ) + + self.ln_1 = norm_layer(d_model) + self.ln_2 = norm_layer(d_model) + + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)), + ] + ) + ) + + def _call_attn( + self, + q_x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ): + + if attn_mask is not None: + # Leave boolean masks as is + if not attn_mask.dtype == torch.bool: + attn_mask = attn_mask.to(q_x.dtype) + + if isinstance(self.attn, SelfAttention): + return self.attn(q_x, attn_mask=attn_mask) + else: + return self.attn(q_x, q_x, q_x, attn_mask=attn_mask, need_weights=False)[0] + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ): + x = x + self.drop_path1( + self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask)) + ) + x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x)))) + return x + + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + drop_path: float = 0.0, + rope: Optional[nn.Module] = None, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + drop_path=drop_path, + rope=rope, + ) + for _ in range(layers) + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def truncate(self, layer_idx: int): + """ Delete layers so the last layer is the given layer index. """ + self.layers = ((self.layers + layer_idx) % self.layers) + 1 + self.resblocks = nn.ModuleList(self.resblocks[:self.layers]) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + layer_idx: int = -1, + ): + stop_idx = (self.layers + layer_idx) % self.layers + + for i, r in enumerate(self.resblocks): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + + if i == stop_idx: + break + + return x + + +#class VisionTransformer(nn.Module): +class PE(nn.Module): + def __init__( + self, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + act_layer: Callable = nn.GELU, + norm_layer: Callable = partial(nn.LayerNorm, eps=1e-5), + use_ln_pre: bool = True, + use_ln_post: bool = True, + ls_init_value: float = None, + drop_path: float = 0.0, + image_size: int = 448, # Pretrain image size only; you can pass in any image size + use_abs_posemb: bool = True, + use_rope2d: bool = True, + use_cls_token: bool = False, + output_dim: Optional[int] = 1280, + attn_pooler_heads: int = 8, + pool_type: Literal["attn", "tok", "avg", "none"] = "attn", + num_classes: int = 1000, # no use for now + in_chans: int = 3, + ): + super().__init__() + assert pool_type in ("attn", "tok", "avg", "none") + self.pool_type = pool_type + self.patch_size = patch_size + + self.output_dim = output_dim or width + self.proj_dim = output_dim + self.heads = heads + self.width = width + self.layers = layers + + self.use_abs_posemb = use_abs_posemb + self.use_cls_token = use_cls_token + self.use_rope2d = use_rope2d + self.image_size = image_size + + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + self.rope = ( + Rope2D( + dim=width // heads, + use_cls_token=self.use_cls_token, + ) + if self.use_rope2d + else None + ) + + self.ln_pre = norm_layer(width) if use_ln_pre else nn.Identity() + self.ln_post = norm_layer(self.width) if use_ln_post else nn.Identity() + + self.transformer = Transformer( + width, + layers, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + drop_path=drop_path, + rope=self.rope, + ) + + if pool_type == "attn": + self.attn_pool = AttentionPooling( + embed_dim=width, + num_heads=attn_pooler_heads, + act_layer=act_layer, + norm_layer=norm_layer, + ) + else: + self.attn_pool = None + + self.init_tensors() + + + def init_tensors(self): + def init_submodule_tensors(module): + for name, child in module.named_children(): + if hasattr(child, "init_tensors"): + #logger.debug(f"Initializing tensors for submodule: {name}") + child.init_tensors() + init_submodule_tensors(child) + + init_submodule_tensors(self) + self.rope.init_tensors() + + # class embeddings and positional embeddings + init_scale = self.width**-0.5 + + if self.use_cls_token: + self.class_embedding = nn.Parameter(init_scale * torch.randn(self.width)) + + if self.use_abs_posemb: + self.posemb_grid_size = self.image_size // self.patch_size + self.positional_embedding = nn.Parameter( + init_scale + * torch.randn( + int(self.use_cls_token) + self.posemb_grid_size**2, self.width + ) + ) + + if self.proj_dim is not None: + self.proj = nn.Parameter( + init_scale * torch.randn(self.width, self.proj_dim) + ) + + def truncate(self, layer_idx: int): + """ Delete layers so the last layer is the given layer index. """ + self.transformer.truncate(layer_idx) + self.layers = self.transformer.layers + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.set_grad_checkpointing(enable=enable) + + def _sample_abs_posemb(self, grid_h: int, grid_w: int): + """Interpolates the absolute position embedding if necessary.""" + if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w: + return self.positional_embedding[None, ...] + + pos_embed = self.positional_embedding + if self.use_cls_token: + cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:] + + pos_embed = ( + pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1) + .permute(0, 3, 1, 2) + .contiguous() + ) + pos_embed = F.interpolate( + pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False + ) + pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous() + + if self.use_cls_token: + pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0) + + return pos_embed[None, ...] + + def _pool(self, x: torch.Tensor): + if self.pool_type == "tok": + return x[:, 0] + elif self.pool_type == "avg": + return x.mean(dim=1) + elif self.pool_type == "attn": + return self.attn_pool(x).squeeze(1) + elif self.pool_type == "none": + return x + else: + raise NotImplementedError + + def forward_features( + self, + x: torch.Tensor, + norm: bool = False, + layer_idx: int = -1, + strip_cls_token: bool = False + ): + batch, _, h, w = x.shape + grid_h, grid_w = h // self.patch_size, w // self.patch_size + + x = self.conv1(x) + x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width) + + if self.use_cls_token: + x = torch.cat( + [self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x], + dim=1, + ) + + if self.use_abs_posemb: + x = x + self._sample_abs_posemb(grid_h, grid_w) + + if self.use_rope2d: + self.rope.update_grid(x.device, grid_h, grid_w) + + x = self.ln_pre(x) + x = self.transformer(x, layer_idx=layer_idx) + + if norm: + x = self.ln_post(x) + + if strip_cls_token and self.use_cls_token: + x = x[:, 1:, :] + + return x + + def forward(self, x: torch.Tensor, **kwargs): + x = self.forward_features(x, norm=True, **kwargs) + x = self._pool(x) + + if self.proj_dim is not None: + x = x @ self.proj + + return x + + +def checkpoint_filter_fn( + state_dict: Dict[str, torch.Tensor], + model: PE, + adapt_layer_scale: bool = False, + interpolation: str = 'bicubic', + antialias: bool = True, +) -> Dict[str, torch.Tensor]: + """ convert patch embedding weight from manual patchify + linear proj to conv""" + import re + state_dict = state_dict.get('model', state_dict) + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + if any(k.startswith("visual.") for k in state_dict): + state_dict = {k.replace("visual.", ""): v for k, v in state_dict.items() if "visual" in k} + return state_dict + +def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE: + out_indices = kwargs.pop('out_indices', 3) + + return build_model_with_cfg( + PE, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + pretrained_strict=True, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), + **kwargs, + ) + + +@register_model +def pe_core_b16_224(pretrained=False, **kwargs): + model_args = dict( + image_size = 224, + patch_size = 16, + width = 768, + layers = 12, + heads = 12, + mlp_ratio = 4.0, + output_dim = 1024, + use_cls_token = True, + pool_type = 'attn', + ) + return _create_pe('pe_core_b16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + +@register_model +def pe_core_l14_336(pretrained=False, **kwargs): + model_args = dict( + image_size = 336, + patch_size = 14, + width = 1024, + layers = 24, + heads = 16, + mlp_ratio = 4.0, + output_dim = 1024, + use_cls_token = True, + pool_type = 'attn', + ) + return _create_pe('pe_core_l14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def pe_core_G14_448(pretrained=False, **kwargs): + model_args = dict( + image_size = 448, + patch_size = 14, + width = 1536, + layers = 50, + heads = 16, + mlp_ratio = 8960 / 1536, + output_dim = 1280, + use_cls_token = False, + pool_type = 'attn', + ) + return _create_pe('pe_core_G14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + +@register_model +def pe_lang_G14_448(pretrained=False, **kwargs): + model_args = dict( + image_size = 448, + patch_size = 14, + width = 1536, + layers = 47, + heads = 16, + mlp_ratio = 8960 / 1536, + output_dim = None, + use_cls_token = False, + use_ln_post = False, + pool_type = 'none', + ls_init_value = 0.1, + ) + return _create_pe('pe_lang_G14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + +@register_model +def pe_lang_l14_448(pretrained=False, **kwargs): + model_args = dict( + image_size = 448, + patch_size = 14, + width = 1024, + layers = 23, + heads = 16, + mlp_ratio = 4.0, + output_dim = None, + use_cls_token = True, + use_ln_post = False, + pool_type = 'none', + ls_init_value = 0.1, + ) + return _create_pe('pe_lang_l14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + +@register_model +def pe_spatial_G14_448(pretrained=False, **kwargs): + model_args = dict( + image_size = 448, + patch_size = 14, + width = 1536, + layers = 50, + heads = 16, + mlp_ratio = 8960 / 1536, + output_dim = None, + use_cls_token = False, + use_ln_post = False, + pool_type = 'none', + ls_init_value = 0.1, + ) + return _create_pe('pe_spatial_G14_448', pretrained=pretrained, **dict(model_args, **kwargs)) From 8eafe2c21e700646ffb04c566c26854f82a331ad Mon Sep 17 00:00:00 2001 From: berniebear Date: Fri, 25 Apr 2025 06:50:29 +0000 Subject: [PATCH 02/15] update pe to reuse timm layers --- timm/models/pe.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/timm/models/pe.py b/timm/models/pe.py index a102f83e21..147cecee7b 100644 --- a/timm/models/pe.py +++ b/timm/models/pe.py @@ -16,7 +16,8 @@ ### Import timm layers from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \ trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ - get_act_layer, get_norm_layer, LayerType, LayerScale + get_act_layer, get_norm_layer, LayerType, LayerScale +#from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -302,7 +303,6 @@ def update_grid(self, device, grid_h, grid_w): self.grid_size = (grid_h, grid_w) self.rope = self.rope.to(device) - if self.use_cls_token: # +1 to leave space for the cls token to be (0, 0) grid_y_range = torch.arange(grid_h, device=device) + 1 @@ -310,9 +310,8 @@ def update_grid(self, device, grid_h, grid_w): else: grid_y_range = torch.arange(grid_h, device=device) grid_x_range = torch.arange(grid_w, device=device) - freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1) - freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1) + freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1) freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1) if self.use_cls_token: @@ -581,7 +580,6 @@ def forward( return x -#class VisionTransformer(nn.Module): class PE(nn.Module): def __init__( self, From e6bbf9fd56dffa4e43028878bf27c1f8f5db7e81 Mon Sep 17 00:00:00 2001 From: berniebear Date: Fri, 25 Apr 2025 07:55:00 +0000 Subject: [PATCH 03/15] add default config --- timm/models/pe.py | 47 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/timm/models/pe.py b/timm/models/pe.py index 147cecee7b..717215c1e8 100644 --- a/timm/models/pe.py +++ b/timm/models/pe.py @@ -18,28 +18,30 @@ trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ get_act_layer, get_norm_layer, LayerType, LayerScale #from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from ._builder import build_model_with_cfg from ._features import feature_take_indices -from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._registry import generate_default_cfgs, register_model, register_model_deprecations +__all__ = ['PE'] + + +####### PE's Rope ######## + def exists(val): return val is not None - def default(val, d): return val if exists(val) else d - def rotate_half(x): x = rearrange(x, "... (d r) -> ... d r", r=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return rearrange(x, "... d r -> ... (d r)") - @autocast("cuda", enabled=False) def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2): dtype = t.dtype @@ -330,6 +332,7 @@ def __call__(self, q, k): return q, k +####### PE's Modules ######## class AttentionPooling(nn.Module): def __init__( @@ -801,6 +804,41 @@ def checkpoint_filter_fn( state_dict = {k.replace("visual.", ""): v for k, v in state_dict.items() if "visual" in k} return state_dict + +default_cfgs = generate_default_cfgs({ + 'pe_core_b16_224': _cfg( + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=0, + input_size=(3, 224, 224)), + 'pe_core_l14_336': _cfg( + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=0, + input_size=(3, 336, 336)), + 'pe_core_G14_448': _cfg( + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=0, + input_size=(3, 448, 448)), + 'pe_lang_l14_448': _cfg( + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=0, + input_size=(3, 448, 448)), + 'pe_lang_G14_448': _cfg( + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=0, + input_size=(3, 448, 448)), + 'pe_spatial_G14_448': _cfg( + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=0, + input_size=(3, 448, 448)), +}) + + def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE: out_indices = kwargs.pop('out_indices', 3) @@ -814,7 +852,6 @@ def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE: **kwargs, ) - @register_model def pe_core_b16_224(pretrained=False, **kwargs): model_args = dict( From 3af564f2fb3727cac106a20bb9124f8f75d3f306 Mon Sep 17 00:00:00 2001 From: berniebear Date: Fri, 25 Apr 2025 09:19:18 +0000 Subject: [PATCH 04/15] fix default config --- timm/models/pe.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/timm/models/pe.py b/timm/models/pe.py index 717215c1e8..953a990681 100644 --- a/timm/models/pe.py +++ b/timm/models/pe.py @@ -805,36 +805,35 @@ def checkpoint_filter_fn( return state_dict +def _cfg(url='', **kwargs): + return { + 'license': 'apache-2.0', + 'num_classes': 0, + 'interpolation': 'bilinear', + 'fixed_input_size': True, + 'mean': IMAGENET_INCEPTION_MEAN, + 'std': IMAGENET_INCEPTION_STD, + **kwargs + } + default_cfgs = generate_default_cfgs({ 'pe_core_b16_224': _cfg( hf_hub_id='timm/', - license='apache-2.0', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=0, input_size=(3, 224, 224)), 'pe_core_l14_336': _cfg( hf_hub_id='timm/', - license='apache-2.0', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=0, input_size=(3, 336, 336)), 'pe_core_G14_448': _cfg( hf_hub_id='timm/', - license='apache-2.0', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=0, input_size=(3, 448, 448)), 'pe_lang_l14_448': _cfg( hf_hub_id='timm/', - license='apache-2.0', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=0, input_size=(3, 448, 448)), 'pe_lang_G14_448': _cfg( hf_hub_id='timm/', - license='apache-2.0', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=0, input_size=(3, 448, 448)), 'pe_spatial_G14_448': _cfg( hf_hub_id='timm/', - license='apache-2.0', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=0, input_size=(3, 448, 448)), }) From b82f9e869cb88cfa64089e64e2a33b3f19b6b8df Mon Sep 17 00:00:00 2001 From: berniebear Date: Fri, 25 Apr 2025 21:26:34 +0000 Subject: [PATCH 05/15] renaming models and update checkpoint to vit-only --- timm/models/pe.py | 346 ++++++++++++++++++++-------------------------- 1 file changed, 147 insertions(+), 199 deletions(-) diff --git a/timm/models/pe.py b/timm/models/pe.py index 953a990681..62ebfa7bc3 100644 --- a/timm/models/pe.py +++ b/timm/models/pe.py @@ -14,10 +14,27 @@ from torch.utils.checkpoint import checkpoint ### Import timm layers -from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \ - trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ - get_act_layer, get_norm_layer, LayerType, LayerScale -#from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible +from timm.layers import ( + PatchEmbed, + Mlp, + DropPath, + AttentionPoolLatent, + RmsNorm, + PatchDropout, + SwiGLUPacked, + SwiGLU, + trunc_normal_, + lecun_normal_, + resample_patch_embed, + resample_abs_pos_embed, + use_fused_attn, + get_act_layer, + get_norm_layer, + LayerType, + LayerScale, +) + +# from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from ._builder import build_model_with_cfg @@ -28,20 +45,22 @@ __all__ = ['PE'] -####### PE's Rope ######## - +######## PE's Rope ######## def exists(val): return val is not None + def default(val, d): return val if exists(val) else d + def rotate_half(x): x = rearrange(x, "... (d r) -> ... d r", r=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return rearrange(x, "... d r -> ... (d r)") + @autocast("cuda", enabled=False) def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2): dtype = t.dtype @@ -73,9 +92,7 @@ def __init__( self, dim, custom_freqs: Optional[Tensor] = None, - freqs_for: Union[ - Literal["lang"], Literal["pixel"], Literal["constant"] - ] = "lang", + freqs_for: Union[Literal["lang"], Literal["pixel"], Literal["constant"]] = "lang", theta=10000, max_freq=10, num_freqs=1, @@ -99,9 +116,7 @@ def __init__( if exists(custom_freqs): freqs = custom_freqs elif freqs_for == "lang": - freqs = 1.0 / ( - theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) - ) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) elif freqs_for == "pixel": freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi elif freqs_for == "constant": @@ -154,9 +169,7 @@ def tmp_store(self, key, value): self.register_buffer(key, value, persistent=False) def get_seq_pos(self, seq_len, device, dtype, offset=0): - return ( - torch.arange(seq_len, device=device, dtype=dtype) + offset - ) / self.interpolate_factor + return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor def rotate_queries_or_keys(self, t, seq_dim=None, offset=0): seq_dim = default(seq_dim, self.default_seq_dim) @@ -184,9 +197,7 @@ def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0): q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] assert q_len <= k_len - rotated_q = self.rotate_queries_or_keys( - q, seq_dim=seq_dim, offset=k_len - q_len + offset - ) + rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, offset=k_len - q_len + offset) rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset) rotated_q = rotated_q.type(q.dtype) @@ -222,11 +233,7 @@ def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0): should_cache = self.cache_if_possible and exists(seq_len) - if ( - should_cache - and exists(self.cached_scales) - and (seq_len + offset) <= self.cached_scales.shape[0] - ): + if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales.shape[0]: return self.cached_scales[offset : (offset + seq_len)] scale = 1.0 @@ -264,17 +271,10 @@ def get_axial_freqs(self, *dims): @autocast("cuda", enabled=False) def forward(self, t: Tensor, seq_len=None, offset=0): should_cache = ( - self.cache_if_possible - and not self.learned_freq - and exists(seq_len) - and self.freqs_for != "pixel" + self.cache_if_possible and not self.learned_freq and exists(seq_len) and self.freqs_for != "pixel" ) - if ( - should_cache - and exists(self.cached_freqs) - and (offset + seq_len) <= self.cached_freqs.shape[0] - ): + if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs.shape[0]: return self.cached_freqs[offset : (offset + seq_len)].detach() freqs = self.freqs @@ -289,7 +289,7 @@ def forward(self, t: Tensor, seq_len=None, offset=0): class Rope2D: - """ Helper class to apply RoPE2D as well as interpolate on the fly. """ + """Helper class to apply RoPE2D as well as interpolate on the fly.""" def __init__(self, dim, use_cls_token=False): self.dim = dim @@ -313,13 +313,11 @@ def update_grid(self, device, grid_h, grid_w): grid_y_range = torch.arange(grid_h, device=device) grid_x_range = torch.arange(grid_w, device=device) freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1) - freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1) + freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1) freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1) if self.use_cls_token: - freq = torch.cat( - [freq, torch.zeros(1, freq.shape[-1], device=device)], dim=0 - ) + freq = torch.cat([freq, torch.zeros(1, freq.shape[-1], device=device)], dim=0) self.freq = freq[None, ...] @@ -332,8 +330,8 @@ def __call__(self, q, k): return q, k -####### PE's Modules ######## +######## PE Modules ######## class AttentionPooling(nn.Module): def __init__( self, @@ -349,14 +347,10 @@ def __init__( self.embed_dim = embed_dim self.num_heads = num_heads - assert ( - self.embed_dim % num_heads == 0 - ), "embed_dim must be divisible by num_heads" + assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" self.probe = nn.Parameter(torch.randn(1, num_probe, self.embed_dim)) - self.attn = nn.MultiheadAttention( - self.embed_dim, self.num_heads, batch_first=True - ) + self.attn = nn.MultiheadAttention(self.embed_dim, self.num_heads, batch_first=True) self.layernorm = norm_layer(embed_dim) self.mlp_width = int(embed_dim * mlp_ratio) @@ -396,9 +390,7 @@ def __init__( self.num_heads = num_heads self.head_dim = embed_dim // num_heads - assert ( - self.head_dim * num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" # To make this compatibile with nn.MultiHeadAttention self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) @@ -418,13 +410,7 @@ def forward(self, x, attn_mask=None): proj = F.linear(x, self.in_proj_weight, self.in_proj_bias) # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() - proj = ( - proj.unflatten(-1, (3, embed_dim)) - .unsqueeze(0) - .transpose(0, -2) - .squeeze(-2) - .contiguous() - ) + proj = proj.unflatten(-1, (3, embed_dim)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() q, k, v = proj[0], proj[1], proj[2] # Use "q_" so that we don't accidentally quit in pdb :) @@ -435,9 +421,7 @@ def forward(self, x, attn_mask=None): if self.rope: q, k = self.rope(q, k) - attn = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale - ) + attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale) attn = rearrange(attn, "b h s d -> b s (h d)") return F.linear(attn, self.out_proj.weight, self.out_proj.bias) @@ -462,16 +446,8 @@ def __init__( else: self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True) - self.ls_1 = ( - LayerScale(d_model, ls_init_value) - if ls_init_value is not None - else nn.Identity() - ) - self.ls_2 = ( - LayerScale(d_model, ls_init_value) - if ls_init_value is not None - else nn.Identity() - ) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() self.ln_1 = norm_layer(d_model) self.ln_2 = norm_layer(d_model) @@ -511,9 +487,7 @@ def forward( x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, ): - x = x + self.drop_path1( - self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask)) - ) + x = x + self.drop_path1(self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask))) x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x)))) return x @@ -558,9 +532,9 @@ def set_grad_checkpointing(self, enable=True): @torch.jit.ignore def truncate(self, layer_idx: int): - """ Delete layers so the last layer is the given layer index. """ + """Delete layers so the last layer is the given layer index.""" self.layers = ((self.layers + layer_idx) % self.layers) + 1 - self.resblocks = nn.ModuleList(self.resblocks[:self.layers]) + self.resblocks = nn.ModuleList(self.resblocks[: self.layers]) def forward( self, @@ -576,7 +550,7 @@ def forward( x = checkpoint(r, x, None, None, attn_mask) else: x = r(x, attn_mask=attn_mask) - + if i == stop_idx: break @@ -604,7 +578,7 @@ def __init__( output_dim: Optional[int] = 1280, attn_pooler_heads: int = 8, pool_type: Literal["attn", "tok", "avg", "none"] = "attn", - num_classes: int = 1000, # no use for now + num_classes: int = 1000, # no use for now in_chans: int = 3, ): super().__init__() @@ -666,12 +640,11 @@ def __init__( self.init_tensors() - def init_tensors(self): def init_submodule_tensors(module): for name, child in module.named_children(): if hasattr(child, "init_tensors"): - #logger.debug(f"Initializing tensors for submodule: {name}") + # logger.debug(f"Initializing tensors for submodule: {name}") child.init_tensors() init_submodule_tensors(child) @@ -687,19 +660,14 @@ def init_submodule_tensors(module): if self.use_abs_posemb: self.posemb_grid_size = self.image_size // self.patch_size self.positional_embedding = nn.Parameter( - init_scale - * torch.randn( - int(self.use_cls_token) + self.posemb_grid_size**2, self.width - ) + init_scale * torch.randn(int(self.use_cls_token) + self.posemb_grid_size**2, self.width) ) if self.proj_dim is not None: - self.proj = nn.Parameter( - init_scale * torch.randn(self.width, self.proj_dim) - ) + self.proj = nn.Parameter(init_scale * torch.randn(self.width, self.proj_dim)) def truncate(self, layer_idx: int): - """ Delete layers so the last layer is the given layer index. """ + """Delete layers so the last layer is the given layer index.""" self.transformer.truncate(layer_idx) self.layers = self.transformer.layers @@ -717,13 +685,9 @@ def _sample_abs_posemb(self, grid_h: int, grid_w: int): cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:] pos_embed = ( - pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1) - .permute(0, 3, 1, 2) - .contiguous() - ) - pos_embed = F.interpolate( - pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False + pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1).permute(0, 3, 1, 2).contiguous() ) + pos_embed = F.interpolate(pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False) pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous() if self.use_cls_token: @@ -743,13 +707,7 @@ def _pool(self, x: torch.Tensor): else: raise NotImplementedError - def forward_features( - self, - x: torch.Tensor, - norm: bool = False, - layer_idx: int = -1, - strip_cls_token: bool = False - ): + def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int = -1, strip_cls_token: bool = False): batch, _, h, w = x.shape grid_h, grid_w = h // self.patch_size, w // self.patch_size @@ -790,14 +748,8 @@ def forward(self, x: torch.Tensor, **kwargs): def checkpoint_filter_fn( - state_dict: Dict[str, torch.Tensor], - model: PE, - adapt_layer_scale: bool = False, - interpolation: str = 'bicubic', - antialias: bool = True, + state_dict: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: - """ convert patch embedding weight from manual patchify + linear proj to conv""" - import re state_dict = state_dict.get('model', state_dict) state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} if any(k.startswith("visual.") for k in state_dict): @@ -805,42 +757,33 @@ def checkpoint_filter_fn( return state_dict +######## PE Config ######## def _cfg(url='', **kwargs): return { 'license': 'apache-2.0', 'num_classes': 0, 'interpolation': 'bilinear', 'fixed_input_size': True, - 'mean': IMAGENET_INCEPTION_MEAN, + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, - **kwargs + **kwargs, } -default_cfgs = generate_default_cfgs({ - 'pe_core_b16_224': _cfg( - hf_hub_id='timm/', - input_size=(3, 224, 224)), - 'pe_core_l14_336': _cfg( - hf_hub_id='timm/', - input_size=(3, 336, 336)), - 'pe_core_G14_448': _cfg( - hf_hub_id='timm/', - input_size=(3, 448, 448)), - 'pe_lang_l14_448': _cfg( - hf_hub_id='timm/', - input_size=(3, 448, 448)), - 'pe_lang_G14_448': _cfg( - hf_hub_id='timm/', - input_size=(3, 448, 448)), - 'pe_spatial_G14_448': _cfg( - hf_hub_id='timm/', - input_size=(3, 448, 448)), -}) + +default_cfgs = generate_default_cfgs( + { + 'vit_pe_core_base_patch16_224': _cfg(hf_hub_id='timm/', input_size=(3, 224, 224)), + 'vit_pe_core_large_patch14_336': _cfg(hf_hub_id='timm/', input_size=(3, 336, 336)), + 'vit_pe_core_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)), + 'vit_pe_lang_large_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)), + 'vit_pe_lang_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)), + 'vit_pe_spatial_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)), + } +) def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE: out_indices = kwargs.pop('out_indices', 3) - return build_model_with_cfg( PE, variant, @@ -851,99 +794,104 @@ def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE: **kwargs, ) + @register_model -def pe_core_b16_224(pretrained=False, **kwargs): +def vit_pe_core_base_patch16_224(pretrained=False, **kwargs): model_args = dict( - image_size = 224, - patch_size = 16, - width = 768, - layers = 12, - heads = 12, - mlp_ratio = 4.0, - output_dim = 1024, - use_cls_token = True, - pool_type = 'attn', + image_size=224, + patch_size=16, + width=768, + layers=12, + heads=12, + mlp_ratio=4.0, + output_dim=1024, + use_cls_token=True, + pool_type='attn', ) - return _create_pe('pe_core_b16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_pe('vit_pe_core_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + @register_model -def pe_core_l14_336(pretrained=False, **kwargs): +def vit_pe_core_large_patch14_336(pretrained=False, **kwargs): model_args = dict( - image_size = 336, - patch_size = 14, - width = 1024, - layers = 24, - heads = 16, - mlp_ratio = 4.0, - output_dim = 1024, - use_cls_token = True, - pool_type = 'attn', + image_size=336, + patch_size=14, + width=1024, + layers=24, + heads=16, + mlp_ratio=4.0, + output_dim=1024, + use_cls_token=True, + pool_type='attn', ) - return _create_pe('pe_core_l14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_pe('vit_pe_core_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model -def pe_core_G14_448(pretrained=False, **kwargs): +def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs): model_args = dict( - image_size = 448, - patch_size = 14, - width = 1536, - layers = 50, - heads = 16, - mlp_ratio = 8960 / 1536, - output_dim = 1280, - use_cls_token = False, - pool_type = 'attn', + image_size=448, + patch_size=14, + width=1536, + layers=50, + heads=16, + mlp_ratio=8960 / 1536, + output_dim=1280, + use_cls_token=False, + pool_type='attn', ) - return _create_pe('pe_core_G14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_pe('vit_pe_core_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + @register_model -def pe_lang_G14_448(pretrained=False, **kwargs): +def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs): model_args = dict( - image_size = 448, - patch_size = 14, - width = 1536, - layers = 47, - heads = 16, - mlp_ratio = 8960 / 1536, - output_dim = None, - use_cls_token = False, - use_ln_post = False, - pool_type = 'none', - ls_init_value = 0.1, + image_size=448, + patch_size=14, + width=1024, + layers=23, + heads=16, + mlp_ratio=4.0, + output_dim=None, + use_cls_token=True, + use_ln_post=False, + pool_type='none', + ls_init_value=0.1, ) - return _create_pe('pe_lang_G14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_pe('vit_pe_lang_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + @register_model -def pe_lang_l14_448(pretrained=False, **kwargs): +def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs): model_args = dict( - image_size = 448, - patch_size = 14, - width = 1024, - layers = 23, - heads = 16, - mlp_ratio = 4.0, - output_dim = None, - use_cls_token = True, - use_ln_post = False, - pool_type = 'none', - ls_init_value = 0.1, + image_size=448, + patch_size=14, + width=1536, + layers=47, + heads=16, + mlp_ratio=8960 / 1536, + output_dim=None, + use_cls_token=False, + use_ln_post=False, + pool_type='none', + ls_init_value=0.1, ) - return _create_pe('pe_lang_l14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_pe('vit_pe_lang_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + @register_model -def pe_spatial_G14_448(pretrained=False, **kwargs): +def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs): model_args = dict( - image_size = 448, - patch_size = 14, - width = 1536, - layers = 50, - heads = 16, - mlp_ratio = 8960 / 1536, - output_dim = None, - use_cls_token = False, - use_ln_post = False, - pool_type = 'none', - ls_init_value = 0.1, + image_size=448, + patch_size=14, + width=1536, + layers=50, + heads=16, + mlp_ratio=8960 / 1536, + output_dim=None, + use_cls_token=False, + use_ln_post=False, + pool_type='none', + ls_init_value=0.1, ) - return _create_pe('pe_spatial_G14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_pe('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) From c5c437ae37db3fd174bc2140b5f6d48d82fd10ca Mon Sep 17 00:00:00 2001 From: berniebear Date: Sun, 27 Apr 2025 06:39:21 +0000 Subject: [PATCH 06/15] remove einops dependencies and reimplement with torch functions --- timm/models/pe.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/timm/models/pe.py b/timm/models/pe.py index 62ebfa7bc3..951bffb438 100644 --- a/timm/models/pe.py +++ b/timm/models/pe.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn -from einops import rearrange, repeat from torch import nn, Tensor, broadcast_tensors, einsum from torch.nn import functional as F from torch.nn import Module, ModuleList @@ -49,17 +48,14 @@ def exists(val): return val is not None - def default(val, d): return val if exists(val) else d - def rotate_half(x): - x = rearrange(x, "... (d r) -> ... d r", r=2) - x1, x2 = x.unbind(dim=-1) + x = x.view(*x.shape[:-1], -1, 2) + x1, x2 = x[..., 0], x[..., 1] x = torch.stack((-x2, x1), dim=-1) - return rearrange(x, "... d r -> ... (d r)") - + return x.view(*x.shape[:-2], -1) @autocast("cuda", enabled=False) def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2): @@ -86,7 +82,6 @@ def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2): return out.type(dtype) - class RotaryEmbedding(Module): def __init__( self, @@ -187,7 +182,7 @@ def rotate_queries_or_keys(self, t, seq_dim=None, offset=0): ) if seq_dim == -3: - freqs = rearrange(freqs, "n d -> n 1 d") + freqs = freqs.unsqueeze(1) return apply_rotary_emb(freqs, t, seq_dim=seq_dim) @@ -217,8 +212,8 @@ def rotate_queries_and_keys(self, q, k, seq_dim=None): scale = self.get_scale(seq, seq_len=seq_len).to(dtype) if seq_dim == -3: - freqs = rearrange(freqs, "n d -> n 1 d") - scale = rearrange(scale, "n d -> n 1 d") + freqs = freqs.unsqueeze(1) + scale = scale.unsqueeze(1) rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim) rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim) @@ -230,7 +225,6 @@ def rotate_queries_and_keys(self, q, k, seq_dim=None): def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0): assert self.use_xpos - should_cache = self.cache_if_possible and exists(seq_len) if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales.shape[0]: @@ -239,7 +233,7 @@ def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0): scale = 1.0 if self.use_xpos: power = (t - len(t) // 2) / self.scale_base - scale = self.scale ** rearrange(power, "n -> n 1") + scale = self.scale ** power.unsqueeze(-1) scale = torch.cat((scale, scale), dim=-1) if should_cache: @@ -280,7 +274,7 @@ def forward(self, t: Tensor, seq_len=None, offset=0): freqs = self.freqs freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs) - freqs = repeat(freqs, "... n -> ... (n r)", r=2) + freqs = freqs.repeat_interleave(2, dim=-1) if should_cache: self.tmp_store("cached_freqs", freqs.detach()) @@ -414,15 +408,15 @@ def forward(self, x, attn_mask=None): q, k, v = proj[0], proj[1], proj[2] # Use "q_" so that we don't accidentally quit in pdb :) - q = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads) - k = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads) - v = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads) + q = q.view(batch, seq, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + k = k.view(batch, seq, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + v = v.view(batch, seq, self.num_heads, self.head_dim).permute(0, 2, 1, 3) if self.rope: q, k = self.rope(q, k) attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale) - attn = rearrange(attn, "b h s d -> b s (h d)") + attn = attn.permute(0, 2, 1, 3).contiguous().view(batch, seq, -1) return F.linear(attn, self.out_proj.weight, self.out_proj.bias) From 2327ecc2e5350ad75e4f5c4931f28c65855cb880 Mon Sep 17 00:00:00 2001 From: berniebear Date: Mon, 28 Apr 2025 19:32:29 +0000 Subject: [PATCH 07/15] fix config --- timm/models/pe.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/timm/models/pe.py b/timm/models/pe.py index 951bffb438..9b7bee5619 100644 --- a/timm/models/pe.py +++ b/timm/models/pe.py @@ -281,7 +281,6 @@ def forward(self, t: Tensor, seq_len=None, offset=0): return freqs - class Rope2D: """Helper class to apply RoPE2D as well as interpolate on the fly.""" @@ -565,14 +564,14 @@ def __init__( use_ln_post: bool = True, ls_init_value: float = None, drop_path: float = 0.0, - image_size: int = 448, # Pretrain image size only; you can pass in any image size + img_size: int = 448, # Pretrain image size only; you can pass in any image size use_abs_posemb: bool = True, use_rope2d: bool = True, use_cls_token: bool = False, output_dim: Optional[int] = 1280, attn_pooler_heads: int = 8, pool_type: Literal["attn", "tok", "avg", "none"] = "attn", - num_classes: int = 1000, # no use for now + num_classes: int = 0, # no use for PE in_chans: int = 3, ): super().__init__() @@ -589,7 +588,9 @@ def __init__( self.use_abs_posemb = use_abs_posemb self.use_cls_token = use_cls_token self.use_rope2d = use_rope2d - self.image_size = image_size + if isinstance(img_size, (tuple, list)): + img_size = img_size[0] + self.img_size = img_size self.conv1 = nn.Conv2d( in_channels=3, @@ -652,7 +653,7 @@ def init_submodule_tensors(module): self.class_embedding = nn.Parameter(init_scale * torch.randn(self.width)) if self.use_abs_posemb: - self.posemb_grid_size = self.image_size // self.patch_size + self.posemb_grid_size = self.img_size // self.patch_size self.positional_embedding = nn.Parameter( init_scale * torch.randn(int(self.use_cls_token) + self.posemb_grid_size**2, self.width) ) @@ -731,8 +732,8 @@ def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int = return x - def forward(self, x: torch.Tensor, **kwargs): - x = self.forward_features(x, norm=True, **kwargs) + def forward(self, x: torch.Tensor, layer_idx: int = -1, strip_cls_token: bool = False): + x = self.forward_features(x, norm=True, layer_idx=layer_idx, strip_cls_token=strip_cls_token) x = self._pool(x) if self.proj_dim is not None: @@ -758,8 +759,8 @@ def _cfg(url='', **kwargs): 'num_classes': 0, 'interpolation': 'bilinear', 'fixed_input_size': True, - 'mean': IMAGENET_INCEPTION_MEAN, - 'std': IMAGENET_INCEPTION_STD, + 'mean': IMAGENET_INCEPTION_MEAN, # (0.5, 0.5, 0.5) + 'std': IMAGENET_INCEPTION_STD, # (0.5, 0.5, 0.5) **kwargs, } @@ -792,7 +793,7 @@ def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE: @register_model def vit_pe_core_base_patch16_224(pretrained=False, **kwargs): model_args = dict( - image_size=224, + img_size=224, patch_size=16, width=768, layers=12, @@ -808,7 +809,7 @@ def vit_pe_core_base_patch16_224(pretrained=False, **kwargs): @register_model def vit_pe_core_large_patch14_336(pretrained=False, **kwargs): model_args = dict( - image_size=336, + img_size=336, patch_size=14, width=1024, layers=24, @@ -824,7 +825,7 @@ def vit_pe_core_large_patch14_336(pretrained=False, **kwargs): @register_model def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs): model_args = dict( - image_size=448, + img_size=448, patch_size=14, width=1536, layers=50, @@ -840,7 +841,7 @@ def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs): @register_model def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs): model_args = dict( - image_size=448, + img_size=448, patch_size=14, width=1024, layers=23, @@ -858,7 +859,7 @@ def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs): @register_model def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs): model_args = dict( - image_size=448, + img_size=448, patch_size=14, width=1536, layers=47, @@ -876,7 +877,7 @@ def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs): @register_model def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs): model_args = dict( - image_size=448, + img_size=448, patch_size=14, width=1536, layers=50, From 8ecc7ca06f238225093f3618e198f9761000b986 Mon Sep 17 00:00:00 2001 From: berniebear Date: Tue, 29 Apr 2025 08:55:37 +0000 Subject: [PATCH 08/15] refactor Rope2D class for torchscript compatiability. init rope freqs in modules. --- timm/models/pe.py | 377 +++++++++++----------------------------------- 1 file changed, 89 insertions(+), 288 deletions(-) diff --git a/timm/models/pe.py b/timm/models/pe.py index 9b7bee5619..a18e6323a5 100644 --- a/timm/models/pe.py +++ b/timm/models/pe.py @@ -11,6 +11,7 @@ from torch.nn.parameter import Parameter from torch.amp import autocast from torch.utils.checkpoint import checkpoint +from torch import pi ### Import timm layers from timm.layers import ( @@ -44,283 +45,105 @@ __all__ = ['PE'] -######## PE's Rope ######## -def exists(val): - return val is not None - -def default(val, d): - return val if exists(val) else d - -def rotate_half(x): - x = x.view(*x.shape[:-1], -1, 2) - x1, x2 = x[..., 0], x[..., 1] - x = torch.stack((-x2, x1), dim=-1) - return x.view(*x.shape[:-2], -1) - -@autocast("cuda", enabled=False) -def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2): - dtype = t.dtype - - if t.ndim == 3: - seq_len = t.shape[seq_dim] - freqs = freqs[-seq_len:] - - rot_dim = freqs.shape[-1] - end_index = start_index + rot_dim - - assert ( - rot_dim <= t.shape[-1] - ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" - - t_left, t, t_right = ( - t[..., :start_index], - t[..., start_index:end_index], - t[..., end_index:], - ) - t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) - out = torch.cat((t_left, t, t_right), dim=-1) - - return out.type(dtype) - +######## PE Rope (Simplified) ######## class RotaryEmbedding(Module): def __init__( self, dim, - custom_freqs: Optional[Tensor] = None, - freqs_for: Union[Literal["lang"], Literal["pixel"], Literal["constant"]] = "lang", + freqs_for: Union[Literal["lang"], Literal["pixel"], Literal["constant"]] = "lang", theta=10000, max_freq=10, num_freqs=1, - learned_freq=False, - use_xpos=False, - xpos_scale_base=512, - interpolate_factor=1.0, + learned_freq=False, theta_rescale_factor=1.0, - seq_before_head_dim=False, - cache_if_possible=True, ): super().__init__() # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning # has some connection to NTK literature # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ - theta *= theta_rescale_factor ** (dim / (dim - 2)) - - self.freqs_for = freqs_for - - if exists(custom_freqs): - freqs = custom_freqs - elif freqs_for == "lang": + if freqs_for == "lang": freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) elif freqs_for == "pixel": freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi elif freqs_for == "constant": freqs = torch.ones(num_freqs).float() - - self.cache_if_possible = cache_if_possible - - self.tmp_store("cached_freqs", None) - self.tmp_store("cached_scales", None) - self.freqs = nn.Parameter(freqs, requires_grad=learned_freq) - self.learned_freq = learned_freq - - # dummy for device - - self.tmp_store("dummy", torch.tensor(0)) - - # default sequence dimension - - self.seq_before_head_dim = seq_before_head_dim - self.default_seq_dim = -3 if seq_before_head_dim else -2 - - # interpolation factors - - assert interpolate_factor >= 1.0 - self.interpolate_factor = interpolate_factor - - # xpos - - self.use_xpos = use_xpos - if not use_xpos: - self.tmp_store("scale", None) - return - - scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) - - self.scale_base = xpos_scale_base - self.tmp_store("scale", scale) - - # add apply_rotary_emb as static method - - self.apply_rotary_emb = staticmethod(apply_rotary_emb) - - @property - def device(self): - return self.dummy.device - - def tmp_store(self, key, value): - self.register_buffer(key, value, persistent=False) - - def get_seq_pos(self, seq_len, device, dtype, offset=0): - return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor - - def rotate_queries_or_keys(self, t, seq_dim=None, offset=0): - seq_dim = default(seq_dim, self.default_seq_dim) - - assert ( - not self.use_xpos - ), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings" - - device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] - - freqs = self.forward( - self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset), - seq_len=seq_len, - offset=offset, - ) - - if seq_dim == -3: - freqs = freqs.unsqueeze(1) - - return apply_rotary_emb(freqs, t, seq_dim=seq_dim) - - def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0): - seq_dim = default(seq_dim, self.default_seq_dim) - - q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] - assert q_len <= k_len - - rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, offset=k_len - q_len + offset) - rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset) - - rotated_q = rotated_q.type(q.dtype) - rotated_k = rotated_k.type(k.dtype) - - return rotated_q, rotated_k - - def rotate_queries_and_keys(self, q, k, seq_dim=None): - seq_dim = default(seq_dim, self.default_seq_dim) - - assert self.use_xpos - device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] - - seq = self.get_seq_pos(seq_len, dtype=dtype, device=device) - - freqs = self.forward(seq, seq_len=seq_len) - scale = self.get_scale(seq, seq_len=seq_len).to(dtype) - - if seq_dim == -3: - freqs = freqs.unsqueeze(1) - scale = scale.unsqueeze(1) - - rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim) - rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim) - - rotated_q = rotated_q.type(q.dtype) - rotated_k = rotated_k.type(k.dtype) - - return rotated_q, rotated_k - - def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0): - assert self.use_xpos - should_cache = self.cache_if_possible and exists(seq_len) - - if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales.shape[0]: - return self.cached_scales[offset : (offset + seq_len)] - - scale = 1.0 - if self.use_xpos: - power = (t - len(t) // 2) / self.scale_base - scale = self.scale ** power.unsqueeze(-1) - scale = torch.cat((scale, scale), dim=-1) - - if should_cache: - self.tmp_store("cached_scales", scale) - - return scale - - def get_axial_freqs(self, *dims): - Colon = slice(None) - all_freqs = [] - - for ind, dim in enumerate(dims): - if self.freqs_for == "pixel": - pos = torch.linspace(-1, 1, steps=dim, device=self.device) - else: - pos = torch.arange(dim, device=self.device) - - freqs = self.forward(pos, seq_len=dim) - - all_axis = [None] * len(dims) - all_axis[ind] = Colon - - new_axis_slice = (Ellipsis, *all_axis, Colon) - all_freqs.append(freqs[new_axis_slice]) - - all_freqs = broadcast_tensors(*all_freqs) - return torch.cat(all_freqs, dim=-1) - - @autocast("cuda", enabled=False) - def forward(self, t: Tensor, seq_len=None, offset=0): - should_cache = ( - self.cache_if_possible and not self.learned_freq and exists(seq_len) and self.freqs_for != "pixel" - ) - - if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs.shape[0]: - return self.cached_freqs[offset : (offset + seq_len)].detach() - + def forward(self, t: Tensor): #, seq_len=None, offset=0): freqs = self.freqs - - freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs) + freqs = t.type(freqs.dtype).unsqueeze(-1) * freqs freqs = freqs.repeat_interleave(2, dim=-1) - - if should_cache: - self.tmp_store("cached_freqs", freqs.detach()) - return freqs -class Rope2D: - """Helper class to apply RoPE2D as well as interpolate on the fly.""" - def __init__(self, dim, use_cls_token=False): +class Rope2D(Module): + def __init__(self, dim, grid_size, use_cls_token=False): + super().__init__() self.dim = dim self.use_cls_token = use_cls_token - self.grid_size = None - self.freq = None - - def init_tensors(self): + self.grid_size = grid_size self.rope = RotaryEmbedding(self.dim // 2) + self.init_tensors() + + def init_tensors(self): + self.update_grid(self.grid_size[0], self.grid_size[1]) + + def update_grid(self, grid_h, grid_w): + if self.use_cls_token: + # +1 to leave space for the cls token to be (0, 0) + grid_y_range = torch.arange(grid_h) + 1 + grid_x_range = torch.arange(grid_w) + 1 + else: + grid_y_range = torch.arange(grid_h) + grid_x_range = torch.arange(grid_w) + freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1) + freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1) + freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1) - def update_grid(self, device, grid_h, grid_w): - if self.grid_size != (grid_h, grid_w): - self.grid_size = (grid_h, grid_w) - - self.rope = self.rope.to(device) - if self.use_cls_token: - # +1 to leave space for the cls token to be (0, 0) - grid_y_range = torch.arange(grid_h, device=device) + 1 - grid_x_range = torch.arange(grid_w, device=device) + 1 - else: - grid_y_range = torch.arange(grid_h, device=device) - grid_x_range = torch.arange(grid_w, device=device) - freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1) - freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1) - freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1) + if self.use_cls_token: + freq = torch.cat([freq, torch.zeros(1, freq.shape[-1])], dim=0) + + self.freq = Parameter(freq[None, ...]) # remark: using Parameter instead of tensor for device consistency + + def rotate_half(self, x): + shape = x.shape + x = x.view(shape[:-1] + (-1, 2)) + x1, x2 = x[..., 0], x[..., 1] + x = torch.stack((-x2, x1), dim=-1) + return x.view(shape[:-1] + (-1,)) + + def apply_rotary_emb(self, freqs, t): + start_index=0 + scale=1.0 + seq_dim=-2 + dtype = t.dtype + + if t.ndim == 3: + seq_len = t.shape[seq_dim] + freqs = freqs[-seq_len:] + + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim - if self.use_cls_token: - freq = torch.cat([freq, torch.zeros(1, freq.shape[-1], device=device)], dim=0) + assert ( + rot_dim <= t.shape[-1] + ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" - self.freq = freq[None, ...] + t_left, t, t_right = ( + t[..., :start_index], + t[..., start_index:end_index], + t[..., end_index:], + ) + t = (t * freqs.cos() * scale) + (self.rotate_half(t) * freqs.sin() * scale) + out = torch.cat((t_left, t, t_right), dim=-1) - self.freq = self.freq.to(device) + return out.type(dtype) - def __call__(self, q, k): + def forward(self, q, k): # batch, heads, seq, dim = q.shape - q = apply_rotary_emb(self.freq[:, None, :, :], q) - k = apply_rotary_emb(self.freq[:, None, :, :], k) - + q = self.apply_rotary_emb(self.freq[:, None, :, :], q) + k = self.apply_rotary_emb(self.freq[:, None, :, :], k) return q, k @@ -359,11 +182,9 @@ def __init__( def forward(self, x: torch.Tensor): batch, _, _ = x.shape - q = self.probe.repeat((batch, 1, 1)).to(x.dtype) x = self.attn(q, x, x, need_weights=False)[0] x = x + self.mlp(self.layernorm(x)) - return x @@ -371,7 +192,6 @@ class SelfAttention(nn.Module): r""" Implements sequence packed attention and RoPe """ - def __init__( self, embed_dim: int, @@ -398,7 +218,10 @@ def init_tensors(self): constant_(self.in_proj_bias, 0.0) constant_(self.out_proj.bias, 0.0) - def forward(self, x, attn_mask=None): + def forward(self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ): batch, seq, embed_dim = x.shape proj = F.linear(x, self.in_proj_weight, self.in_proj_bias) @@ -411,7 +234,7 @@ def forward(self, x, attn_mask=None): k = k.view(batch, seq, self.num_heads, self.head_dim).permute(0, 2, 1, 3) v = v.view(batch, seq, self.num_heads, self.head_dim).permute(0, 2, 1, 3) - if self.rope: + if self.rope is not None: q, k = self.rope(q, k) attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale) @@ -462,9 +285,8 @@ def __init__( def _call_attn( self, q_x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None ): - if attn_mask is not None: # Leave boolean masks as is if not attn_mask.dtype == torch.bool: @@ -478,7 +300,7 @@ def _call_attn( def forward( self, x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, ): x = x + self.drop_path1(self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask))) x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x)))) @@ -532,21 +354,18 @@ def truncate(self, layer_idx: int): def forward( self, x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - layer_idx: int = -1, + attn_mask: Optional[torch.Tensor] = None, + # layer_idx=-1, #: int = -1, # torchscript emits iterations over modules as unrolled loops. so dynamic layer_idx is not supported as in orig pe ): - stop_idx = (self.layers + layer_idx) % self.layers - + #stop_idx = (self.layers + layer_idx) % self.layers for i, r in enumerate(self.resblocks): if self.grad_checkpointing and not torch.jit.is_scripting(): # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 x = checkpoint(r, x, None, None, attn_mask) else: x = r(x, attn_mask=attn_mask) - - if i == stop_idx: - break - + # if i == stop_idx: + # break return x @@ -603,6 +422,7 @@ def __init__( Rope2D( dim=width // heads, use_cls_token=self.use_cls_token, + grid_size = (img_size // patch_size, img_size // patch_size), ) if self.use_rope2d else None @@ -670,26 +490,6 @@ def truncate(self, layer_idx: int): def set_grad_checkpointing(self, enable=True): self.transformer.set_grad_checkpointing(enable=enable) - def _sample_abs_posemb(self, grid_h: int, grid_w: int): - """Interpolates the absolute position embedding if necessary.""" - if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w: - return self.positional_embedding[None, ...] - - pos_embed = self.positional_embedding - if self.use_cls_token: - cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:] - - pos_embed = ( - pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1).permute(0, 3, 1, 2).contiguous() - ) - pos_embed = F.interpolate(pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False) - pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous() - - if self.use_cls_token: - pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0) - - return pos_embed[None, ...] - def _pool(self, x: torch.Tensor): if self.pool_type == "tok": return x[:, 0] @@ -702,9 +502,10 @@ def _pool(self, x: torch.Tensor): else: raise NotImplementedError - def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int = -1, strip_cls_token: bool = False): + #def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int = -1, strip_cls_token: bool = False): + def forward_features(self, x: torch.Tensor, norm: bool = False, strip_cls_token: bool = False): + #: layer_idx = -1, # torchscript emits iterations over modules as unrolled loops. so dynamic layer_idx is not supported in timm as in orig pe batch, _, h, w = x.shape - grid_h, grid_w = h // self.patch_size, w // self.patch_size x = self.conv1(x) x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width) @@ -716,13 +517,10 @@ def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int = ) if self.use_abs_posemb: - x = x + self._sample_abs_posemb(grid_h, grid_w) - - if self.use_rope2d: - self.rope.update_grid(x.device, grid_h, grid_w) + x = x + self.positional_embedding[None, ...] x = self.ln_pre(x) - x = self.transformer(x, layer_idx=layer_idx) + x = self.transformer(x) if norm: x = self.ln_post(x) @@ -732,8 +530,8 @@ def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int = return x - def forward(self, x: torch.Tensor, layer_idx: int = -1, strip_cls_token: bool = False): - x = self.forward_features(x, norm=True, layer_idx=layer_idx, strip_cls_token=strip_cls_token) + def forward(self, x: torch.Tensor, strip_cls_token: bool = False): + x = self.forward_features(x, norm=True, strip_cls_token=strip_cls_token) x = self._pool(x) if self.proj_dim is not None: @@ -784,7 +582,9 @@ def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE: variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, - pretrained_strict=True, + pretrained_strict=False, + # Remakr: strict=False since original pretrained models don't have rope freqs in nn.modules and samples rope on-the-fly w/ dynamic grid + # torchscript/timm doesn't support dynamic grid so sample once during model init without overwritten by ckpt loading. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) @@ -890,3 +690,4 @@ def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs): ls_init_value=0.1, ) return _create_pe('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + From 9dbb47d55585b408ca1425c8389b9ccc5d54955b Mon Sep 17 00:00:00 2001 From: berniebear Date: Wed, 30 Apr 2025 02:25:29 +0000 Subject: [PATCH 09/15] fix jit.trace() issues. add classifier for timm and fix config to pass cfg/fxforward/fxbackward unitest --- tests/test_models.py | 2 +- timm/models/pe.py | 158 ++++++++++++++++++++++++++----------------- 2 files changed, 98 insertions(+), 62 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 3ba3615db4..0d3801b660 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -61,7 +61,7 @@ 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*', - 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*', + 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*', 'pe_*' ] NUM_NON_STD = len(NON_STD_FILTERS) diff --git a/timm/models/pe.py b/timm/models/pe.py index a18e6323a5..84fdc535c3 100644 --- a/timm/models/pe.py +++ b/timm/models/pe.py @@ -11,25 +11,11 @@ from torch.nn.parameter import Parameter from torch.amp import autocast from torch.utils.checkpoint import checkpoint -from torch import pi ### Import timm layers from timm.layers import ( - PatchEmbed, - Mlp, DropPath, AttentionPoolLatent, - RmsNorm, - PatchDropout, - SwiGLUPacked, - SwiGLU, - trunc_normal_, - lecun_normal_, - resample_patch_embed, - resample_abs_pos_embed, - use_fused_attn, - get_act_layer, - get_norm_layer, LayerType, LayerScale, ) @@ -40,12 +26,14 @@ from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._registry import generate_default_cfgs, register_model, register_model_deprecations +from ._features_fx import register_notrace_module __all__ = ['PE'] ######## PE Rope (Simplified) ######## +@register_notrace_module class RotaryEmbedding(Module): def __init__( self, @@ -65,18 +53,24 @@ def __init__( if freqs_for == "lang": freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) elif freqs_for == "pixel": - freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * torch.pi elif freqs_for == "constant": freqs = torch.ones(num_freqs).float() - self.freqs = nn.Parameter(freqs, requires_grad=learned_freq) + else: + assert False + if learned_freq: + self.freqs = nn.Parameter(freqs) + else: + self.freqs = nn.Buffer(freqs, persistent=False) - def forward(self, t: Tensor): #, seq_len=None, offset=0): + def forward(self, t: Tensor): freqs = self.freqs freqs = t.type(freqs.dtype).unsqueeze(-1) * freqs freqs = freqs.repeat_interleave(2, dim=-1) return freqs +@register_notrace_module class Rope2D(Module): def __init__(self, dim, grid_size, use_cls_token=False): super().__init__() @@ -103,8 +97,7 @@ def update_grid(self, grid_h, grid_w): if self.use_cls_token: freq = torch.cat([freq, torch.zeros(1, freq.shape[-1])], dim=0) - - self.freq = Parameter(freq[None, ...]) # remark: using Parameter instead of tensor for device consistency + self.freq = nn.Buffer(freq[None, ...], persistent=False) def rotate_half(self, x): shape = x.shape @@ -114,22 +107,18 @@ def rotate_half(self, x): return x.view(shape[:-1] + (-1,)) def apply_rotary_emb(self, freqs, t): - start_index=0 - scale=1.0 - seq_dim=-2 + start_index = 0 + scale = 1.0 + seq_dim = -2 dtype = t.dtype - - if t.ndim == 3: - seq_len = t.shape[seq_dim] - freqs = freqs[-seq_len:] + + # if len(t.shape) == 3: + # seq_len = t.shape[seq_dim] + # freqs = freqs[-seq_len:] rot_dim = freqs.shape[-1] end_index = start_index + rot_dim - assert ( - rot_dim <= t.shape[-1] - ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" - t_left, t, t_right = ( t[..., :start_index], t[..., start_index:end_index], @@ -387,37 +376,51 @@ def __init__( use_abs_posemb: bool = True, use_rope2d: bool = True, use_cls_token: bool = False, + use_proj: bool = True, output_dim: Optional[int] = 1280, + num_classes: int = 0, attn_pooler_heads: int = 8, pool_type: Literal["attn", "tok", "avg", "none"] = "attn", - num_classes: int = 0, # no use for PE in_chans: int = 3, ): super().__init__() assert pool_type in ("attn", "tok", "avg", "none") self.pool_type = pool_type - self.patch_size = patch_size - self.output_dim = output_dim or width - self.proj_dim = output_dim + self.patch_size = patch_size self.heads = heads self.width = width self.layers = layers + self.in_chans = in_chans + + self.num_intermediate_features = width # the dim before PE projection layer (vit output) + self.proj_dim = output_dim # the output_dim after PE projection layer + self.use_proj = use_proj + if self.use_proj: + self.head_hidden_size = self.proj_dim + self.num_features = self.proj_dim + else: + self.head_hidden_size = self.num_intermediate_features + self.num_features = self.num_intermediate_features + + self.num_classes = num_classes self.use_abs_posemb = use_abs_posemb self.use_cls_token = use_cls_token self.use_rope2d = use_rope2d + if isinstance(img_size, (tuple, list)): img_size = img_size[0] self.img_size = img_size self.conv1 = nn.Conv2d( - in_channels=3, + in_channels=in_chans, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False, ) + self.rope = ( Rope2D( dim=width // heads, @@ -478,8 +481,19 @@ def init_submodule_tensors(module): init_scale * torch.randn(int(self.use_cls_token) + self.posemb_grid_size**2, self.width) ) - if self.proj_dim is not None: + # PE's: Transfomer(x) -> pool -> proj -> head (for timm). (PE contains an additional projection layer) + if self.use_proj: self.proj = nn.Parameter(init_scale * torch.randn(self.width, self.proj_dim)) + if self.num_classes > 0: + self.head = nn.Linear(self.proj_dim, self.num_classes) + else: + self.head = nn.Identity() + else: # no projection (eg PE-lang and PE-spatial) + self.proj = nn.Identity() + if self.num_classes > 0: + self.head = nn.Linear(self.width, self.num_classes) + else: + self.head = nn.Identity() def truncate(self, layer_idx: int): """Delete layers so the last layer is the given layer index.""" @@ -490,20 +504,24 @@ def truncate(self, layer_idx: int): def set_grad_checkpointing(self, enable=True): self.transformer.set_grad_checkpointing(enable=enable) - def _pool(self, x: torch.Tensor): + def forward_pool_and_proj(self, x: torch.Tensor): if self.pool_type == "tok": - return x[:, 0] + x = x[:, 0] elif self.pool_type == "avg": - return x.mean(dim=1) + x = x.mean(dim=1) elif self.pool_type == "attn": - return self.attn_pool(x).squeeze(1) + x = self.attn_pool(x).squeeze(1) elif self.pool_type == "none": - return x - else: - raise NotImplementedError + x = x + if self.use_proj: + x = x @ self.proj + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False): + return x if pre_logits else self.head(x) #def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int = -1, strip_cls_token: bool = False): - def forward_features(self, x: torch.Tensor, norm: bool = False, strip_cls_token: bool = False): + def forward_features(self, x: torch.Tensor, norm: bool = False): #: layer_idx = -1, # torchscript emits iterations over modules as unrolled loops. so dynamic layer_idx is not supported in timm as in orig pe batch, _, h, w = x.shape @@ -521,24 +539,30 @@ def forward_features(self, x: torch.Tensor, norm: bool = False, strip_cls_token: x = self.ln_pre(x) x = self.transformer(x) - if norm: x = self.ln_post(x) - if strip_cls_token and self.use_cls_token: - x = x[:, 1:, :] + # if strip_cls_token and self.use_cls_token: + # x = x[:, 1:, :] + x = self.forward_pool_and_proj(x) return x - def forward(self, x: torch.Tensor, strip_cls_token: bool = False): - x = self.forward_features(x, norm=True, strip_cls_token=strip_cls_token) - x = self._pool(x) - - if self.proj_dim is not None: - x = x @ self.proj - + def forward(self, x: torch.Tensor): + x = self.forward_features(x, norm=True) + x = self.forward_head(x) return x + def reset_classifier(self, num_classes: int): + self.num_classes = num_classes + if num_classes > 0: + if self.proj_dim > 0: + self.head = nn.Parameter(self.proj_dim, num_classes) + else: # no projection (eg PE-lang and PE-spatial) + self.head = nn.Parameter(self.width, num_classes) + else: + self.head = nn.Identity() + def checkpoint_filter_fn( state_dict: Dict[str, torch.Tensor], @@ -559,6 +583,8 @@ def _cfg(url='', **kwargs): 'fixed_input_size': True, 'mean': IMAGENET_INCEPTION_MEAN, # (0.5, 0.5, 0.5) 'std': IMAGENET_INCEPTION_STD, # (0.5, 0.5, 0.5) + 'first_conv': 'conv1', + 'classifier': 'head', **kwargs, } @@ -582,9 +608,7 @@ def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE: variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, - pretrained_strict=False, - # Remakr: strict=False since original pretrained models don't have rope freqs in nn.modules and samples rope on-the-fly w/ dynamic grid - # torchscript/timm doesn't support dynamic grid so sample once during model init without overwritten by ckpt loading. + pretrained_strict=True, feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) @@ -600,8 +624,10 @@ def vit_pe_core_base_patch16_224(pretrained=False, **kwargs): heads=12, mlp_ratio=4.0, output_dim=1024, + num_classes=0, use_cls_token=True, pool_type='attn', + use_proj=True, ) return _create_pe('vit_pe_core_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -616,8 +642,10 @@ def vit_pe_core_large_patch14_336(pretrained=False, **kwargs): heads=16, mlp_ratio=4.0, output_dim=1024, + num_classes=0, use_cls_token=True, pool_type='attn', + use_proj=True, ) return _create_pe('vit_pe_core_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -632,8 +660,10 @@ def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs): heads=16, mlp_ratio=8960 / 1536, output_dim=1280, + num_classes=0, use_cls_token=False, pool_type='attn', + use_proj=True, ) return _create_pe('vit_pe_core_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -647,11 +677,13 @@ def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs): layers=23, heads=16, mlp_ratio=4.0, - output_dim=None, - use_cls_token=True, + output_dim=1024, + num_classes=0, + use_cls_token=False, use_ln_post=False, pool_type='none', ls_init_value=0.1, + use_proj=False, ) return _create_pe('vit_pe_lang_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -665,11 +697,13 @@ def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs): layers=47, heads=16, mlp_ratio=8960 / 1536, - output_dim=None, + output_dim=1536, + num_classes=0, use_cls_token=False, use_ln_post=False, pool_type='none', ls_init_value=0.1, + use_proj=False, ) return _create_pe('vit_pe_lang_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -683,11 +717,13 @@ def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs): layers=50, heads=16, mlp_ratio=8960 / 1536, - output_dim=None, + output_dim=1536, + num_classes=0, use_cls_token=False, use_ln_post=False, pool_type='none', ls_init_value=0.1, + use_proj=False, ) return _create_pe('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) From a04019165dde5e0ba58dfd47b6740ed667cdb06d Mon Sep 17 00:00:00 2001 From: berniebear Date: Wed, 30 Apr 2025 06:49:38 +0000 Subject: [PATCH 10/15] add forward_intermediates support --- tests/test_models.py | 5 +- timm/models/pe.py | 108 +++++++++++++++++++++++++++++++++++++++---- 2 files changed, 102 insertions(+), 11 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 0d3801b660..bfff9d626c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -53,7 +53,7 @@ 'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos', 'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2', 'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet', - 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*' + 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'pe' ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. @@ -61,7 +61,7 @@ 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*', - 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*', 'pe_*' + 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*', ] NUM_NON_STD = len(NON_STD_FILTERS) @@ -224,6 +224,7 @@ def test_model_backward(model_name, batch_size): timm.models.MobileNetV3, timm.models.RepGhostNet, timm.models.VGG, + timm.models.pe, ) @pytest.mark.cfg diff --git a/timm/models/pe.py b/timm/models/pe.py index 84fdc535c3..b70fbda0f0 100644 --- a/timm/models/pe.py +++ b/timm/models/pe.py @@ -393,15 +393,21 @@ def __init__( self.layers = layers self.in_chans = in_chans - self.num_intermediate_features = width # the dim before PE projection layer (vit output) - self.proj_dim = output_dim # the output_dim after PE projection layer + # PE contains an (optional) projection layer + # Flow: x -> Transfomer(x) -> pool -> proj -> head (for timm). + # forward_features: x -> Transfomer(x) + # forward_head: pool -> proj -> head + # output_dim is the final output dim of the model (keep it for clarity) self.use_proj = use_proj if self.use_proj: + self.proj_dim = output_dim self.head_hidden_size = self.proj_dim - self.num_features = self.proj_dim + self.num_features = width # self.proj_dim else: - self.head_hidden_size = self.num_intermediate_features - self.num_features = self.num_intermediate_features + self.proj_dim = 0 + assert output_dim == width + self.head_hidden_size = width + self.num_features = width self.num_classes = num_classes @@ -446,6 +452,9 @@ def __init__( rope=self.rope, ) + self.feature_info = [ + dict(module=f'blocks.{i}', num_chs=width, reduction=patch_size) for i in range(layers)] + if pool_type == "attn": self.attn_pool = AttentionPooling( embed_dim=width, @@ -491,7 +500,7 @@ def init_submodule_tensors(module): else: # no projection (eg PE-lang and PE-spatial) self.proj = nn.Identity() if self.num_classes > 0: - self.head = nn.Linear(self.width, self.num_classes) + self.head = nn.Linear(self.width, self.num_classes) # no proj. input dim = self.width (pooled) else: self.head = nn.Identity() @@ -518,6 +527,9 @@ def forward_pool_and_proj(self, x: torch.Tensor): return x def forward_head(self, x: torch.Tensor, pre_logits: bool = False): + # PE has an additional proj layer: Transfomer(x) -> pool -> proj -> head (for timm). + # Ideally pool To discuss with Ross where to split + x = self.forward_pool_and_proj(x) return x if pre_logits else self.head(x) #def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int = -1, strip_cls_token: bool = False): @@ -544,8 +556,6 @@ def forward_features(self, x: torch.Tensor, norm: bool = False): # if strip_cls_token and self.use_cls_token: # x = x[:, 1:, :] - - x = self.forward_pool_and_proj(x) return x def forward(self, x: torch.Tensor): @@ -563,6 +573,86 @@ def reset_classifier(self, num_classes: int): else: self.head = nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + return_prefix_tokens: bool = False, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + return_prefix_tokens: Return both prefix and spatial intermediate tokens + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + take_indices, max_index = feature_take_indices(self.layers, indices) + + # forward pass + B, _, height, width = x.shape + + + x = self.conv1(x) + x = x.permute(0, 2, 3, 1).reshape(B, -1, self.width) # NLC + + if self.use_cls_token: + x = torch.cat( + [self.class_embedding.view(1, 1, -1).expand(B, -1, -1), x], + dim=1, + ) + + if self.use_abs_posemb: + x = x + self.positional_embedding[None, ...] + + x = self.ln_pre(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.transformer.resblocks + else: + blocks = self.transformer.resblocks[:max_index + 1] + + for i, blk in enumerate(blocks): + x = blk(x) + if i in take_indices: + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(x) if norm else x) + + # process intermediates + if self.use_cls_token: + prefix_tokens = [y[:, 0] for y in intermediates] + intermediates = [y[:, 1:] for y in intermediates] + else: + prefix_tokens = None + + if reshape: + # reshape to BCHW output format + H = W = self.posemb_grid_size + intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None: + # return_prefix not support in torchscript due to poor type handling + intermediates = list(zip(intermediates, prefix_tokens)) + + if intermediates_only: + return intermediates + + x = self.ln_post(x) + + return x, intermediates + + def checkpoint_filter_fn( state_dict: Dict[str, torch.Tensor], @@ -679,7 +769,7 @@ def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs): mlp_ratio=4.0, output_dim=1024, num_classes=0, - use_cls_token=False, + use_cls_token=True, use_ln_post=False, pool_type='none', ls_init_value=0.1, From 414b775352d98900d6450112740301fa6718d904 Mon Sep 17 00:00:00 2001 From: berniebear Date: Wed, 30 Apr 2025 08:08:57 +0000 Subject: [PATCH 11/15] torchscript for L/G models at higher resolution --- timm/models/pe.py | 82 +++++++++++++++++++++-------------------------- 1 file changed, 36 insertions(+), 46 deletions(-) diff --git a/timm/models/pe.py b/timm/models/pe.py index b70fbda0f0..a495c9620e 100644 --- a/timm/models/pe.py +++ b/timm/models/pe.py @@ -246,10 +246,10 @@ def __init__( ): super().__init__() - if rope: - self.attn = SelfAttention(d_model, n_head, rope=rope) - else: - self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True) + #if rope: + self.attn = SelfAttention(d_model, n_head, rope=rope) + #else: + # self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True) self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() @@ -281,10 +281,10 @@ def _call_attn( if not attn_mask.dtype == torch.bool: attn_mask = attn_mask.to(q_x.dtype) - if isinstance(self.attn, SelfAttention): - return self.attn(q_x, attn_mask=attn_mask) - else: - return self.attn(q_x, q_x, q_x, attn_mask=attn_mask, need_weights=False)[0] + #if isinstance(self.attn, SelfAttention): + return self.attn(q_x, attn_mask=attn_mask) + #else: + # return self.attn(q_x, q_x, q_x, attn_mask=attn_mask, need_weights=False)[0] def forward( self, @@ -380,19 +380,16 @@ def __init__( output_dim: Optional[int] = 1280, num_classes: int = 0, attn_pooler_heads: int = 8, - pool_type: Literal["attn", "tok", "avg", "none"] = "attn", + use_attn_pool: bool = True, in_chans: int = 3, ): super().__init__() - assert pool_type in ("attn", "tok", "avg", "none") - self.pool_type = pool_type - self.patch_size = patch_size self.heads = heads self.width = width self.layers = layers self.in_chans = in_chans - + # PE contains an (optional) projection layer # Flow: x -> Transfomer(x) -> pool -> proj -> head (for timm). # forward_features: x -> Transfomer(x) @@ -418,6 +415,7 @@ def __init__( if isinstance(img_size, (tuple, list)): img_size = img_size[0] self.img_size = img_size + self.grid_size = self.img_size // self.patch_size self.conv1 = nn.Conv2d( in_channels=in_chans, @@ -455,7 +453,7 @@ def __init__( self.feature_info = [ dict(module=f'blocks.{i}', num_chs=width, reduction=patch_size) for i in range(layers)] - if pool_type == "attn": + if use_attn_pool: self.attn_pool = AttentionPooling( embed_dim=width, num_heads=attn_pooler_heads, @@ -483,12 +481,15 @@ def init_submodule_tensors(module): if self.use_cls_token: self.class_embedding = nn.Parameter(init_scale * torch.randn(self.width)) + else: + self.class_embedding = None if self.use_abs_posemb: - self.posemb_grid_size = self.img_size // self.patch_size self.positional_embedding = nn.Parameter( - init_scale * torch.randn(int(self.use_cls_token) + self.posemb_grid_size**2, self.width) + init_scale * torch.randn(int(self.use_cls_token) + self.grid_size**2, self.width) ) + else: + self.positional_embedding = None # PE's: Transfomer(x) -> pool -> proj -> head (for timm). (PE contains an additional projection layer) if self.use_proj: @@ -498,7 +499,7 @@ def init_submodule_tensors(module): else: self.head = nn.Identity() else: # no projection (eg PE-lang and PE-spatial) - self.proj = nn.Identity() + self.proj = None if self.num_classes > 0: self.head = nn.Linear(self.width, self.num_classes) # no proj. input dim = self.width (pooled) else: @@ -514,15 +515,9 @@ def set_grad_checkpointing(self, enable=True): self.transformer.set_grad_checkpointing(enable=enable) def forward_pool_and_proj(self, x: torch.Tensor): - if self.pool_type == "tok": - x = x[:, 0] - elif self.pool_type == "avg": - x = x.mean(dim=1) - elif self.pool_type == "attn": + if self.attn_pool is not None: x = self.attn_pool(x).squeeze(1) - elif self.pool_type == "none": - x = x - if self.use_proj: + if self.proj is not None: x = x @ self.proj return x @@ -532,21 +527,19 @@ def forward_head(self, x: torch.Tensor, pre_logits: bool = False): x = self.forward_pool_and_proj(x) return x if pre_logits else self.head(x) - #def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int = -1, strip_cls_token: bool = False): def forward_features(self, x: torch.Tensor, norm: bool = False): - #: layer_idx = -1, # torchscript emits iterations over modules as unrolled loops. so dynamic layer_idx is not supported in timm as in orig pe batch, _, h, w = x.shape x = self.conv1(x) x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width) - if self.use_cls_token: + if self.class_embedding is not None: x = torch.cat( [self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x], dim=1, ) - - if self.use_abs_posemb: + + if self.positional_embedding is not None: x = x + self.positional_embedding[None, ...] x = self.ln_pre(x) @@ -554,8 +547,6 @@ def forward_features(self, x: torch.Tensor, norm: bool = False): if norm: x = self.ln_post(x) - # if strip_cls_token and self.use_cls_token: - # x = x[:, 1:, :] return x def forward(self, x: torch.Tensor): @@ -566,7 +557,7 @@ def forward(self, x: torch.Tensor): def reset_classifier(self, num_classes: int): self.num_classes = num_classes if num_classes > 0: - if self.proj_dim > 0: + if self.proj is not None: self.head = nn.Parameter(self.proj_dim, num_classes) else: # no projection (eg PE-lang and PE-spatial) self.head = nn.Parameter(self.width, num_classes) @@ -603,18 +594,17 @@ def forward_intermediates( # forward pass B, _, height, width = x.shape - - + # patch embedgging x = self.conv1(x) x = x.permute(0, 2, 3, 1).reshape(B, -1, self.width) # NLC - if self.use_cls_token: + if self.class_embedding is not None: x = torch.cat( [self.class_embedding.view(1, 1, -1).expand(B, -1, -1), x], dim=1, ) - if self.use_abs_posemb: + if self.positional_embedding is not None: x = x + self.positional_embedding[None, ...] x = self.ln_pre(x) @@ -631,15 +621,15 @@ def forward_intermediates( intermediates.append(self.norm(x) if norm else x) # process intermediates - if self.use_cls_token: - prefix_tokens = [y[:, 0] for y in intermediates] + if self.class_embedding is not None: + prefix_tokens = [y[:, 0] for y in intermediates] # only one cls token in PE intermediates = [y[:, 1:] for y in intermediates] else: prefix_tokens = None if reshape: # reshape to BCHW output format - H = W = self.posemb_grid_size + H = W = self.grid_size intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None: # return_prefix not support in torchscript due to poor type handling @@ -716,7 +706,7 @@ def vit_pe_core_base_patch16_224(pretrained=False, **kwargs): output_dim=1024, num_classes=0, use_cls_token=True, - pool_type='attn', + use_attn_pool=True, use_proj=True, ) return _create_pe('vit_pe_core_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -734,7 +724,7 @@ def vit_pe_core_large_patch14_336(pretrained=False, **kwargs): output_dim=1024, num_classes=0, use_cls_token=True, - pool_type='attn', + use_attn_pool=True, use_proj=True, ) return _create_pe('vit_pe_core_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -752,7 +742,7 @@ def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs): output_dim=1280, num_classes=0, use_cls_token=False, - pool_type='attn', + use_attn_pool=True, use_proj=True, ) return _create_pe('vit_pe_core_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -771,7 +761,7 @@ def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs): num_classes=0, use_cls_token=True, use_ln_post=False, - pool_type='none', + use_attn_pool=False, ls_init_value=0.1, use_proj=False, ) @@ -791,7 +781,7 @@ def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs): num_classes=0, use_cls_token=False, use_ln_post=False, - pool_type='none', + use_attn_pool=False, ls_init_value=0.1, use_proj=False, ) @@ -811,7 +801,7 @@ def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs): num_classes=0, use_cls_token=False, use_ln_post=False, - pool_type='none', + use_attn_pool=False, ls_init_value=0.1, use_proj=False, ) From 89d348df22c9b707354fca0fa2a8f308534a257c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 30 Apr 2025 15:30:29 -0700 Subject: [PATCH 12/15] PE model working with timm train script, fix nn.Buffer -> register_buffer, add drop_rate arg --- timm/models/pe.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/timm/models/pe.py b/timm/models/pe.py index a495c9620e..8d9f9bfd0a 100644 --- a/timm/models/pe.py +++ b/timm/models/pe.py @@ -61,7 +61,7 @@ def __init__( if learned_freq: self.freqs = nn.Parameter(freqs) else: - self.freqs = nn.Buffer(freqs, persistent=False) + self.register_buffer('freqs', freqs, persistent=False) def forward(self, t: Tensor): freqs = self.freqs @@ -97,7 +97,7 @@ def update_grid(self, grid_h, grid_w): if self.use_cls_token: freq = torch.cat([freq, torch.zeros(1, freq.shape[-1])], dim=0) - self.freq = nn.Buffer(freq[None, ...], persistent=False) + self.register_buffer('freq', freq[None, ...], persistent=False) def rotate_half(self, x): shape = x.shape @@ -382,6 +382,7 @@ def __init__( attn_pooler_heads: int = 8, use_attn_pool: bool = True, in_chans: int = 3, + drop_rate: float = 0., # Expected to be here, TODO add a final drop layer once head finalized ): super().__init__() self.patch_size = patch_size @@ -389,6 +390,8 @@ def __init__( self.width = width self.layers = layers self.in_chans = in_chans + self.num_classes = num_classes + self.drop_rate = drop_rate # PE contains an (optional) projection layer # Flow: x -> Transfomer(x) -> pool -> proj -> head (for timm). @@ -494,16 +497,13 @@ def init_submodule_tensors(module): # PE's: Transfomer(x) -> pool -> proj -> head (for timm). (PE contains an additional projection layer) if self.use_proj: self.proj = nn.Parameter(init_scale * torch.randn(self.width, self.proj_dim)) - if self.num_classes > 0: - self.head = nn.Linear(self.proj_dim, self.num_classes) - else: - self.head = nn.Identity() else: # no projection (eg PE-lang and PE-spatial) self.proj = None - if self.num_classes > 0: - self.head = nn.Linear(self.width, self.num_classes) # no proj. input dim = self.width (pooled) - else: - self.head = nn.Identity() + + if self.num_classes > 0: + self.head = nn.Linear(self.head_hidden_size, self.num_classes) # no proj. input dim = self.width (pooled) + else: + self.head = nn.Identity() def truncate(self, layer_idx: int): """Delete layers so the last layer is the given layer index.""" @@ -671,7 +671,8 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs( { - 'vit_pe_core_base_patch16_224': _cfg(hf_hub_id='timm/', input_size=(3, 224, 224)), + # TODO finalize locations + 'vit_pe_core_base_patch16_224': _cfg(hf_hub_id='facebook/pe_core_base_patch16_224_timm', input_size=(3, 224, 224)), 'vit_pe_core_large_patch14_336': _cfg(hf_hub_id='timm/', input_size=(3, 336, 336)), 'vit_pe_core_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)), 'vit_pe_lang_large_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)), From 936d20ea766b4db18df888592f80b295e72310a5 Mon Sep 17 00:00:00 2001 From: berniebear Date: Thu, 1 May 2025 11:18:00 +0000 Subject: [PATCH 13/15] reuse fused_attn from timm, add activation between proj and cls_head --- timm/models/pe.py | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/timm/models/pe.py b/timm/models/pe.py index 8d9f9bfd0a..0e080a49ff 100644 --- a/timm/models/pe.py +++ b/timm/models/pe.py @@ -11,6 +11,8 @@ from torch.nn.parameter import Parameter from torch.amp import autocast from torch.utils.checkpoint import checkpoint +from torch.jit import Final + ### Import timm layers from timm.layers import ( @@ -18,6 +20,7 @@ AttentionPoolLatent, LayerType, LayerScale, + use_fused_attn, ) # from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible @@ -70,6 +73,7 @@ def forward(self, t: Tensor): return freqs + @register_notrace_module class Rope2D(Module): def __init__(self, dim, grid_size, use_cls_token=False): @@ -181,6 +185,8 @@ class SelfAttention(nn.Module): r""" Implements sequence packed attention and RoPe """ + fused_attn: Final[bool] + def __init__( self, embed_dim: int, @@ -201,12 +207,14 @@ def __init__( self.rope = rope self.scale = self.head_dim ** (-0.5) + self.fused_attn = use_fused_attn() def init_tensors(self): xavier_uniform_(self.in_proj_weight) constant_(self.in_proj_bias, 0.0) constant_(self.out_proj.bias, 0.0) + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, @@ -226,12 +234,21 @@ def forward(self, if self.rope is not None: q, k = self.rope(q, k) - attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale) + if self.fused_attn: + attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = attn @ v + attn = attn.permute(0, 2, 1, 3).contiguous().view(batch, seq, -1) return F.linear(attn, self.out_proj.weight, self.out_proj.bias) + + class ResidualAttentionBlock(nn.Module): def __init__( self, @@ -246,10 +263,7 @@ def __init__( ): super().__init__() - #if rope: self.attn = SelfAttention(d_model, n_head, rope=rope) - #else: - # self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True) self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() @@ -281,10 +295,7 @@ def _call_attn( if not attn_mask.dtype == torch.bool: attn_mask = attn_mask.to(q_x.dtype) - #if isinstance(self.attn, SelfAttention): return self.attn(q_x, attn_mask=attn_mask) - #else: - # return self.attn(q_x, q_x, q_x, attn_mask=attn_mask, need_weights=False)[0] def forward( self, @@ -392,6 +403,7 @@ def __init__( self.in_chans = in_chans self.num_classes = num_classes self.drop_rate = drop_rate + self.emb_dim = width # PE contains an (optional) projection layer # Flow: x -> Transfomer(x) -> pool -> proj -> head (for timm). @@ -410,6 +422,7 @@ def __init__( self.num_features = width self.num_classes = num_classes + self.output_dim = output_dim self.use_abs_posemb = use_abs_posemb self.use_cls_token = use_cls_token @@ -466,6 +479,7 @@ def __init__( else: self.attn_pool = None + self.act_layer_cfg = act_layer self.init_tensors() def init_tensors(self): @@ -523,8 +537,10 @@ def forward_pool_and_proj(self, x: torch.Tensor): def forward_head(self, x: torch.Tensor, pre_logits: bool = False): # PE has an additional proj layer: Transfomer(x) -> pool -> proj -> head (for timm). - # Ideally pool To discuss with Ross where to split + # To discuss with Ross where to split x = self.forward_pool_and_proj(x) + if self.head_act_layer is not None: + x = self.head_act_layer(x) return x if pre_logits else self.head(x) def forward_features(self, x: torch.Tensor, norm: bool = False): @@ -806,5 +822,4 @@ def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs): ls_init_value=0.1, use_proj=False, ) - return _create_pe('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) - + return _create_pe('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) \ No newline at end of file From 4746845543021d9baf81efd1a6563a830c7bb9fb Mon Sep 17 00:00:00 2001 From: berniebear Date: Thu, 1 May 2025 11:57:38 +0000 Subject: [PATCH 14/15] fix register, implement all proj and cls_head configurations for model init/reset, fix fused_attn --- timm/models/pe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/pe.py b/timm/models/pe.py index 0e080a49ff..0d9a7baec0 100644 --- a/timm/models/pe.py +++ b/timm/models/pe.py @@ -479,7 +479,7 @@ def __init__( else: self.attn_pool = None - self.act_layer_cfg = act_layer + self.head_act_layer = None # =act_layer if to add an additional activation between fc1(proj) and fc2(head) self.init_tensors() def init_tensors(self): From 2ba13ee1bbab8556a49806ba441caf3e1b813219 Mon Sep 17 00:00:00 2001 From: berniebear Date: Fri, 2 May 2025 07:26:48 +0000 Subject: [PATCH 15/15] fix rope bug --- timm/models/pe.py | 116 +++++++++++++++++++++------------------------- 1 file changed, 54 insertions(+), 62 deletions(-) diff --git a/timm/models/pe.py b/timm/models/pe.py index 0d9a7baec0..121630009d 100644 --- a/timm/models/pe.py +++ b/timm/models/pe.py @@ -41,11 +41,11 @@ class RotaryEmbedding(Module): def __init__( self, dim, - freqs_for: Union[Literal["lang"], Literal["pixel"], Literal["constant"]] = "lang", + freqs_for: Union[Literal["lang"], Literal["pixel"], Literal["constant"]] = "lang", theta=10000, max_freq=10, num_freqs=1, - learned_freq=False, + learned_freq=False, theta_rescale_factor=1.0, ): super().__init__() @@ -73,7 +73,6 @@ def forward(self, t: Tensor): return freqs - @register_notrace_module class Rope2D(Module): def __init__(self, dim, grid_size, use_cls_token=False): @@ -83,10 +82,10 @@ def __init__(self, dim, grid_size, use_cls_token=False): self.grid_size = grid_size self.rope = RotaryEmbedding(self.dim // 2) self.init_tensors() - + def init_tensors(self): self.update_grid(self.grid_size[0], self.grid_size[1]) - + def update_grid(self, grid_h, grid_w): if self.use_cls_token: # +1 to leave space for the cls token to be (0, 0) @@ -100,22 +99,22 @@ def update_grid(self, grid_h, grid_w): freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1) if self.use_cls_token: - freq = torch.cat([freq, torch.zeros(1, freq.shape[-1])], dim=0) + freq = torch.cat([torch.zeros(1, freq.shape[-1]), freq], dim=0) self.register_buffer('freq', freq[None, ...], persistent=False) def rotate_half(self, x): - shape = x.shape + shape = x.shape x = x.view(shape[:-1] + (-1, 2)) x1, x2 = x[..., 0], x[..., 1] x = torch.stack((-x2, x1), dim=-1) return x.view(shape[:-1] + (-1,)) - + def apply_rotary_emb(self, freqs, t): start_index = 0 scale = 1.0 seq_dim = -2 dtype = t.dtype - + # if len(t.shape) == 3: # seq_len = t.shape[seq_dim] # freqs = freqs[-seq_len:] @@ -185,6 +184,7 @@ class SelfAttention(nn.Module): r""" Implements sequence packed attention and RoPe """ + fused_attn: Final[bool] def __init__( @@ -214,11 +214,11 @@ def init_tensors(self): constant_(self.in_proj_bias, 0.0) constant_(self.out_proj.bias, 0.0) - - def forward(self, - x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - ): + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ): batch, seq, embed_dim = x.shape proj = F.linear(x, self.in_proj_weight, self.in_proj_bias) @@ -235,7 +235,9 @@ def forward(self, q, k = self.rope(q, k) if self.fused_attn: - attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale) + attn = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale + ) else: q = q * self.scale attn = q @ k.transpose(-2, -1) @@ -247,8 +249,6 @@ def forward(self, return F.linear(attn, self.out_proj.weight, self.out_proj.bias) - - class ResidualAttentionBlock(nn.Module): def __init__( self, @@ -285,11 +285,7 @@ def __init__( ) ) - def _call_attn( - self, - q_x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None - ): + def _call_attn(self, q_x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): if attn_mask is not None: # Leave boolean masks as is if not attn_mask.dtype == torch.bool: @@ -300,7 +296,7 @@ def _call_attn( def forward( self, x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, ): x = x + self.drop_path1(self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask))) x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x)))) @@ -354,18 +350,14 @@ def truncate(self, layer_idx: int): def forward( self, x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - # layer_idx=-1, #: int = -1, # torchscript emits iterations over modules as unrolled loops. so dynamic layer_idx is not supported as in orig pe + attn_mask: Optional[torch.Tensor] = None, ): - #stop_idx = (self.layers + layer_idx) % self.layers for i, r in enumerate(self.resblocks): if self.grad_checkpointing and not torch.jit.is_scripting(): # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 x = checkpoint(r, x, None, None, attn_mask) else: x = r(x, attn_mask=attn_mask) - # if i == stop_idx: - # break return x @@ -389,11 +381,11 @@ def __init__( use_cls_token: bool = False, use_proj: bool = True, output_dim: Optional[int] = 1280, - num_classes: int = 0, + num_classes: int = 0, attn_pooler_heads: int = 8, use_attn_pool: bool = True, in_chans: int = 3, - drop_rate: float = 0., # Expected to be here, TODO add a final drop layer once head finalized + drop_rate: float = 0.0, # Expected to be here, TODO add a final drop layer once head finalized ): super().__init__() self.patch_size = patch_size @@ -404,7 +396,7 @@ def __init__( self.num_classes = num_classes self.drop_rate = drop_rate self.emb_dim = width - + # PE contains an (optional) projection layer # Flow: x -> Transfomer(x) -> pool -> proj -> head (for timm). # forward_features: x -> Transfomer(x) @@ -414,10 +406,10 @@ def __init__( if self.use_proj: self.proj_dim = output_dim self.head_hidden_size = self.proj_dim - self.num_features = width # self.proj_dim + self.num_features = width # self.proj_dim else: - self.proj_dim = 0 - assert output_dim == width + self.proj_dim = 0 + assert output_dim == width self.head_hidden_size = width self.num_features = width @@ -445,7 +437,7 @@ def __init__( Rope2D( dim=width // heads, use_cls_token=self.use_cls_token, - grid_size = (img_size // patch_size, img_size // patch_size), + grid_size=(img_size // patch_size, img_size // patch_size), ) if self.use_rope2d else None @@ -466,8 +458,7 @@ def __init__( rope=self.rope, ) - self.feature_info = [ - dict(module=f'blocks.{i}', num_chs=width, reduction=patch_size) for i in range(layers)] + self.feature_info = [dict(module=f'blocks.{i}', num_chs=width, reduction=patch_size) for i in range(layers)] if use_attn_pool: self.attn_pool = AttentionPooling( @@ -479,7 +470,7 @@ def __init__( else: self.attn_pool = None - self.head_act_layer = None # =act_layer if to add an additional activation between fc1(proj) and fc2(head) + self.head_act_layer = None # =act_layer if to add an additional activation between fc1(proj) and fc2(head) self.init_tensors() def init_tensors(self): @@ -511,11 +502,11 @@ def init_submodule_tensors(module): # PE's: Transfomer(x) -> pool -> proj -> head (for timm). (PE contains an additional projection layer) if self.use_proj: self.proj = nn.Parameter(init_scale * torch.randn(self.width, self.proj_dim)) - else: # no projection (eg PE-lang and PE-spatial) + else: # no projection (eg PE-lang and PE-spatial) self.proj = None if self.num_classes > 0: - self.head = nn.Linear(self.head_hidden_size, self.num_classes) # no proj. input dim = self.width (pooled) + self.head = nn.Linear(self.head_hidden_size, self.num_classes) # no proj. input dim = self.width (pooled) else: self.head = nn.Identity() @@ -536,8 +527,8 @@ def forward_pool_and_proj(self, x: torch.Tensor): return x def forward_head(self, x: torch.Tensor, pre_logits: bool = False): - # PE has an additional proj layer: Transfomer(x) -> pool -> proj -> head (for timm). - # To discuss with Ross where to split + # PE has an additional proj layer: Transfomer(x) -> pool -> proj -> head (for timm). + # To discuss with Ross where to split x = self.forward_pool_and_proj(x) if self.head_act_layer is not None: x = self.head_act_layer(x) @@ -554,7 +545,7 @@ def forward_features(self, x: torch.Tensor, norm: bool = False): [self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x], dim=1, ) - + if self.positional_embedding is not None: x = x + self.positional_embedding[None, ...] @@ -575,22 +566,22 @@ def reset_classifier(self, num_classes: int): if num_classes > 0: if self.proj is not None: self.head = nn.Parameter(self.proj_dim, num_classes) - else: # no projection (eg PE-lang and PE-spatial) + else: # no projection (eg PE-lang and PE-spatial) self.head = nn.Parameter(self.width, num_classes) else: self.head = nn.Identity() def forward_intermediates( - self, - x: torch.Tensor, - indices: Optional[Union[int, List[int]]] = None, - return_prefix_tokens: bool = False, - norm: bool = False, - stop_early: bool = False, - output_fmt: str = 'NCHW', - intermediates_only: bool = False, + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + return_prefix_tokens: bool = False, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: - """ Forward features that returns intermediates. + """Forward features that returns intermediates. Args: x: Input image tensor @@ -612,7 +603,7 @@ def forward_intermediates( B, _, height, width = x.shape # patch embedgging x = self.conv1(x) - x = x.permute(0, 2, 3, 1).reshape(B, -1, self.width) # NLC + x = x.permute(0, 2, 3, 1).reshape(B, -1, self.width) # NLC if self.class_embedding is not None: x = torch.cat( @@ -628,7 +619,7 @@ def forward_intermediates( if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript blocks = self.transformer.resblocks else: - blocks = self.transformer.resblocks[:max_index + 1] + blocks = self.transformer.resblocks[: max_index + 1] for i, blk in enumerate(blocks): x = blk(x) @@ -638,7 +629,7 @@ def forward_intermediates( # process intermediates if self.class_embedding is not None: - prefix_tokens = [y[:, 0] for y in intermediates] # only one cls token in PE + prefix_tokens = [y[:, 0] for y in intermediates] # only one cls token in PE intermediates = [y[:, 1:] for y in intermediates] else: prefix_tokens = None @@ -657,7 +648,6 @@ def forward_intermediates( x = self.ln_post(x) return x, intermediates - def checkpoint_filter_fn( @@ -677,10 +667,10 @@ def _cfg(url='', **kwargs): 'num_classes': 0, 'interpolation': 'bilinear', 'fixed_input_size': True, - 'mean': IMAGENET_INCEPTION_MEAN, # (0.5, 0.5, 0.5) - 'std': IMAGENET_INCEPTION_STD, # (0.5, 0.5, 0.5) + 'mean': IMAGENET_INCEPTION_MEAN, # (0.5, 0.5, 0.5) + 'std': IMAGENET_INCEPTION_STD, # (0.5, 0.5, 0.5) 'first_conv': 'conv1', - 'classifier': 'head', + 'classifier': 'head', **kwargs, } @@ -688,7 +678,9 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs( { # TODO finalize locations - 'vit_pe_core_base_patch16_224': _cfg(hf_hub_id='facebook/pe_core_base_patch16_224_timm', input_size=(3, 224, 224)), + 'vit_pe_core_base_patch16_224': _cfg( + hf_hub_id='facebook/pe_core_base_patch16_224_timm', input_size=(3, 224, 224) + ), 'vit_pe_core_large_patch14_336': _cfg(hf_hub_id='timm/', input_size=(3, 336, 336)), 'vit_pe_core_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)), 'vit_pe_lang_large_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)), @@ -822,4 +814,4 @@ def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs): ls_init_value=0.1, use_proj=False, ) - return _create_pe('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) \ No newline at end of file + return _create_pe('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))