Skip to content

Commit c0a0c75

Browse files
authored
Activation checkpointing with offloading to disk with prefetch (axolotl-ai-cloud#2663)
* offload activations to disk instead of CPU RAM * add prefetch * Disco :dance: * include offload_disk in e2e test for AC * document and make sure to cleanup * fix annotation to match docs * fix docs build * address PR feedback
1 parent 7fa1089 commit c0a0c75

File tree

8 files changed

+577
-11
lines changed

8 files changed

+577
-11
lines changed

_quarto.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ quartodoc:
139139
- utils.optimizers.adopt
140140
- utils.data.pretraining
141141
- utils.data.sft
142-
- utils.gradient_checkpointing.unsloth
142+
- utils.gradient_checkpointing.offload_cpu
143+
- utils.gradient_checkpointing.offload_disk
143144
- title: Schemas
144145
desc: Pydantic data models for Axolotl config
145146
contents:

docs/config.qmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ train_on_inputs: false
539539
# Note that training loss may have an oscillating pattern with this enabled.
540540
group_by_length: false
541541

542-
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
542+
# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk".
543543
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
544544
gradient_checkpointing: false
545545
# additional kwargs to pass to the trainer for gradient checkpointing

src/axolotl/utils/gradient_checkpointing/__init__.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55

66
from packaging import version
77

8-
from axolotl.utils.gradient_checkpointing.unsloth import (
9-
Unsloth_Offloaded_Gradient_Checkpointer,
8+
from axolotl.utils.gradient_checkpointing.offload_cpu import (
9+
CPU_Offloaded_Gradient_Checkpointer,
10+
)
11+
from axolotl.utils.gradient_checkpointing.offload_disk import (
12+
Disco,
1013
)
1114

1215
transformers_version = version.parse(importlib.metadata.version("transformers"))
@@ -26,12 +29,31 @@ def hf_grad_checkpoint_offload_wrapper(
2629
decoder_layer, *args, use_reentrant=None
2730
): # pylint: disable=unused-argument
2831
if uses_gc_layers(decoder_layer):
29-
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
32+
return CPU_Offloaded_Gradient_Checkpointer.apply(
33+
decoder_layer,
34+
*args,
35+
)
36+
37+
return CPU_Offloaded_Gradient_Checkpointer.apply(
38+
(
39+
decoder_layer.func.__self__
40+
if isinstance(decoder_layer, partial)
41+
else decoder_layer.__self__
42+
),
43+
*args,
44+
)
45+
46+
47+
def hf_grad_checkpoint_disk_offload_wrapper(
48+
decoder_layer, *args, use_reentrant=None
49+
): # pylint: disable=unused-argument
50+
if uses_gc_layers(decoder_layer):
51+
return Disco.apply(
3052
decoder_layer,
3153
*args,
3254
)
3355

34-
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
56+
return Disco.apply(
3557
(
3658
decoder_layer.func.__self__
3759
if isinstance(decoder_layer, partial)

src/axolotl/utils/gradient_checkpointing/unsloth.py renamed to src/axolotl/utils/gradient_checkpointing/offload_cpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Unsloth checkpointing"""
1+
"""CPU offloaded checkpointing"""
22

33
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
44
#
@@ -26,7 +26,7 @@
2626
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
2727

2828

29-
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
29+
class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
3030
torch.autograd.Function
3131
):
3232
"""

0 commit comments

Comments
 (0)