Skip to content

Commit 6f80214

Browse files
authored
Merge pull request #2394 from huggingface/non_reentrant_ckpt
Wrap torch checkpoint() fn to default use_reentrant flag to False and allow env var override
2 parents 131518c + 155f6e7 commit 6f80214

24 files changed

+94
-54
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313
## What's New
1414

15+
## Jan 6, 2025
16+
* Add `torch.utils.checkpoint.checkpoint()` wrapper in `timm.models` that defaults `use_reentrant=False`, unless `TIMM_REENTRANT_CKPT=1` is set in env.
17+
1518
## Dec 31, 2024
1619
* `convnext_nano` 384x384 ImageNet-12k pretrain & fine-tune. https://huggingface.co/models?search=convnext_nano%20r384
1720
* Add AIM-v2 encoders from https://github.com/apple/ml-aim, see on Hub: https://huggingface.co/models?search=timm%20aimv2

timm/layers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead
99
from .cond_conv2d import CondConv2d, get_condconv_initializer
1010
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \
11-
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn
11+
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn, \
12+
set_reentrant_ckpt, use_reentrant_ckpt
1213
from .conv2d_same import Conv2dSame, conv2d_same
1314
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
1415
from .create_act import create_act_layer, get_act_layer, get_act_fn

timm/layers/config.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
__all__ = [
1010
'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn',
11-
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn'
11+
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn',
12+
'set_reentrant_ckpt', 'use_reentrant_ckpt'
1213
]
1314

1415
# Set to True if prefer to have layers with no jit optimization (includes activations)
@@ -34,6 +35,12 @@
3435
_USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)
3536

3637

38+
if 'TIMM_REENTRANT_CKPT' in os.environ:
39+
_USE_REENTRANT_CKPT = bool(os.environ['TIMM_REENTRANT_CKPT'])
40+
else:
41+
_USE_REENTRANT_CKPT = False # defaults to disabled (off)
42+
43+
3744
def is_no_jit():
3845
return _NO_JIT
3946

@@ -147,3 +154,12 @@ def set_fused_attn(enable: bool = True, experimental: bool = False):
147154
_USE_FUSED_ATTN = 1
148155
else:
149156
_USE_FUSED_ATTN = 0
157+
158+
159+
def use_reentrant_ckpt() -> bool:
160+
return _USE_REENTRANT_CKPT
161+
162+
163+
def set_reentrant_ckpt(enable: bool = True):
164+
global _USE_REENTRANT_CKPT
165+
_USE_REENTRANT_CKPT = enable

timm/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint
9292
from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
9393
from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
94-
group_modules, group_parameters, checkpoint_seq, adapt_input_conv
94+
group_modules, group_parameters, checkpoint_seq, checkpoint, adapt_input_conv
9595
from ._pretrained import PretrainedCfg, DefaultCfg, filter_pretrained_cfg
9696
from ._prune import adapt_model_from_string
9797
from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \

timm/models/_features.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515

1616
import torch
1717
import torch.nn as nn
18-
from torch.utils.checkpoint import checkpoint
1918

2019
from timm.layers import Format, _assert
21-
20+
from ._manipulate import checkpoint
2221

