Skip to content

Commit 8bdeeea

Browse files
authored
Merge pull request #2501 from brianhou0208/grad_checkpointing
Support gradient checkpointing in `forward_intermediates()`
2 parents 2b9840c + b3a8773 commit 8bdeeea

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+237
-128
lines changed

tests/test_models.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,18 @@ def test_model_forward(model_name, batch_size):
186186
assert outputs.shape[0] == batch_size
187187
assert not torch.isnan(outputs).any(), 'Output included NaNs'
188188

189+
# Test that grad-checkpointing, if supported, doesn't cause model failures or change in output
190+
try:
191+
model.set_grad_checkpointing()
192+
except:
193+
# throws if not supported, that's fine
194+
pass
195+
else:
196+
outputs2 = model(inputs)
197+
if isinstance(outputs, tuple):
198+
outputs2 = torch.cat(outputs2)
199+
assert torch.allclose(outputs, outputs2, rtol=1e-4, atol=1e-5), 'Output does not match'
200+
189201

190202
@pytest.mark.base
191203
@pytest.mark.timeout(timeout120)
@@ -529,6 +541,20 @@ def test_model_forward_intermediates(model_name, batch_size):
529541
output2 = model.forward_features(inpt)
530542
assert torch.allclose(output, output2)
531543

544+
# Test that grad-checkpointing, if supported
545+
try:
546+
model.set_grad_checkpointing()
547+
except:
548+
# throws if not supported, that's fine
549+
pass
550+
else:
551+
output3, _ = model.forward_intermediates(
552+
inpt,
553+
output_fmt=output_fmt,
554+
)
555+
assert torch.allclose(output, output3, rtol=1e-4, atol=1e-5), 'Output does not match'
556+
557+
532558

533559
def _create_fx_model(model, train=False):
534560
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
@@ -717,4 +743,4 @@ def test_model_forward_torchscript_with_features_fx(model_name, batch_size):
717743

718744
for tensor in outputs:
719745
assert tensor.shape[0] == batch_size
720-
assert not torch.isnan(tensor).any(), 'Output included NaNs'
746+
assert not torch.isnan(tensor).any(), 'Output included NaNs'

timm/models/beit.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,10 @@ def forward_intermediates(
615615
else:
616616
blocks = self.blocks[:max_index + 1]
617617
for i, blk in enumerate(blocks):
618-
x = blk(x, shared_rel_pos_bias=rel_pos_bias)
618+
if self.grad_checkpointing and not torch.jit.is_scripting():
619+
x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias)
620+
else:
621+
x = blk(x, shared_rel_pos_bias=rel_pos_bias)
619622
if i in take_indices:
620623
# normalize intermediates with final norm layer if enabled
621624
intermediates.append(self.norm(x) if norm else x)

