Skip to content

Commit 941ec01

Browse files
committed
update some model
1 parent 2af810f commit 941ec01

File tree

7 files changed

+104
-29
lines changed

7 files changed

+104
-29
lines changed

timm/models/convnext.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -452,29 +452,29 @@ def forward_intermediates(
452452
"""
453453
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
454454
intermediates = []
455-
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
455+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
456456

457457
# forward pass
458-
feat_idx = 0 # stem is index 0
459458
x = self.stem(x)
460-
if feat_idx in take_indices:
461-
intermediates.append(x)
462459

460+
last_idx = len(self.stages) - 1
463461
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
464462
stages = self.stages
465463
else:
466-
stages = self.stages[:max_index]
467-
for stage in stages:
468-
feat_idx += 1
464+
stages = self.stages[:max_index + 1]
465+
for feat_idx, stage in enumerate(stages):
469466
x = stage(x)
470467
if feat_idx in take_indices:
471-
# NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled
472-
intermediates.append(x)
468+
if norm and feat_idx == last_idx:
469+
intermediates.append(self.norm_pre(x))
470+
else:
471+
intermediates.append(x)
473472

474473
if intermediates_only:
475474
return intermediates
476475

477-
x = self.norm_pre(x)
476+
if feat_idx == last_idx:
477+
x = self.norm_pre(x)
478478

479479
return x, intermediates
480480

timm/models/focalnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ def forward_intermediates(
491491
else:
492492
stages = self.layers[:max_index + 1]
493493

494-
last_idx = len(self.layers)
494+
last_idx = len(self.layers) - 1
495495
for feat_idx, stage in enumerate(stages):
496496
x = stage(x)
497497
if feat_idx in take_indices:

timm/models/mvitv2.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -870,10 +870,11 @@ def forward_intermediates(
870870
if self.pos_embed is not None:
871871
x = x + self.pos_embed
872872

873-
for i, stage in enumerate(self.stages):
873+
last_idx = len(self.stages) - 1
874+
for feat_idx, stage in enumerate(self.stages):
874875
x, feat_size = stage(x, feat_size)
875-
if i in take_indices:
876-
if norm and i == (len(self.stages) - 1):
876+
if feat_idx in take_indices:
877+
if norm and feat_idx == last_idx:
877878
x_inter = self.norm(x) # applying final norm last intermediate
878879
else:
879880
x_inter = x
@@ -887,7 +888,8 @@ def forward_intermediates(
887888
if intermediates_only:
888889
return intermediates
889890

890-
x = self.norm(x)
891+
if feat_idx == last_idx:
892+
x = self.norm(x)
891893

892894
return x, intermediates
893895

timm/models/pit.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414
import math
1515
import re
1616
from functools import partial
17-
from typing import Optional, Sequence, Tuple
17+
from typing import List, Optional, Sequence, Tuple, Union
1818

1919
import torch
2020
from torch import nn
2121

2222
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2323
from timm.layers import trunc_normal_, to_2tuple
2424
from ._builder import build_model_with_cfg
25+
from ._features import feature_take_indices
2526
from ._registry import register_model, generate_default_cfgs
2627
from .vision_transformer import Block
2728

@@ -254,6 +255,71 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
254255
if self.head_dist is not None:
255256
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
256257

258+
def forward_intermediates(
259+
self,
260+
x: torch.Tensor,
261+
indices: Optional[Union[int, List[int]]] = None,
262+
norm: bool = False,
263+
stop_early: bool = False,
264+
output_fmt: str = 'NCHW',
265+
intermediates_only: bool = False,
266+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
267+
""" Forward features that returns intermediates.
268+
269+
Args:
270+
x: Input image tensor
271+
indices: Take last n blocks if int, all if None, select matching indices if sequence
272+
norm: Apply norm layer to compatible intermediates
273+
stop_early: Stop iterating over blocks when last desired intermediate hit
274+
output_fmt: Shape of intermediate feature outputs
275+
intermediates_only: Only return intermediate features
276+
Returns:
277+
278+
"""
279+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
280+
intermediates = []
281+
take_indices, max_index = feature_take_indices(len(self.transformers), indices)
282+
283+
# forward pass
284+
x = self.patch_embed(x)
285+
x = self.pos_drop(x + self.pos_embed)
286+
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
287+
288+
last_idx = len(self.transformers) - 1
289+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
290+
stages = self.transformers
291+
else:
292+
stages = self.transformers[:max_index + 1]
293+
294+
for feat_idx, stage in enumerate(stages):
295+
x, cls_tokens = stage((x, cls_tokens))
296+
if feat_idx in take_indices:
297+
intermediates.append(x)
298+
299+
if intermediates_only:
300+
return intermediates
301+
302+
if feat_idx == last_idx:
303+
cls_tokens = self.norm(cls_tokens)
304+
305+
return cls_tokens, intermediates
306+
307+
def prune_intermediate_layers(
308+
self,
309+
indices: Union[int, List[int]] = 1,
310+
prune_norm: bool = False,
311+
prune_head: bool = True,
312+
):
313+
""" Prune layers not required for specified intermediates.
314+
"""
315+
take_indices, max_index = feature_take_indices(len(self.transformers), indices)
316+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
317+
if prune_norm:
318+
self.norm = nn.Identity()
319+
if prune_head:
320+
self.reset_classifier(0, '')
321+
return take_indices
322+
257323
def forward_features(self, x):
258324
x = self.patch_embed(x)
259325
x = self.pos_drop(x + self.pos_embed)

timm/models/rdnet.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -302,20 +302,20 @@ def forward_intermediates(
302302
"""
303303
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
304304
intermediates = []
305-
take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices)
305+
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
306+
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
307+
take_indices = [stage_ends[i] for i in take_indices]
308+
max_index = stage_ends[max_index]
306309

