Skip to content

Commit 97341fe

Browse files
committed
A much faster resample_patch_embed, can be used at train/validation time
1 parent b4bb0f4 commit 97341fe

File tree

1 file changed

+188
-2
lines changed

1 file changed

+188
-2
lines changed

timm/layers/patch_embed.py

Lines changed: 188 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"""
1111
import logging
1212
import math
13-
from typing import Callable, List, Optional, Tuple, Union
13+
from typing import Callable, Dict, List, Optional, Tuple, Union
1414

1515
import torch
1616
from torch import nn as nn
@@ -180,7 +180,8 @@ def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
180180
return x, feat_size
181181

182182

183-
def resample_patch_embed(
183+
# FIXME to remove, keeping for comparison for now
184+
def resample_patch_embed_old(
184185
patch_embed,
185186
new_size: List[int],
186187
interpolation: str = 'bicubic',
@@ -250,6 +251,191 @@ def resample_kernel(kernel):
250251
return patch_embed
251252

252253

254+
DTYPE_INTERMEDIATE = torch.float32
255+
256+
257+
def _compute_resize_matrix(
258+
old_size: Tuple[int, int],
259+
new_size: Tuple[int, int],
260+
interpolation: str,
261+
antialias: bool,
262+
device: torch.device,
263+
dtype: torch.dtype = DTYPE_INTERMEDIATE
264+
) -> torch.Tensor:
265+
"""Computes the resize matrix basis vectors and interpolates them to new_size."""
266+
old_h, old_w = old_size
267+
new_h, new_w = new_size
268+
old_total = old_h * old_w
269+
new_total = new_h * new_w
270+
271+
eye_matrix = torch.eye(old_total, device=device, dtype=dtype)
272+
basis_vectors_batch = eye_matrix.reshape(old_total, 1, old_h, old_w)
273+
274+
resized_basis_vectors_batch = F.interpolate(
275+
basis_vectors_batch,
276+
size=new_size,
277+
mode=interpolation,
278+
antialias=antialias,
279+
align_corners=False
280+
) # Output shape: (old_total, 1, new_h, new_w)
281+
282+
resize_matrix = resized_basis_vectors_batch.squeeze(1).reshape(old_total, new_total).T
283+
return resize_matrix # Shape: (new_total, old_total)
284+
285+
286+
def _compute_pinv_for_resampling(resize_matrix: torch.Tensor) -> torch.Tensor:
287+
"""Calculates the pseudoinverse matrix used for the resampling operation."""
288+
pinv_matrix = torch.linalg.pinv(resize_matrix.T) # Shape: (new_total, old_total)
289+
return pinv_matrix
290+
291+
292+
def _apply_resampling(
293+
patch_embed: torch.Tensor,
294+
pinv_matrix: torch.Tensor,
295+
new_size_tuple: Tuple[int, int],
296+
orig_dtype: torch.dtype,
297+
intermediate_dtype: torch.dtype = DTYPE_INTERMEDIATE
298+
) -> torch.Tensor:
299+
"""Applies the precomputed pinv_matrix to resample the patch_embed tensor."""
300+
try:
301+
from torch import vmap
302+
except ImportError:
303+
from functorch import vmap
304+
305+
def resample_kernel(kernel: torch.Tensor) -> torch.Tensor:
306+
kernel_flat = kernel.reshape(-1).to(intermediate_dtype)
307+
resampled_kernel_flat = pinv_matrix @ kernel_flat
308+
return resampled_kernel_flat.reshape(new_size_tuple)
309+
310+
resample_kernel_vmap = vmap(vmap(resample_kernel, in_dims=0, out_dims=0), in_dims=0, out_dims=0)
311+
patch_embed_float = patch_embed.to(intermediate_dtype)
312+
resampled_patch_embed = resample_kernel_vmap(patch_embed_float)
313+
return resampled_patch_embed.to(orig_dtype)
314+
315+
316+
def resample_patch_embed(
317+
patch_embed: torch.Tensor,
318+
new_size: List[int],
319+
interpolation: str = 'bicubic',
320+
antialias: bool = True,
321+
verbose: bool = False,
322+
):
323+
""" Standalone function (computes matrix on each call). """
324+
assert len(patch_embed.shape) == 4, "Input tensor should be 4D (out_c, in_c, h, w)"
325+
assert len(new_size) == 2, "New shape should only be hw (height, width)"
326+
327+
old_size_tuple: Tuple[int, int] = tuple(patch_embed.shape[-2:])
328+
new_size_tuple: Tuple[int, int] = tuple(new_size)
329+
330+
if old_size_tuple == new_size_tuple:
331+
return patch_embed
332+
333+
device = patch_embed.device
334+
orig_dtype = patch_embed.dtype
335+
336+
resize_mat = _compute_resize_matrix(
337+
old_size_tuple, new_size_tuple, interpolation, antialias, device, DTYPE_INTERMEDIATE
338+
)
339+
pinv_matrix = _compute_pinv_for_resampling(resize_mat)
340+
resampled_patch_embed = _apply_resampling(
341+
patch_embed, pinv_matrix, new_size_tuple, orig_dtype, DTYPE_INTERMEDIATE
342+
)
343+
return resampled_patch_embed
344+
345+
346+
class PatchEmbedResamplerFixedOrigSize(nn.Module):
347+
"""
348+
Resample patch embedding weights from a fixed original size,
349+
caching the pseudoinverse matrix based on the target size.
350+
"""
351+
def __init__(
352+
self,
353+
orig_size: Tuple[int, int],
354+
interpolation: str = 'bicubic',
355+
antialias: bool = True
356+
):
357+
"""
358+
Args:
359+
orig_size (Tuple[int, int]): The expected original (height, width) of input patch_embed tensors.
360+
interpolation (str): Interpolation mode.
361+
antialias (bool): Use anti-aliasing filter in resize.
362+
"""
363+
super().__init__()
364+
assert isinstance(orig_size, tuple) and len(orig_size) == 2, \
365+
"`orig_size` must be a tuple of (height, width)"
366+
self.orig_size = orig_size # expected original size
367+
self.interpolation = interpolation
368+
self.antialias = antialias
369+
# Cache map key is the target new_size tuple
370+
self._pinv_cache_map: Dict[Tuple[int, int], str] = {}
371+
372+
def _get_or_create_pinv_matrix(
373+
self,
374+
new_size: Tuple[int, int],
375+
device: torch.device,
376+
dtype: torch.dtype = DTYPE_INTERMEDIATE
377+
) -> torch.Tensor:
378+
"""Retrieves the cached pinv matrix or computes and caches it for the given new_size."""
379+
cache_key = new_size
380+
buffer_name = self._pinv_cache_map.get(cache_key)
381+
382+
if buffer_name and hasattr(self, buffer_name):
383+
pinv_matrix = getattr(self, buffer_name)
384+
if pinv_matrix.device == device and pinv_matrix.dtype == dtype:
385+
return pinv_matrix
386+
387+
# Calculate the matrix if not cached or needs update
388+
resize_mat = _compute_resize_matrix(
389+
self.orig_size, new_size, self.interpolation, self.antialias, device, dtype
390+
)
391+
pinv_matrix = _compute_pinv_for_resampling(resize_mat)
392+
393+
# Cache using register_buffer
394+
buffer_name = f"pinv_{new_size[0]}x{new_size[1]}"
395+
if hasattr(self, buffer_name):
396+
delattr(self, buffer_name)
397+
self.register_buffer(buffer_name, pinv_matrix)
398+
self._pinv_cache_map[cache_key] = buffer_name # Map new_size key to buffer name
399+
400+
return pinv_matrix
401+
402+
def forward(self, patch_embed: torch.Tensor, new_size: List[int]) -> torch.Tensor:
403+
""" Resamples the patch embedding weights to new_size.
404+
405+
Args:
406+
patch_embed (torch.Tensor): Original weights (out_ch, in_ch, H_orig, W_orig).
407+
new_size (List[int]): Target [height, width].
408+
409+
Returns:
410+
torch.Tensor: Resampled weights.
411+
"""
412+
assert len(patch_embed.shape) == 4
413+
assert len(new_size) == 2
414+
415+
# Input Validation
416+
input_size = tuple(patch_embed.shape[-2:])
417+
assert input_size == self.orig_size, \
418+
f"Input patch_embed spatial size {input_size} does not match " \
419+
f"module's expected original size {self.orig_size}"
420+
421+
new_size_tuple: Tuple[int, int] = tuple(new_size)
422+
423+
# Check no-op case against self.orig_size
424+
if self.orig_size == new_size_tuple:
425+
return patch_embed
426+
427+
device = patch_embed.device
428+
orig_dtype = patch_embed.dtype
429+
430+
# Get or compute the required pseudoinverse matrix
431+
pinv_matrix = self._get_or_create_pinv_matrix(new_size_tuple, device)
432+
433+
# Apply the resampling
434+
resampled_patch_embed = _apply_resampling(patch_embed, pinv_matrix, new_size_tuple, orig_dtype)
435+
436+
return resampled_patch_embed
437+
438+
253439
# def divs(n, m=None):
254440
# m = m or n // 2
255441
# if m == 1:

0 commit comments

Comments
 (0)