Skip to content

Commit 848b8c3

Browse files
committed
Support features_only
1 parent b37f0f7 commit 848b8c3

File tree

1 file changed

+133
-35
lines changed

1 file changed

+133
-35
lines changed

timm/models/tnt.py

Lines changed: 133 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010
https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch
1111
"""
1212
import math
13-
from typing import Optional
13+
from typing import List, Optional, Tuple, Union
1414

1515
import torch
1616
import torch.nn as nn
1717

1818
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
1919
from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple, resample_abs_pos_embed
2020
from ._builder import build_model_with_cfg
21+
from ._features import feature_take_indices
2122
from ._manipulate import checkpoint
2223
from ._registry import register_model
2324

@@ -172,7 +173,16 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4,
172173
else:
173174
self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
174175

175-
def forward(self, x, pixel_pos):
176+
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
177+
if as_scalar:
178+
return max(self.patch_size)
179+
else:
180+
return self.patch_size
181+
182+
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
183+
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
184+
185+
def forward(self, x: torch.Tensor, pixel_pos: torch.Tensor) -> torch.Tensor:
176186
B, C, H, W = x.shape
177187
_assert(H == self.img_size[0],
178188
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
@@ -222,6 +232,7 @@ def __init__(
222232
self.num_classes = num_classes
223233
self.global_pool = global_pool
224234
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
235+
self.num_prefix_tokens = 1
225236
self.grad_checkpointing = False
226237

227238
self.pixel_embed = PixelEmbed(
@@ -233,6 +244,7 @@ def __init__(
233244
legacy=legacy,
234245
)
235246
num_patches = self.pixel_embed.num_patches
247+
r = self.pixel_embed.feat_ratio() if hasattr(self.pixel_embed, 'feat_ratio') else patch_size
236248
self.num_patches = num_patches
237249
new_patch_size = self.pixel_embed.new_patch_size
238250
num_pixel = new_patch_size[0] * new_patch_size[1]
@@ -264,8 +276,10 @@ def __init__(
264276
legacy=legacy,
265277
))
266278
self.blocks = nn.ModuleList(blocks)
279+
self.feature_info = [
280+
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
281+
267282
self.norm = norm_layer(embed_dim)
268-
269283
self.head_drop = nn.Dropout(drop_rate)
270284
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
271285

@@ -313,6 +327,92 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
313327
self.global_pool = global_pool
314328
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
315329

330+
def forward_intermediates(
331+
self,
332+
x: torch.Tensor,
333+
indices: Optional[Union[int, List[int]]] = None,
334+
return_prefix_tokens: bool = False,
335+
norm: bool = False,
336+
stop_early: bool = False,
337+
output_fmt: str = 'NCHW',
338+
intermediates_only: bool = False,
339+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
340+
""" Forward features that returns intermediates.
341+
342+
Args:
343+
x: Input image tensor
344+
indices: Take last n blocks if an int, if is a sequence, select by matching indices
345+
return_prefix_tokens: Return both prefix and spatial intermediate tokens
346+
norm: Apply norm layer to all intermediates
347+
stop_early: Stop iterating over blocks when last desired intermediate hit
348+
output_fmt: Shape of intermediate feature outputs
349+
intermediates_only: Only return intermediate features
350+
Returns:
351+
352+
"""
353+
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
354+
reshape = output_fmt == 'NCHW'
355+
intermediates = []
356+
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
357+
358+
# forward pass
359+
B, _, height, width = x.shape
360+
361+
pixel_embed = self.pixel_embed(x, self.pixel_pos)
362+
363+
patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1))))
364+
patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1)
365+
patch_embed = patch_embed + self.patch_pos
366+
patch_embed = self.pos_drop(patch_embed)
367+
368+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
369+
blocks = self.blocks
370+
else:
371+
blocks = self.blocks[:max_index + 1]
372+
373+
for i, blk in enumerate(blocks):
374+
pixel_embed, patch_embed = blk(pixel_embed, patch_embed)
375+
if i in take_indices:
376+
# normalize intermediates with final norm layer if enabled
377+
intermediates.append(self.norm(patch_embed) if norm else patch_embed)
378+
379+
# process intermediates
380+
if self.num_prefix_tokens:
381+
# split prefix (e.g. class, distill) and spatial feature tokens
382+
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
383+
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
384+
385+
if reshape:
386+
# reshape to BCHW output format
387+
H, W = self.pixel_embed.dynamic_feat_size((height, width))
388+
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
389+
if not torch.jit.is_scripting() and return_prefix_tokens:
390+
# return_prefix not support in torchscript due to poor type handling
391+
intermediates = list(zip(intermediates, prefix_tokens))
392+
393+
if intermediates_only:
394+
return intermediates
395+
396+
patch_embed = self.norm(patch_embed)
397+
398+
return patch_embed, intermediates
399+
400+
def prune_intermediate_layers(
401+
self,
402+
indices: Union[int, List[int]] = 1,
403+
prune_norm: bool = False,
404+
prune_head: bool = True,
405+
):
406+
""" Prune layers not required for specified intermediates.
407+
"""
408+
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
409+
self.blocks = self.blocks[:max_index + 1] # truncate blocks
410+
if prune_norm:
411+
self.norm = nn.Identity()
412+
if prune_head:
413+
self.reset_classifier(0, '')
414+
return take_indices
415+
316416
def forward_features(self, x):
317417
B = x.shape[0]
318418
pixel_embed = self.pixel_embed(x, self.pixel_pos)
@@ -322,19 +422,18 @@ def forward_features(self, x):
322422
patch_embed = patch_embed + self.patch_pos
323423
patch_embed = self.pos_drop(patch_embed)
324424

325-
if self.grad_checkpointing and not torch.jit.is_scripting():
326-
for blk in self.blocks:
425+
for blk in self.blocks:
426+
if self.grad_checkpointing and not torch.jit.is_scripting():
327427
pixel_embed, patch_embed = checkpoint(blk, pixel_embed, patch_embed)
328-
else:
329-
for blk in self.blocks:
428+
else:
330429
pixel_embed, patch_embed = blk(pixel_embed, patch_embed)
331430

332431
patch_embed = self.norm(patch_embed)
333432
return patch_embed
334433

335434
def forward_head(self, x, pre_logits: bool = False):
336435
if self.global_pool:
337-
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
436+
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
338437
x = self.head_drop(x)
339438
return x if pre_logits else self.head(x)
340439

@@ -344,6 +443,30 @@ def forward(self, x):
344443
return x
345444

346445

446+
def _cfg(url='', **kwargs):
447+
return {
448+
'url': url,
449+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
450+
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
451+
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
452+
'first_conv': 'pixel_embed.proj', 'classifier': 'head',
453+
**kwargs
454+
}
455+
456+
457+
default_cfgs = {
458+
'tnt_s_patch16_224.in1k': _cfg(
459+
# hf_hub_id='timm/',
460+
# url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
461+
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar',
462+
),
463+
'tnt_b_patch16_224.in1k': _cfg(
464+
# hf_hub_id='timm/',
465+
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
466+
),
467+
}
468+
469+
347470
def checkpoint_filter_fn(state_dict, model):
348471
state_dict.pop('outer_tokens', None)
349472

@@ -380,40 +503,15 @@ def checkpoint_filter_fn(state_dict, model):
380503

381504

382505
def _create_tnt(variant, pretrained=False, **kwargs):
383-
if kwargs.get('features_only', None):
384-
raise RuntimeError('features_only not implemented for Vision Transformer models.')
385-
506+
out_indices = kwargs.pop('out_indices', 3)
386507
model = build_model_with_cfg(
387508
TNT, variant, pretrained,
388509
pretrained_filter_fn=checkpoint_filter_fn,
510+
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
389511
**kwargs)
390512
return model
391513

392514

393-
def _cfg(url='', **kwargs):
394-
return {
395-
'url': url,
396-
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
397-
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
398-
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
399-
'first_conv': 'pixel_embed.proj', 'classifier': 'head',
400-
**kwargs
401-
}
402-
403-
404-
default_cfgs = {
405-
'tnt_s_patch16_224': _cfg(
406-
# hf_hub_id='timm/',
407-
# url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
408-
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar',
409-
),
410-
'tnt_b_patch16_224': _cfg(
411-
# hf_hub_id='timm/',
412-
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
413-
),
414-
}
415-
416-
417515
@register_model
418516
def tnt_s_patch16_224(pretrained=False, **kwargs) -> TNT:
419517
model_cfg = dict(

0 commit comments

Comments
 (0)