Skip to content

Commit 2b251fb

Browse files
committed
Wrap torch checkpoint() fn to default use_reentrant flag to False and allow env var override
1 parent 131518c commit 2b251fb

22 files changed

+91
-54
lines changed

timm/layers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead
99
from .cond_conv2d import CondConv2d, get_condconv_initializer
1010
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \
11-
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn
11+
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn, \
12+
set_reentrant_ckpt, use_reentrant_ckpt
1213
from .conv2d_same import Conv2dSame, conv2d_same
1314
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
1415
from .create_act import create_act_layer, get_act_layer, get_act_fn

timm/layers/config.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
__all__ = [
1010
'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn',
11-
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn'
11+
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn',
12+
'set_reentrant_ckpt', 'use_reentrant_ckpt'
1213
]
1314

1415
# Set to True if prefer to have layers with no jit optimization (includes activations)
@@ -34,6 +35,12 @@
3435
_USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)
3536

3637

38+
if 'TIMM_REENTRANT_CKPT' in os.environ:
39+
_USE_REENTRANT_CKPT = bool(os.environ['TIMM_REENTRANT_CKPT'])
40+
else:
41+
_USE_REENTRANT_CKPT = False # defaults to disabled (off)
42+
43+
3744
def is_no_jit():
3845
return _NO_JIT
3946

@@ -147,3 +154,12 @@ def set_fused_attn(enable: bool = True, experimental: bool = False):
147154
_USE_FUSED_ATTN = 1
148155
else:
149156
_USE_FUSED_ATTN = 0
157+
158+
159+
def use_reentrant_ckpt() -> bool:
160+
return _USE_REENTRANT_CKPT
161+
162+
163+
def set_reentrant_ckpt(enable: bool = True):
164+
global _USE_REENTRANT_CKPT
165+
_USE_REENTRANT_CKPT = enable

timm/models/_features.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515

1616
import torch
1717
import torch.nn as nn
18-
from torch.utils.checkpoint import checkpoint
1918

2019
from timm.layers import Format, _assert
21-
20+
from ._manipulate import checkpoint
2221

