Skip to content

Commit 725071e

Browse files
committed
More forward_intermediate specific grad checkpointing fixes
1 parent 1f9eb66 commit 725071e

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

timm/models/starnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def forward_intermediates(
199199

200200
for feat_idx, stage in enumerate(stages):
201201
if self.grad_checkpointing and not torch.jit.is_scripting():
202-
x = checkpoint_seq(stages, x)
202+
x = checkpoint_seq(stage, x)
203203
else:
204204
x = stage(x)
205205
if feat_idx in take_indices:

timm/models/tiny_vit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ def forward_intermediates(
571571

572572
for feat_idx, stage in enumerate(stages):
573573
if self.grad_checkpointing and not torch.jit.is_scripting():
574-
x = checkpoint(stages, x)
574+
x = checkpoint(stage, x)
575575
else:
576576
x = stage(x)
577577
if feat_idx in take_indices:

0 commit comments

Comments
 (0)