|
10 | 10 | """
|
11 | 11 | import logging
|
12 | 12 | import math
|
13 |
| -from typing import Callable, List, Optional, Tuple, Union |
| 13 | +from typing import Callable, Dict, List, Optional, Tuple, Union |
14 | 14 |
|
15 | 15 | import torch
|
16 | 16 | from torch import nn as nn
|
@@ -180,7 +180,8 @@ def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
|
180 | 180 | return x, feat_size
|
181 | 181 |
|
182 | 182 |
|
183 |
| -def resample_patch_embed( |
| 183 | +# FIXME to remove, keeping for comparison for now |
| 184 | +def resample_patch_embed_old( |
184 | 185 | patch_embed,
|
185 | 186 | new_size: List[int],
|
186 | 187 | interpolation: str = 'bicubic',
|
@@ -250,6 +251,191 @@ def resample_kernel(kernel):
|
250 | 251 | return patch_embed
|
251 | 252 |
|
252 | 253 |
|
| 254 | +DTYPE_INTERMEDIATE = torch.float32 |
| 255 | + |
| 256 | + |
| 257 | +def _compute_resize_matrix( |
| 258 | + old_size: Tuple[int, int], |
| 259 | + new_size: Tuple[int, int], |
| 260 | + interpolation: str, |
| 261 | + antialias: bool, |
| 262 | + device: torch.device, |
| 263 | + dtype: torch.dtype = DTYPE_INTERMEDIATE |
| 264 | +) -> torch.Tensor: |
| 265 | + """Computes the resize matrix basis vectors and interpolates them to new_size.""" |
| 266 | + old_h, old_w = old_size |
| 267 | + new_h, new_w = new_size |
| 268 | + old_total = old_h * old_w |
| 269 | + new_total = new_h * new_w |
| 270 | + |
| 271 | + eye_matrix = torch.eye(old_total, device=device, dtype=dtype) |
| 272 | + basis_vectors_batch = eye_matrix.reshape(old_total, 1, old_h, old_w) |
| 273 | + |
| 274 | + resized_basis_vectors_batch = F.interpolate( |
| 275 | + basis_vectors_batch, |
| 276 | + size=new_size, |
| 277 | + mode=interpolation, |
| 278 | + antialias=antialias, |
| 279 | + align_corners=False |
| 280 | + ) # Output shape: (old_total, 1, new_h, new_w) |
| 281 | + |
| 282 | + resize_matrix = resized_basis_vectors_batch.squeeze(1).reshape(old_total, new_total).T |
| 283 | + return resize_matrix # Shape: (new_total, old_total) |
| 284 | + |
| 285 | + |
| 286 | +def _compute_pinv_for_resampling(resize_matrix: torch.Tensor) -> torch.Tensor: |
| 287 | + """Calculates the pseudoinverse matrix used for the resampling operation.""" |
| 288 | + pinv_matrix = torch.linalg.pinv(resize_matrix.T) # Shape: (new_total, old_total) |
| 289 | + return pinv_matrix |
| 290 | + |
| 291 | + |
| 292 | +def _apply_resampling( |
| 293 | + patch_embed: torch.Tensor, |
| 294 | + pinv_matrix: torch.Tensor, |
| 295 | + new_size_tuple: Tuple[int, int], |
| 296 | + orig_dtype: torch.dtype, |
| 297 | + intermediate_dtype: torch.dtype = DTYPE_INTERMEDIATE |
| 298 | +) -> torch.Tensor: |
| 299 | + """Applies the precomputed pinv_matrix to resample the patch_embed tensor.""" |
| 300 | + try: |
| 301 | + from torch import vmap |
| 302 | + except ImportError: |
| 303 | + from functorch import vmap |
| 304 | + |
| 305 | + def resample_kernel(kernel: torch.Tensor) -> torch.Tensor: |
| 306 | + kernel_flat = kernel.reshape(-1).to(intermediate_dtype) |
| 307 | + resampled_kernel_flat = pinv_matrix @ kernel_flat |
| 308 | + return resampled_kernel_flat.reshape(new_size_tuple) |
| 309 | + |
| 310 | + resample_kernel_vmap = vmap(vmap(resample_kernel, in_dims=0, out_dims=0), in_dims=0, out_dims=0) |
| 311 | + patch_embed_float = patch_embed.to(intermediate_dtype) |
| 312 | + resampled_patch_embed = resample_kernel_vmap(patch_embed_float) |
| 313 | + return resampled_patch_embed.to(orig_dtype) |
| 314 | + |
| 315 | + |
| 316 | +def resample_patch_embed( |
| 317 | + patch_embed: torch.Tensor, |
| 318 | + new_size: List[int], |
| 319 | + interpolation: str = 'bicubic', |
| 320 | + antialias: bool = True, |
| 321 | + verbose: bool = False, |
| 322 | +): |
| 323 | + """ Standalone function (computes matrix on each call). """ |
| 324 | + assert len(patch_embed.shape) == 4, "Input tensor should be 4D (out_c, in_c, h, w)" |
| 325 | + assert len(new_size) == 2, "New shape should only be hw (height, width)" |
| 326 | + |
| 327 | + old_size_tuple: Tuple[int, int] = tuple(patch_embed.shape[-2:]) |
| 328 | + new_size_tuple: Tuple[int, int] = tuple(new_size) |
| 329 | + |
| 330 | + if old_size_tuple == new_size_tuple: |
| 331 | + return patch_embed |
| 332 | + |
| 333 | + device = patch_embed.device |
| 334 | + orig_dtype = patch_embed.dtype |
| 335 | + |
| 336 | + resize_mat = _compute_resize_matrix( |
| 337 | + old_size_tuple, new_size_tuple, interpolation, antialias, device, DTYPE_INTERMEDIATE |
| 338 | + ) |
| 339 | + pinv_matrix = _compute_pinv_for_resampling(resize_mat) |
| 340 | + resampled_patch_embed = _apply_resampling( |
| 341 | + patch_embed, pinv_matrix, new_size_tuple, orig_dtype, DTYPE_INTERMEDIATE |
| 342 | + ) |
| 343 | + return resampled_patch_embed |
| 344 | + |
| 345 | + |
| 346 | +class PatchEmbedResamplerFixedOrigSize(nn.Module): |
| 347 | + """ |
| 348 | + Resample patch embedding weights from a fixed original size, |
| 349 | + caching the pseudoinverse matrix based on the target size. |
| 350 | + """ |
| 351 | + def __init__( |
| 352 | + self, |
| 353 | + orig_size: Tuple[int, int], |
| 354 | + interpolation: str = 'bicubic', |
| 355 | + antialias: bool = True |
| 356 | + ): |
| 357 | + """ |
| 358 | + Args: |
| 359 | + orig_size (Tuple[int, int]): The expected original (height, width) of input patch_embed tensors. |
| 360 | + interpolation (str): Interpolation mode. |
| 361 | + antialias (bool): Use anti-aliasing filter in resize. |
| 362 | + """ |
| 363 | + super().__init__() |
| 364 | + assert isinstance(orig_size, tuple) and len(orig_size) == 2, \ |
| 365 | + "`orig_size` must be a tuple of (height, width)" |
| 366 | + self.orig_size = orig_size # expected original size |
| 367 | + self.interpolation = interpolation |
| 368 | + self.antialias = antialias |
| 369 | + # Cache map key is the target new_size tuple |
| 370 | + self._pinv_cache_map: Dict[Tuple[int, int], str] = {} |
| 371 | + |
| 372 | + def _get_or_create_pinv_matrix( |
| 373 | + self, |
| 374 | + new_size: Tuple[int, int], |
| 375 | + device: torch.device, |
| 376 | + dtype: torch.dtype = DTYPE_INTERMEDIATE |
| 377 | + ) -> torch.Tensor: |
| 378 | + """Retrieves the cached pinv matrix or computes and caches it for the given new_size.""" |
| 379 | + cache_key = new_size |
| 380 | + buffer_name = self._pinv_cache_map.get(cache_key) |
| 381 | + |
| 382 | + if buffer_name and hasattr(self, buffer_name): |
| 383 | + pinv_matrix = getattr(self, buffer_name) |
| 384 | + if pinv_matrix.device == device and pinv_matrix.dtype == dtype: |
| 385 | + return pinv_matrix |
| 386 | + |
| 387 | + # Calculate the matrix if not cached or needs update |
| 388 | + resize_mat = _compute_resize_matrix( |
| 389 | + self.orig_size, new_size, self.interpolation, self.antialias, device, dtype |
| 390 | + ) |
| 391 | + pinv_matrix = _compute_pinv_for_resampling(resize_mat) |
| 392 | + |
| 393 | + # Cache using register_buffer |
| 394 | + buffer_name = f"pinv_{new_size[0]}x{new_size[1]}" |
| 395 | + if hasattr(self, buffer_name): |
| 396 | + delattr(self, buffer_name) |
| 397 | + self.register_buffer(buffer_name, pinv_matrix) |
| 398 | + self._pinv_cache_map[cache_key] = buffer_name # Map new_size key to buffer name |
| 399 | + |
| 400 | + return pinv_matrix |
| 401 | + |
| 402 | + def forward(self, patch_embed: torch.Tensor, new_size: List[int]) -> torch.Tensor: |
| 403 | + """ Resamples the patch embedding weights to new_size. |
| 404 | +
|
| 405 | + Args: |
| 406 | + patch_embed (torch.Tensor): Original weights (out_ch, in_ch, H_orig, W_orig). |
| 407 | + new_size (List[int]): Target [height, width]. |
| 408 | +
|
| 409 | + Returns: |
| 410 | + torch.Tensor: Resampled weights. |
| 411 | + """ |
| 412 | + assert len(patch_embed.shape) == 4 |
| 413 | + assert len(new_size) == 2 |
| 414 | + |
| 415 | + # Input Validation |
| 416 | + input_size = tuple(patch_embed.shape[-2:]) |
| 417 | + assert input_size == self.orig_size, \ |
| 418 | + f"Input patch_embed spatial size {input_size} does not match " \ |
| 419 | + f"module's expected original size {self.orig_size}" |
| 420 | + |
| 421 | + new_size_tuple: Tuple[int, int] = tuple(new_size) |
| 422 | + |
| 423 | + # Check no-op case against self.orig_size |
| 424 | + if self.orig_size == new_size_tuple: |
| 425 | + return patch_embed |
| 426 | + |
| 427 | + device = patch_embed.device |
| 428 | + orig_dtype = patch_embed.dtype |
| 429 | + |
| 430 | + # Get or compute the required pseudoinverse matrix |
| 431 | + pinv_matrix = self._get_or_create_pinv_matrix(new_size_tuple, device) |
| 432 | + |
| 433 | + # Apply the resampling |
| 434 | + resampled_patch_embed = _apply_resampling(patch_embed, pinv_matrix, new_size_tuple, orig_dtype) |
| 435 | + |
| 436 | + return resampled_patch_embed |
| 437 | + |
| 438 | + |
253 | 439 | # def divs(n, m=None):
|
254 | 440 | # m = m or n // 2
|
255 | 441 | # if m == 1:
|
|
0 commit comments