2322
__all__ = [
2423
'FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', 'FeatureGetterNet',

timm/models/_manipulate.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33
import re
44
from collections import defaultdict
55
from itertools import chain
6-
from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union
6+
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Type, Union
77

88
import torch
9+
import torch.utils.checkpoint
910
from torch import nn as nn
10-
from torch.utils.checkpoint import checkpoint
11+
12+
from timm.layers import use_reentrant_ckpt
13+
1114

1215
__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv',
13-
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
16+
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq', 'checkpoint']
1417

1518

1619
def model_parameters(model: nn.Module, exclude_head: bool = False):
@@ -183,13 +186,35 @@ def flatten_modules(
183186
yield name, module
184187

185188

189+
def checkpoint(
190+
function,
191+
*args,
192+
use_reentrant: Optional[bool] = None,
193+
**kwargs,
194+
):
195+
""" checkpoint wrapper fn
196+
197+
A thin wrapper around torch.utils.checkpoint.checkpoint to default
198+
use_reentrant to False
199+
"""
200+
if use_reentrant is None:
201+
use_reentrant = use_reentrant_ckpt()
202+
203+
return torch.utils.checkpoint.checkpoint(
204+
function,
205+
*args,
206+
use_reentrant=use_reentrant,
207+
**kwargs,
208+
)
209+
210+
186211
def checkpoint_seq(
187212
functions,
188213
x,
189-
every=1,
190-
flatten=False,
191-
skip_last=False,
192-
preserve_rng_state=True
214+
every: int = 1,
215+
flatten: bool = False,
216+
skip_last: bool = False,
217+
use_reentrant: Optional[bool] = None,
193218
):
194219
r"""A helper function for checkpointing sequential models.
195220
@@ -215,10 +240,9 @@ def checkpoint_seq(
215240
functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
216241
x: A Tensor that is input to :attr:`functions`
217242
every: checkpoint every-n functions (default: 1)
218-
flatten (bool): flatten nn.Sequential of nn.Sequentials
219-
skip_last (bool): skip checkpointing the last function in the sequence if True
220-
preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
221-
the RNG state during each checkpoint.
243+
flatten: flatten nn.Sequential of nn.Sequentials
244+
skip_last: skip checkpointing the last function in the sequence if True
245+
use_reentrant: Use re-entrant checkpointing
222246
223247
Returns:
224248
Output of running :attr:`functions` sequentially on :attr:`*inputs`
@@ -227,6 +251,9 @@ def checkpoint_seq(
227251
>>> model = nn.Sequential(...)
228252
>>> input_var = checkpoint_seq(model, input_var, every=2)
229253
"""
254+
if use_reentrant is None:
255+
use_reentrant = use_reentrant_ckpt()
256+
230257
def run_function(start, end, functions):
231258
def forward(_x):
232259
for j in range(start, end + 1):
@@ -247,7 +274,11 @@ def forward(_x):
247274
end = -1
248275
for start in range(0, num_checkpointed, every):
249276
end = min(start + every - 1, num_checkpointed - 1)
250-
x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state)
277+
x = torch.utils.checkpoint.checkpoint(
278+
run_function(start, end, functions),
279+
x,
280+
use_reentrant=use_reentrant,
281+
)
251282
if skip_last:
252283
return run_function(end + 1, len(functions) - 1, functions)(x)
253284
return x

timm/models/beit.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,14 @@
4444
import torch
4545
import torch.nn as nn
4646
import torch.nn.functional as F
47-
from torch.utils.checkpoint import checkpoint
4847

4948
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
5049
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn
5150
from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid
5251

53-
5452
from ._builder import build_model_with_cfg
5553
from ._features import feature_take_indices
54+
from ._manipulate import checkpoint
5655
from ._registry import generate_default_cfgs, register_model
5756

5857
__all__ = ['Beit']

timm/models/densenet.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@
88
import torch
99
import torch.nn as nn
1010
import torch.nn.functional as F
11-
import torch.utils.checkpoint as cp
1211
from torch.jit.annotations import List
1312

1413
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1514
from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier
1615
from ._builder import build_model_with_cfg
17-
from ._manipulate import MATCH_PREV_GROUP
16+
from ._manipulate import MATCH_PREV_GROUP, checkpoint
1817
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
1918

2019
__all__ = ['DenseNet']
@@ -60,7 +59,7 @@ def call_checkpoint_bottleneck(self, x):
6059
def closure(*xs):
6160
return self.bottleneck_fn(xs)
6261

63-
return cp.checkpoint(closure, *x)
62+
return checkpoint(closure, *x)
6463

6564
@torch.jit._overload_method # noqa: F811
6665
def forward(self, x):

timm/models/efficientnet.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
import torch
4242
import torch.nn as nn
4343
import torch.nn.functional as F
44-
from torch.utils.checkpoint import checkpoint
4544

4645
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
4746
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, LayerType, \
@@ -51,7 +50,7 @@
5150
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
5251
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
5352
from ._features import FeatureInfo, FeatureHooks, feature_take_indices
54-
from ._manipulate import checkpoint_seq
53+
from ._manipulate import checkpoint_seq, checkpoint
5554
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
5655

5756
__all__ = ['EfficientNet', 'EfficientNetFeatures']

timm/models/eva.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import torch
3131
import torch.nn as nn
3232
import torch.nn.functional as F
33-
from torch.utils.checkpoint import checkpoint
3433

3534
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
3635
from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \
@@ -39,6 +38,7 @@
3938

4039
from ._builder import build_model_with_cfg
4140
from ._features import feature_take_indices
41+
from ._manipulate import checkpoint
4242
from ._registry import generate_default_cfgs, register_model
4343

4444
__all__ = ['Eva']

0 commit comments

Comments
 (0)