timm/models/byobnet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1508,7 +1508,10 @@ def forward_intermediates(
15081508
stages = self.stages[:max_index]
15091509
for stage in stages:
15101510
feat_idx += 1
1511-
x = stage(x)
1511+
if self.grad_checkpointing and not torch.jit.is_scripting():
1512+
x = checkpoint_seq(stage, x)
1513+
else:
1514+
x = stage(x)
15121515
if not exclude_final_conv and feat_idx == last_idx:
15131516
# default feature_info for this model uses final_conv as the last feature output (if present)
15141517
x = self.final_conv(x)

timm/models/cait.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, use_fused_attn
1919
from ._builder import build_model_with_cfg
2020
from ._features import feature_take_indices
21-
from ._manipulate import checkpoint_seq
21+
from ._manipulate import checkpoint, checkpoint_seq
2222
from ._registry import register_model, generate_default_cfgs
2323

2424
__all__ = ['Cait', 'ClassAttn', 'LayerScaleBlockClassAttn', 'LayerScaleBlock', 'TalkingHeadAttn']
@@ -373,7 +373,10 @@ def forward_intermediates(
373373
else:
374374
blocks = self.blocks[:max_index + 1]
375375
for i, blk in enumerate(blocks):
376-
x = blk(x)
376+
if self.grad_checkpointing and not torch.jit.is_scripting():
377+
x = checkpoint(blk, x)
378+
else:
379+
x = blk(x)
377380
if i in take_indices:
378381
# normalize intermediates with final norm layer if enabled
379382
intermediates.append(self.norm(x) if norm else x)

timm/models/crossvit.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,16 @@
1414
NOTE: model names have been renamed from originals to represent actual input res all *_224 -> *_240 and *_384 -> *_408
1515
1616
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
17+
Modified from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
1718
"""
1819

1920
# Copyright IBM All Rights Reserved.
2021
# SPDX-License-Identifier: Apache-2.0
2122

22-
23-
"""
24-
Modified from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
25-
26-
"""
2723
from functools import partial
2824
from typing import List, Optional, Tuple
2925

3026
import torch
31-
import torch.hub
3227
import torch.nn as nn
3328

3429
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

timm/models/davit.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ._builder import build_model_with_cfg
2626
from ._features import feature_take_indices
2727
from ._features_fx import register_notrace_function
28-
from ._manipulate import checkpoint_seq
28+
from ._manipulate import checkpoint, checkpoint_seq
2929
from ._registry import generate_default_cfgs, register_model
3030

3131
__all__ = ['DaVit']
@@ -671,7 +671,10 @@ def forward_intermediates(
671671
stages = self.stages[:max_index + 1]
672672

673673
for feat_idx, stage in enumerate(stages):
674-
x = stage(x)
674+
if self.grad_checkpointing and not torch.jit.is_scripting():
675+
x = checkpoint(stage, x)
676+
else:
677+
x = stage(x)
675678
if feat_idx in take_indices:
676679
if norm and feat_idx == last_idx:
677680
x_inter = self.norm_pre(x) # applying final norm to last intermediate

timm/models/dla.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import torch
1212
import torch.nn as nn
13-
import torch.nn.functional as F
1413

1514
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1615
from timm.layers import create_classifier

timm/models/efficientnet.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,11 @@ def forward_intermediates(
259259
blocks = self.blocks
260260
else:
261261
blocks = self.blocks[:max_index]
262-
for blk in blocks:
263-
feat_idx += 1
264-
x = blk(x)
262+
for feat_idx, blk in enumerate(blocks, start=1):
263+
if self.grad_checkpointing and not torch.jit.is_scripting():
264+
x = checkpoint_seq(blk, x)
265+
else:
266+
x = blk(x)
265267
if feat_idx in take_indices:
266268
intermediates.append(x)
267269

timm/models/efficientvit_mit.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,10 @@ def forward_intermediates(
789789
stages = self.stages[:max_index + 1]
790790

791791
for feat_idx, stage in enumerate(stages):
792-
x = stage(x)
792+
if self.grad_checkpointing and not torch.jit.is_scripting():
793+
x = checkpoint_seq(stages, x)
794+
else:
795+
x = stage(x)
793796
if feat_idx in take_indices:
794797
intermediates.append(x)
795798

@@ -943,7 +946,10 @@ def forward_intermediates(
943946
stages = self.stages[:max_index + 1]
944947

945948
for feat_idx, stage in enumerate(stages):
946-
x = stage(x)
949+
if self.grad_checkpointing and not torch.jit.is_scripting():
950+
x = checkpoint_seq(stages, x)
951+
else:
952+
x = stage(x)
947953
if feat_idx in take_indices:
948954
intermediates.append(x)
949955

timm/models/efficientvit_msra.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert
1919
from ._builder import build_model_with_cfg
2020
from ._features import feature_take_indices
21-
from ._manipulate import checkpoint_seq
21+
from ._manipulate import checkpoint, checkpoint_seq
2222
from ._registry import register_model, generate_default_cfgs
2323

2424

@@ -510,7 +510,10 @@ def forward_intermediates(
510510
stages = self.stages[:max_index + 1]
511511

512512
for feat_idx, stage in enumerate(stages):
513-
x = stage(x)
513+
if self.grad_checkpointing and not torch.jit.is_scripting():
514+
x = checkpoint(stage, x)
515+
else:
516+
x = stage(x)
514517
if feat_idx in take_indices:
515518
intermediates.append(x)
516519

0 commit comments

Comments
 (0)