2322
__all__ = [
2423
'FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', 'FeatureGetterNet',

timm/models/_manipulate.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33
import re
44
from collections import defaultdict
55
from itertools import chain
6-
from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union
6+
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Type, Union
77

88
import torch
9+
import torch.utils.checkpoint
910
from torch import nn as nn
10-
from torch.utils.checkpoint import checkpoint
11+
12+
from timm.layers import use_reentrant_ckpt
13+
1114

1215
__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv',
13-
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
16+
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq', 'checkpoint']
1417

1518

1619
def model_parameters(model: nn.Module, exclude_head: bool = False):
@@ -183,13 +186,35 @@ def flatten_modules(
183186
yield name, module
184187

185188

189+
def checkpoint(
190+
function,
191+
*args,
192+
use_reentrant: Optional[bool] = None,
193+
**kwargs,
194+
):
195+
""" checkpoint wrapper fn
196+
197+
A thin wrapper around torch.utils.checkpoint.checkpoint to default
198+
use_reentrant to False
199+
"""
200+
if use_reentrant is None:
201+
use_reentrant = use_reentrant_ckpt()
202+
203+
return torch.utils.checkpoint.checkpoint(
204+
function,
205+
*args,
206+
use_reentrant=use_reentrant,
207+
**kwargs,
208+
)
209+
210+
186211
def checkpoint_seq(
187212
functions,
188213
x,
189-
every=1,
190-
flatten=False,
191-
skip_last=False,
192-
preserve_rng_state=True
214+
every: int = 1,
215+
flatten: bool = False,
216+
skip_last: bool = False,
217+
use_reentrant: Optional[bool] = None,
193218
):
194219
r"""A helper function for checkpointing sequential models.
195220
@@ -215,10 +240,9 @@ def checkpoint_seq(
215240
functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
216241
x: A Tensor that is input to :attr:`functions`
217242
every: checkpoint every-n functions (default: 1)
218-
flatten (bool): flatten nn.Sequential of nn.Sequentials
219-
skip_last (bool): skip checkpointing the last function in the sequence if True
220-
preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
221-
the RNG state during each checkpoint.
243+
flatten: flatten nn.Sequential of nn.Sequentials
244+
skip_last: skip checkpointing the last function in the sequence if True
245+
use_reentrant: Use re-entrant checkpointing
222246
223247
Returns:
224248
Output of running :attr:`functions` sequentially on :attr:`*inputs`
@@ -227,6 +251,9 @@ def checkpoint_seq(
227251
>>> model = nn.Sequential(...)
228252
>>> input_var = checkpoint_seq(model, input_var, every=2)
229253
"""
254+
if use_reentrant is None:
255+
use_reentrant = use_reentrant_ckpt()
256+
230257
def run_function(start, end, functions):
231258
def forward(_x):
232259
for j in range(start, end + 1):
@@ -247,7 +274,11 @@ def forward(_x):
247274
end = -1
248275
for start in range(0, num_checkpointed, every):
249276
end = min(start + every - 1, num_checkpointed - 1)
250-
x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state)
277+
x = torch.utils.checkpoint.checkpoint(
278+
run_function(start, end, functions),
279+
x,
280+
use_reentrant=use_reentrant,
281+
)
251282
if skip_last:
252283
return run_function(end + 1, len(functions) - 1, functions)(x)
253284
return x

timm/models/beit.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,14 @@
4444
import torch
4545
import torch.nn as nn
4646
import torch.nn.functional as F
47-
from torch.utils.checkpoint import checkpoint
4847

4948
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
5049
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn
5150
from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid
5251

53-
5452
from ._builder import build_model_with_cfg
5553
from ._features import feature_take_indices
54+
from ._manipulate import checkpoint
5655
from ._registry import generate_default_cfgs, register_model
5756

5857
__all__ = ['Beit']

timm/models/densenet.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@
88
import torch
99
import torch.nn as nn
1010
import torch.nn.functional as F
11-
import torch.utils.checkpoint as cp
1211
from torch.jit.annotations import List
1312

1413
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1514
from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier
1615
from ._builder import build_model_with_cfg
17-
from ._manipulate import MATCH_PREV_GROUP
16+
from ._manipulate import MATCH_PREV_GROUP, checkpoint
1817
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
1918

2019
__all__ = ['DenseNet']
@@ -60,7 +59,7 @@ def call_checkpoint_bottleneck(self, x):
6059
def closure(*xs):
6160
return self.bottleneck_fn(xs)
6261

63-
return cp.checkpoint(closure, *x)
62+
return checkpoint(closure, *x)
6463

6564
@torch.jit._overload_method # noqa: F811
6665
def forward(self, x):

timm/models/efficientnet.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
import torch
4242
import torch.nn as nn
4343
import torch.nn.functional as F
44-
from torch.utils.checkpoint import checkpoint
4544

4645
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
4746
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, LayerType, \
@@ -51,7 +50,7 @@
5150
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
5251
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
5352
from ._features import FeatureInfo, FeatureHooks, feature_take_indices
54-
from ._manipulate import checkpoint_seq
53+
from ._manipulate import checkpoint_seq, checkpoint
5554
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
5655

5756
__all__ = ['EfficientNet', 'EfficientNetFeatures']

timm/models/eva.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import torch
3131
import torch.nn as nn
3232
import torch.nn.functional as F
33-
from torch.utils.checkpoint import checkpoint
3433

3534
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
3635
from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \
@@ -39,6 +38,7 @@
3938

4039
from ._builder import build_model_with_cfg
4140
from ._features import feature_take_indices
41+
from ._manipulate import checkpoint
4242
from ._registry import generate_default_cfgs, register_model
4343

4444
__all__ = ['Eva']

timm/models/focalnet.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,11 @@
2222

2323
import torch
2424
import torch.nn as nn
25-
import torch.utils.checkpoint as checkpoint
2625

2726
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2827
from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead
2928
from ._builder import build_model_with_cfg
30-
from ._manipulate import named_apply
29+
from ._manipulate import named_apply, checkpoint
3130
from ._registry import generate_default_cfgs, register_model
3231

3332
__all__ = ['FocalNet']

timm/models/gcvit.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,13 @@
2525

2626
import torch
2727
import torch.nn as nn
28-
import torch.utils.checkpoint as checkpoint
2928

3029
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
3130
from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \
3231
get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert
3332
from ._builder import build_model_with_cfg
3433
from ._features_fx import register_notrace_function
35-
from ._manipulate import named_apply
34+
from ._manipulate import named_apply, checkpoint
3635
from ._registry import register_model, generate_default_cfgs
3736

3837
__all__ = ['GlobalContextVit']

timm/models/hiera.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import torch
3030
import torch.nn as nn
3131
import torch.nn.functional as F
32-
from torch.utils.checkpoint import checkpoint
3332

3433
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
3534
from timm.layers import DropPath, Mlp, LayerScale, ClNormMlpClassifierHead, use_fused_attn, \
@@ -39,7 +38,7 @@
3938
from ._builder import build_model_with_cfg
4039
from ._features import feature_take_indices
4140
from ._features_fx import register_notrace_function
42-
from ._manipulate import named_apply
41+
from ._manipulate import named_apply, checkpoint
4342

4443

4544
__all__ = ['Hiera']

timm/models/mobilenetv3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import torch
1313
import torch.nn as nn
1414
import torch.nn.functional as F
15-
from torch.utils.checkpoint import checkpoint
1615

1716
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
1817
from timm.layers import SelectAdaptivePool2d, Linear, LayerType, PadType, create_conv2d, get_norm_act_layer
@@ -21,7 +20,7 @@
2120
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
2221
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
2322
from ._features import FeatureInfo, FeatureHooks, feature_take_indices
24-
from ._manipulate import checkpoint_seq
23+
from ._manipulate import checkpoint_seq, checkpoint
2524
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
2625

2726
__all__ = ['MobileNetV3', 'MobileNetV3Features']

timm/models/mvitv2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@
2020
from typing import Union, List, Tuple, Optional
2121

2222
import torch
23-
import torch.utils.checkpoint as checkpoint
2423
from torch import nn
2524

2625
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2726
from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple
2827
from ._builder import build_model_with_cfg
2928
from ._features import feature_take_indices
3029
from ._features_fx import register_notrace_function
31-
from ._registry import register_model, register_model_deprecations, generate_default_cfgs
30+
from ._manipulate import checkpoint
31+
from ._registry import register_model, generate_default_cfgs
3232

3333
__all__ = ['MultiScaleVit', 'MultiScaleVitCfg'] # model_registry will add each entrypoint fn to this
3434

timm/models/pvt_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
import torch
2222
import torch.nn as nn
2323
import torch.nn.functional as F
24-
import torch.utils.checkpoint as checkpoint
2524

2625
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2726
from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn
2827
from ._builder import build_model_with_cfg
28+
from ._manipulate import checkpoint
2929
from ._registry import register_model, generate_default_cfgs
3030

3131
__all__ = ['PyramidVisionTransformerV2']

timm/models/swin_transformer_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21-
import torch.utils.checkpoint as checkpoint
2221

2322
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
24-
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,\
23+
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, ClassifierHead,\
2524
resample_patch_embed, ndgrid, get_act_layer, LayerType
2625
from ._builder import build_model_with_cfg
2726
from ._features import feature_take_indices
2827
from ._features_fx import register_notrace_function
28+
from ._manipulate import checkpoint
2929
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
3030

3131
__all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this

timm/models/swin_transformer_v2_cr.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,13 @@
3434
import torch
3535
import torch.nn as nn
3636
import torch.nn.functional as F
37-
import torch.utils.checkpoint as checkpoint
3837

3938
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
4039
from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert, ndgrid
4140
from ._builder import build_model_with_cfg
4241
from ._features import feature_take_indices
4342
from ._features_fx import register_notrace_function
44-
from ._manipulate import named_apply
43+
from ._manipulate import named_apply, checkpoint
4544
from ._registry import generate_default_cfgs, register_model
4645

4746
__all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn to this

timm/models/tnt.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111

1212
import torch
1313
import torch.nn as nn
14-
from torch.utils.checkpoint import checkpoint
1514

1615
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
17-
from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple
16+
from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple, resample_abs_pos_embed
1817
from ._builder import build_model_with_cfg
18+
from ._manipulate import checkpoint
1919
from ._registry import register_model
20-
from .vision_transformer import resize_pos_embed
20+
2121

2222
__all__ = ['TNT'] # model_registry will add each entrypoint fn to this
2323

@@ -340,8 +340,11 @@ def forward(self, x):
340340
def checkpoint_filter_fn(state_dict, model):
341341
""" convert patch embedding weight from manual patchify + linear proj to conv"""
342342
if state_dict['patch_pos'].shape != model.patch_pos.shape:
343-
state_dict['patch_pos'] = resize_pos_embed(state_dict['patch_pos'],
344-
model.patch_pos, getattr(model, 'num_tokens', 1), model.pixel_embed.grid_size)
343+
state_dict['patch_pos'] = resample_abs_pos_embed(
344+
state_dict['patch_pos'],
345+
new_size=model.pixel_embed.grid_size,
346+
num_prefix_tokens=1,
347+
)
345348
return state_dict
346349

347350

0 commit comments

Comments
 (0)