307310
# forward pass
308-
feat_idx = 0 # stem is index 0
309311
x = self.stem(x)
310-
if feat_idx in take_indices:
311-
intermediates.append(x)
312-
last_idx = len(self.dense_stages)
312+
313+
last_idx = len(self.dense_stages) - 1
313314
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
314315
dense_stages = self.dense_stages
315316
else:
316-
dense_stages = self.dense_stages[:max_index]
317-
for stage in dense_stages:
318-
feat_idx += 1
317+
dense_stages = self.dense_stages[:max_index + 1]
318+
for feat_idx, stage in enumerate(dense_stages):
319319
x = stage(x)
320320
if feat_idx in take_indices:
321321
if norm and feat_idx == last_idx:
@@ -340,8 +340,10 @@ def prune_intermediate_layers(
340340
):
341341
""" Prune layers not required for specified intermediates.
342342
"""
343-
take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices)
344-
self.dense_stages = self.dense_stages[:max_index] # truncate blocks w/ stem as idx 0
343+
stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
344+
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
345+
max_index = stage_ends[max_index]
346+
self.dense_stages = self.dense_stages[:max_index + 1] # truncate blocks w/ stem as idx 0
345347
if prune_norm:
346348
self.norm_pre = nn.Identity()
347349
if prune_head:

timm/models/resnetv2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,9 +571,13 @@ def forward_intermediates(
571571

572572
# forward pass
573573
feat_idx = 0
574-
x = self.stem(x)
574+
H, W = x.shape[-2:]
575+
for stem in self.stem:
576+
x = stem(x)
577+
if x.shape[-2:] == (H //2, W //2):
578+
x_down = x
575579
if feat_idx in take_indices:
576-
intermediates.append(x)
580+
intermediates.append(x_down)
577581
last_idx = len(self.stages)
578582
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
579583
stages = self.stages

timm/models/xcit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,8 @@ def forward_intermediates(
494494
# NOTE not supporting return of class tokens
495495
x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
496496
for blk in self.cls_attn_blocks:
497-
x = blk(x)
497+
x = blk(x)
498+
498499
x = self.norm(x)
499500

500501
return x, intermediates

0 commit comments

Comments
 (0)