10
10
https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch
11
11
"""
12
12
import math
13
- from typing import Optional
13
+ from typing import List , Optional , Tuple , Union
14
14
15
15
import torch
16
16
import torch .nn as nn
17
17
18
18
from timm .data import IMAGENET_INCEPTION_MEAN , IMAGENET_INCEPTION_STD
19
19
from timm .layers import Mlp , DropPath , trunc_normal_ , _assert , to_2tuple , resample_abs_pos_embed
20
20
from ._builder import build_model_with_cfg
21
+ from ._features import feature_take_indices
21
22
from ._manipulate import checkpoint
22
23
from ._registry import register_model
23
24
@@ -172,7 +173,16 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4,
172
173
else :
173
174
self .unfold = nn .Unfold (kernel_size = patch_size , stride = patch_size )
174
175
175
- def forward (self , x , pixel_pos ):
176
+ def feat_ratio (self , as_scalar = True ) -> Union [Tuple [int , int ], int ]:
177
+ if as_scalar :
178
+ return max (self .patch_size )
179
+ else :
180
+ return self .patch_size
181
+
182
+ def dynamic_feat_size (self , img_size : Tuple [int , int ]) -> Tuple [int , int ]:
183
+ return img_size [0 ] // self .patch_size [0 ], img_size [1 ] // self .patch_size [1 ]
184
+
185
+ def forward (self , x : torch .Tensor , pixel_pos : torch .Tensor ) -> torch .Tensor :
176
186
B , C , H , W = x .shape
177
187
_assert (H == self .img_size [0 ],
178
188
f"Input image size ({ H } *{ W } ) doesn't match model ({ self .img_size [0 ]} *{ self .img_size [1 ]} )." )
@@ -222,6 +232,7 @@ def __init__(
222
232
self .num_classes = num_classes
223
233
self .global_pool = global_pool
224
234
self .num_features = self .head_hidden_size = self .embed_dim = embed_dim # for consistency with other models
235
+ self .num_prefix_tokens = 1
225
236
self .grad_checkpointing = False
226
237
227
238
self .pixel_embed = PixelEmbed (
@@ -233,6 +244,7 @@ def __init__(
233
244
legacy = legacy ,
234
245
)
235
246
num_patches = self .pixel_embed .num_patches
247
+ r = self .pixel_embed .feat_ratio () if hasattr (self .pixel_embed , 'feat_ratio' ) else patch_size
236
248
self .num_patches = num_patches
237
249
new_patch_size = self .pixel_embed .new_patch_size
238
250
num_pixel = new_patch_size [0 ] * new_patch_size [1 ]
@@ -264,8 +276,10 @@ def __init__(
264
276
legacy = legacy ,
265
277
))
266
278
self .blocks = nn .ModuleList (blocks )
279
+ self .feature_info = [
280
+ dict (module = f'blocks.{ i } ' , num_chs = embed_dim , reduction = r ) for i in range (depth )]
281
+
267
282
self .norm = norm_layer (embed_dim )
268
-
269
283
self .head_drop = nn .Dropout (drop_rate )
270
284
self .head = nn .Linear (embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
271
285
@@ -313,6 +327,92 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
313
327
self .global_pool = global_pool
314
328
self .head = nn .Linear (self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
315
329
330
+ def forward_intermediates (
331
+ self ,
332
+ x : torch .Tensor ,
333
+ indices : Optional [Union [int , List [int ]]] = None ,
334
+ return_prefix_tokens : bool = False ,
335
+ norm : bool = False ,
336
+ stop_early : bool = False ,
337
+ output_fmt : str = 'NCHW' ,
338
+ intermediates_only : bool = False ,
339
+ ) -> Union [List [torch .Tensor ], Tuple [torch .Tensor , List [torch .Tensor ]]]:
340
+ """ Forward features that returns intermediates.
341
+
342
+ Args:
343
+ x: Input image tensor
344
+ indices: Take last n blocks if an int, if is a sequence, select by matching indices
345
+ return_prefix_tokens: Return both prefix and spatial intermediate tokens
346
+ norm: Apply norm layer to all intermediates
347
+ stop_early: Stop iterating over blocks when last desired intermediate hit
348
+ output_fmt: Shape of intermediate feature outputs
349
+ intermediates_only: Only return intermediate features
350
+ Returns:
351
+
352
+ """
353
+ assert output_fmt in ('NCHW' , 'NLC' ), 'Output format must be one of NCHW or NLC.'
354
+ reshape = output_fmt == 'NCHW'
355
+ intermediates = []
356
+ take_indices , max_index = feature_take_indices (len (self .blocks ), indices )
357
+
358
+ # forward pass
359
+ B , _ , height , width = x .shape
360
+
361
+ pixel_embed = self .pixel_embed (x , self .pixel_pos )
362
+
363
+ patch_embed = self .norm2_proj (self .proj (self .norm1_proj (pixel_embed .reshape (B , self .num_patches , - 1 ))))
364
+ patch_embed = torch .cat ((self .cls_token .expand (B , - 1 , - 1 ), patch_embed ), dim = 1 )
365
+ patch_embed = patch_embed + self .patch_pos
366
+ patch_embed = self .pos_drop (patch_embed )
367
+
368
+ if torch .jit .is_scripting () or not stop_early : # can't slice blocks in torchscript
369
+ blocks = self .blocks
370
+ else :
371
+ blocks = self .blocks [:max_index + 1 ]
372
+
373
+ for i , blk in enumerate (blocks ):
374
+ pixel_embed , patch_embed = blk (pixel_embed , patch_embed )
375
+ if i in take_indices :
376
+ # normalize intermediates with final norm layer if enabled
377
+ intermediates .append (self .norm (patch_embed ) if norm else patch_embed )
378
+
379
+ # process intermediates
380
+ if self .num_prefix_tokens :
381
+ # split prefix (e.g. class, distill) and spatial feature tokens
382
+ prefix_tokens = [y [:, 0 :self .num_prefix_tokens ] for y in intermediates ]
383
+ intermediates = [y [:, self .num_prefix_tokens :] for y in intermediates ]
384
+
385
+ if reshape :
386
+ # reshape to BCHW output format
387
+ H , W = self .pixel_embed .dynamic_feat_size ((height , width ))
388
+ intermediates = [y .reshape (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous () for y in intermediates ]
389
+ if not torch .jit .is_scripting () and return_prefix_tokens :
390
+ # return_prefix not support in torchscript due to poor type handling
391
+ intermediates = list (zip (intermediates , prefix_tokens ))
392
+
393
+ if intermediates_only :
394
+ return intermediates
395
+
396
+ patch_embed = self .norm (patch_embed )
397
+
398
+ return patch_embed , intermediates
399
+
400
+ def prune_intermediate_layers (
401
+ self ,
402
+ indices : Union [int , List [int ]] = 1 ,
403
+ prune_norm : bool = False ,
404
+ prune_head : bool = True ,
405
+ ):
406
+ """ Prune layers not required for specified intermediates.
407
+ """
408
+ take_indices , max_index = feature_take_indices (len (self .blocks ), indices )
409
+ self .blocks = self .blocks [:max_index + 1 ] # truncate blocks
410
+ if prune_norm :
411
+ self .norm = nn .Identity ()
412
+ if prune_head :
413
+ self .reset_classifier (0 , '' )
414
+ return take_indices
415
+
316
416
def forward_features (self , x ):
317
417
B = x .shape [0 ]
318
418
pixel_embed = self .pixel_embed (x , self .pixel_pos )
@@ -322,19 +422,18 @@ def forward_features(self, x):
322
422
patch_embed = patch_embed + self .patch_pos
323
423
patch_embed = self .pos_drop (patch_embed )
324
424
325
- if self . grad_checkpointing and not torch . jit . is_scripting () :
326
- for blk in self . blocks :
425
+ for blk in self . blocks :
426
+ if self . grad_checkpointing and not torch . jit . is_scripting () :
327
427
pixel_embed , patch_embed = checkpoint (blk , pixel_embed , patch_embed )
328
- else :
329
- for blk in self .blocks :
428
+ else :
330
429
pixel_embed , patch_embed = blk (pixel_embed , patch_embed )
331
430
332
431
patch_embed = self .norm (patch_embed )
333
432
return patch_embed
334
433
335
434
def forward_head (self , x , pre_logits : bool = False ):
336
435
if self .global_pool :
337
- x = x [:, 1 :].mean (dim = 1 ) if self .global_pool == 'avg' else x [:, 0 ]
436
+ x = x [:, self . num_prefix_tokens :].mean (dim = 1 ) if self .global_pool == 'avg' else x [:, 0 ]
338
437
x = self .head_drop (x )
339
438
return x if pre_logits else self .head (x )
340
439
@@ -344,6 +443,30 @@ def forward(self, x):
344
443
return x
345
444
346
445
446
+ def _cfg (url = '' , ** kwargs ):
447
+ return {
448
+ 'url' : url ,
449
+ 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : None ,
450
+ 'crop_pct' : .9 , 'interpolation' : 'bicubic' , 'fixed_input_size' : True ,
451
+ 'mean' : IMAGENET_INCEPTION_MEAN , 'std' : IMAGENET_INCEPTION_STD ,
452
+ 'first_conv' : 'pixel_embed.proj' , 'classifier' : 'head' ,
453
+ ** kwargs
454
+ }
455
+
456
+
457
+ default_cfgs = {
458
+ 'tnt_s_patch16_224.in1k' : _cfg (
459
+ # hf_hub_id='timm/',
460
+ # url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
461
+ url = 'https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar' ,
462
+ ),
463
+ 'tnt_b_patch16_224.in1k' : _cfg (
464
+ # hf_hub_id='timm/',
465
+ url = 'https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar' ,
466
+ ),
467
+ }
468
+
469
+
347
470
def checkpoint_filter_fn (state_dict , model ):
348
471
state_dict .pop ('outer_tokens' , None )
349
472
@@ -380,40 +503,15 @@ def checkpoint_filter_fn(state_dict, model):
380
503
381
504
382
505
def _create_tnt (variant , pretrained = False , ** kwargs ):
383
- if kwargs .get ('features_only' , None ):
384
- raise RuntimeError ('features_only not implemented for Vision Transformer models.' )
385
-
506
+ out_indices = kwargs .pop ('out_indices' , 3 )
386
507
model = build_model_with_cfg (
387
508
TNT , variant , pretrained ,
388
509
pretrained_filter_fn = checkpoint_filter_fn ,
510
+ feature_cfg = dict (out_indices = out_indices , feature_cls = 'getter' ),
389
511
** kwargs )
390
512
return model
391
513
392
514
393
- def _cfg (url = '' , ** kwargs ):
394
- return {
395
- 'url' : url ,
396
- 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : None ,
397
- 'crop_pct' : .9 , 'interpolation' : 'bicubic' , 'fixed_input_size' : True ,
398
- 'mean' : IMAGENET_INCEPTION_MEAN , 'std' : IMAGENET_INCEPTION_STD ,
399
- 'first_conv' : 'pixel_embed.proj' , 'classifier' : 'head' ,
400
- ** kwargs
401
- }
402
-
403
-
404
- default_cfgs = {
405
- 'tnt_s_patch16_224' : _cfg (
406
- # hf_hub_id='timm/',
407
- # url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
408
- url = 'https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar' ,
409
- ),
410
- 'tnt_b_patch16_224' : _cfg (
411
- # hf_hub_id='timm/',
412
- url = 'https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar' ,
413
- ),
414
- }
415
-
416
-
417
515
@register_model
418
516
def tnt_s_patch16_224 (pretrained = False , ** kwargs ) -> TNT :
419
517
model_cfg = dict (
0 commit comments