33import re
44from collections import defaultdict
55from itertools import chain
6- from typing import Any , Callable , Dict , Iterator , Tuple , Type , Union
6+ from typing import Any , Callable , Dict , Iterator , Optional , Tuple , Type , Union
77
88import torch
9+ import torch .utils .checkpoint
910from torch import nn as nn
10- from torch .utils .checkpoint import checkpoint
11+
12+ from timm .layers import use_reentrant_ckpt
13+
1114
1215__all__ = ['model_parameters' , 'named_apply' , 'named_modules' , 'named_modules_with_params' , 'adapt_input_conv' ,
13- 'group_with_matcher' , 'group_modules' , 'group_parameters' , 'flatten_modules' , 'checkpoint_seq' ]
16+ 'group_with_matcher' , 'group_modules' , 'group_parameters' , 'flatten_modules' , 'checkpoint_seq' , 'checkpoint' ]
1417
1518
1619def model_parameters (model : nn .Module , exclude_head : bool = False ):
@@ -183,13 +186,35 @@ def flatten_modules(
183186 yield name , module
184187
185188
189+ def checkpoint (
190+ function ,
191+ * args ,
192+ use_reentrant : Optional [bool ] = None ,
193+ ** kwargs ,
194+ ):
195+ """ checkpoint wrapper fn
196+
197+ A thin wrapper around torch.utils.checkpoint.checkpoint to default
198+ use_reentrant to False
199+ """
200+ if use_reentrant is None :
201+ use_reentrant = use_reentrant_ckpt ()
202+
203+ return torch .utils .checkpoint .checkpoint (
204+ function ,
205+ * args ,
206+ use_reentrant = use_reentrant ,
207+ ** kwargs ,
208+ )
209+
210+
186211def checkpoint_seq (
187212 functions ,
188213 x ,
189- every = 1 ,
190- flatten = False ,
191- skip_last = False ,
192- preserve_rng_state = True
214+ every : int = 1 ,
215+ flatten : bool = False ,
216+ skip_last : bool = False ,
217+ use_reentrant : Optional [ bool ] = None ,
193218):
194219 r"""A helper function for checkpointing sequential models.
195220
@@ -215,10 +240,9 @@ def checkpoint_seq(
215240 functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
216241 x: A Tensor that is input to :attr:`functions`
217242 every: checkpoint every-n functions (default: 1)
218- flatten (bool): flatten nn.Sequential of nn.Sequentials
219- skip_last (bool): skip checkpointing the last function in the sequence if True
220- preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
221- the RNG state during each checkpoint.
243+ flatten: flatten nn.Sequential of nn.Sequentials
244+ skip_last: skip checkpointing the last function in the sequence if True
245+ use_reentrant: Use re-entrant checkpointing
222246
223247 Returns:
224248 Output of running :attr:`functions` sequentially on :attr:`*inputs`
@@ -227,6 +251,9 @@ def checkpoint_seq(
227251 >>> model = nn.Sequential(...)
228252 >>> input_var = checkpoint_seq(model, input_var, every=2)
229253 """
254+ if use_reentrant is None :
255+ use_reentrant = use_reentrant_ckpt ()
256+
230257 def run_function (start , end , functions ):
231258 def forward (_x ):
232259 for j in range (start , end + 1 ):
@@ -247,7 +274,11 @@ def forward(_x):
247274 end = - 1
248275 for start in range (0 , num_checkpointed , every ):
249276 end = min (start + every - 1 , num_checkpointed - 1 )
250- x = checkpoint (run_function (start , end , functions ), x , preserve_rng_state = preserve_rng_state )
277+ x = torch .utils .checkpoint .checkpoint (
278+ run_function (start , end , functions ),
279+ x ,
280+ use_reentrant = use_reentrant ,
281+ )
251282 if skip_last :
252283 return run_function (end + 1 , len (functions ) - 1 , functions )(x )
253284 return x
0 commit comments