diff --git a/.github/workflows/e2e_test.yml b/.github/workflows/e2e_test.yml index 90547b8c..e2f0bbe6 100644 --- a/.github/workflows/e2e_test.yml +++ b/.github/workflows/e2e_test.yml @@ -17,6 +17,7 @@ jobs: outputs: llama-3-8b-name: ${{ steps.run-llama-3-8b.outputs.name }} llama-3_1-8b-sa-name: ${{ steps.run-llama-3_1-8b-SplashAttention.outputs.name }} + llama-3_1-8b-scan-offload-name: ${{ steps.run-llama-3_1-8b-scan-offload.outputs.name }} llama-3-8b-2d-name: ${{ steps.run-llama-3-8b-2d.outputs.name }} llama-3-8b-2-slice-name: ${{ steps.run-llama-3-8b-2-slice.outputs.name }} mixtral-8x7b-name: ${{ steps.run-mixtral-8x7b.outputs.name }} @@ -62,6 +63,7 @@ jobs: dataset_config_name=wikitext-2-raw-v1 \ profile_step=3 \ max_steps=15 + - name: Run Llama 3.1 8B (Splash Attention) id: run-llama-3_1-8b-SplashAttention env: @@ -81,6 +83,27 @@ jobs: dataset_config_name=wikitext-2-raw-v1 \ profile_step=3 \ max_steps=15 + + - name: Run Llama 3.1 8B (Scan + Offload) + id: run-llama-3_1-8b-scan-offload + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + XLA_IR_DEBUG: 1 + XLA_HLO_DEBUG: 1 + run: | + name=$(e2e_testing/gen_name.py llama-3dot1-8b-sa) + echo "name=$name" >> "$GITHUB_OUTPUT" + tp run \ + --name $name \ + torchprime/torch_xla_models/train.py \ + model=llama-3.1-8b \ + model/remat=llama-scan-offload \ + global_batch_size=8 \ + ici_mesh.fsdp=4 \ + dataset_config_name=wikitext-2-raw-v1 \ + profile_step=3 \ + max_steps=15 + - name: Run Llama 3.0 8B (2D sharding) id: run-llama-3-8b-2d env: @@ -94,7 +117,7 @@ jobs: --name $name \ torchprime/torch_xla_models/train.py \ model=llama-3-8b \ - model/scaling=llama-fsdp-tp \ + model/sharding=llama-fsdp-tp \ global_batch_size=8 \ ici_mesh.fsdp=2 \ ici_mesh.tensor=2 \ @@ -136,7 +159,7 @@ jobs: --num-slices 2 \ torchprime/torch_xla_models/train.py \ model=llama-3-8b \ - model/scaling=llama-fsdp \ + model/sharding=llama-fsdp \ global_batch_size=16 \ dcn_mesh.fsdp=2 \ ici_mesh.fsdp=4 \ @@ -154,7 +177,7 @@ jobs: secrets: inherit llama-3_1-8b-sa: - name: Llama 3.1 8B (Spalsh Attention) + name: Llama 3.1 8B (Splash Attention) needs: tp-run uses: ./.github/workflows/reusable_e2e_check.yml with: @@ -162,6 +185,15 @@ jobs: artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }} secrets: inherit + llama-3_1-8b-scan-offload: + name: Llama 3.1 8B (Scan + Offload) + needs: tp-run + uses: ./.github/workflows/reusable_e2e_check.yml + with: + jobset_name: ${{ needs.tp-run.outputs.llama-3_1-8b-scan-offload-name }} + artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }} + secrets: inherit + llama-3-8b-2d: name: Llama 3.0 8B (2D sharding) needs: tp-run diff --git a/torchprime/torch_xla_models/README.md b/torchprime/torch_xla_models/README.md index d3cb1249..639ab18e 100644 --- a/torchprime/torch_xla_models/README.md +++ b/torchprime/torch_xla_models/README.md @@ -137,8 +137,10 @@ tp run torchprime/torch_xla_models/train.py \ - `train.py`: Main training script that sets up the model, data, and training loop - `configs/base.yaml`: Configuration file for the training script - `configs/model`: Configuration files for models -- `configs/model/scaling`: Configuration files for scaling the training of a model, e.g. - rematerialization, sharding tensors. +- `configs/model/sharding`: Configuration files for distributing the training + over many chips +- `configs/model/remat`: Configuration files for rematerialization strategy e.g. + activation checkpointing, host offloading - `llama/model.py`: Implementation of the Llama model family - `mixtral/model.py`: Implementation of the Mixtral model family diff --git a/torchprime/torch_xla_models/configs/default.yaml b/torchprime/torch_xla_models/configs/default.yaml index c035a39f..4bdae4c9 100644 --- a/torchprime/torch_xla_models/configs/default.yaml +++ b/torchprime/torch_xla_models/configs/default.yaml @@ -47,3 +47,26 @@ dcn_mesh: fsdp: 1 tensor: 1 expert: 1 + +# These are default values for model activation rematerialization configuration. +# They can be overridden on the command line or by importing one of the presets +# in the `model/remat` directory. +model: + remat: + # The class names of model layers whose intermediate activations should be + # recomputed during the backward pass (i.e. activation checkpointing). + activation_checkpoint_layers: [] + + # If not null, compile a module of type `HomogeneousSequential` located at the + # given path in the module tree using `torch_xla.experimental.scan_layers`. + scan_layers: null + + # If specified, offload these tensors to host RAM during the forward pass and + # move them back during the backward pass. + # + # The tensors to be offloaded should be given a name by wrapping them with the + # `torchprime.torch_xla_models.offloading.offload_name` call. Then the same + # name could be specified here to offload that tensor. + # + # Currently in order to offload tensors, `scan_layers` must also be enabled. + offload_tensors: [] diff --git a/torchprime/torch_xla_models/configs/model/llama-3-8b.yaml b/torchprime/torch_xla_models/configs/model/llama-3-8b.yaml index e7e29b60..ad6b90fb 100644 --- a/torchprime/torch_xla_models/configs/model/llama-3-8b.yaml +++ b/torchprime/torch_xla_models/configs/model/llama-3-8b.yaml @@ -1,6 +1,7 @@ defaults: - _self_ # refers to this config file - - scaling: llama-fsdp # refers to scaling/llama-fsdp.yaml + - sharding: llama-fsdp # refers to sharding/llama-fsdp.yaml + - remat: llama # refers to remat/llama.yaml model_class: llama.LlamaForCausalLM # Used to import the model from this class vocab_size: 128256 diff --git a/torchprime/torch_xla_models/configs/model/llama-3.1-405b.yaml b/torchprime/torch_xla_models/configs/model/llama-3.1-405b.yaml index ff2896c5..072adb34 100644 --- a/torchprime/torch_xla_models/configs/model/llama-3.1-405b.yaml +++ b/torchprime/torch_xla_models/configs/model/llama-3.1-405b.yaml @@ -1,6 +1,7 @@ defaults: - _self_ - - scaling: llama-fsdp-tp + - sharding: llama-fsdp-tp + - remat: llama model_class: llama.LlamaForCausalLM vocab_size: 128256 diff --git a/torchprime/torch_xla_models/configs/model/llama-3.1-8b.yaml b/torchprime/torch_xla_models/configs/model/llama-3.1-8b.yaml index eeb7b41a..dfbc9ca6 100644 --- a/torchprime/torch_xla_models/configs/model/llama-3.1-8b.yaml +++ b/torchprime/torch_xla_models/configs/model/llama-3.1-8b.yaml @@ -1,6 +1,7 @@ defaults: - _self_ # refers to this config file - - scaling: llama-fsdp # refers to scaling/llama-fsdp.yaml + - sharding: llama-fsdp # refers to sharding/llama-fsdp.yaml + - remat: llama # refers to remat/llama.yaml model_class: llama.LlamaForCausalLM # Used to import the model from this class vocab_size: 128256 @@ -25,4 +26,4 @@ rope_scaling: factor: 8.0 low_freq_factor: 1.0 high_freq_factor: 4.0 - original_context_len: 8192 \ No newline at end of file + original_context_len: 8192 diff --git a/torchprime/torch_xla_models/configs/model/mixtral-8x7b.yaml b/torchprime/torch_xla_models/configs/model/mixtral-8x7b.yaml index 739342f9..f82e4eae 100644 --- a/torchprime/torch_xla_models/configs/model/mixtral-8x7b.yaml +++ b/torchprime/torch_xla_models/configs/model/mixtral-8x7b.yaml @@ -1,6 +1,7 @@ defaults: - _self_ # refers to this config file - - scaling: mixtral-fsdp # refers to scaling/mixtral-fsdp.yaml + - sharding: mixtral-fsdp # refers to sharding/mixtral-fsdp.yaml + - remat: mixtral # refers to remat/mixtral.yaml model_class: mixtral.MixtralForCausalLM # Used to import the model from this class bos_token_id: 1 diff --git a/torchprime/torch_xla_models/configs/model/remat/llama-scan-offload.yaml b/torchprime/torch_xla_models/configs/model/remat/llama-scan-offload.yaml new file mode 100644 index 00000000..4fd5716d --- /dev/null +++ b/torchprime/torch_xla_models/configs/model/remat/llama-scan-offload.yaml @@ -0,0 +1,6 @@ +defaults: + - llama-scan + - _self_ + +offload_tensors: + - decoder_input diff --git a/torchprime/torch_xla_models/configs/model/remat/llama-scan.yaml b/torchprime/torch_xla_models/configs/model/remat/llama-scan.yaml new file mode 100644 index 00000000..d7505854 --- /dev/null +++ b/torchprime/torch_xla_models/configs/model/remat/llama-scan.yaml @@ -0,0 +1,6 @@ +defaults: + - llama + - _self_ + +# Compile the sequence of decoder layers at `model.layers` using scan. +scan_layers: 'model.layers' diff --git a/torchprime/torch_xla_models/configs/model/remat/llama.yaml b/torchprime/torch_xla_models/configs/model/remat/llama.yaml new file mode 100644 index 00000000..8232cbdb --- /dev/null +++ b/torchprime/torch_xla_models/configs/model/remat/llama.yaml @@ -0,0 +1,6 @@ +activation_checkpoint_layers: + - LlamaDecoderLayer + +# Refer to https://github.com/pytorch/xla/issues/6379 for backward optimization barrier info. +optimization_barrier_layers: + - LlamaDecoderLayer diff --git a/torchprime/torch_xla_models/configs/model/remat/mixtral.yaml b/torchprime/torch_xla_models/configs/model/remat/mixtral.yaml new file mode 100644 index 00000000..d0910a81 --- /dev/null +++ b/torchprime/torch_xla_models/configs/model/remat/mixtral.yaml @@ -0,0 +1,6 @@ +activation_checkpoint_layers: + - MixtralDecoderLayer + +# Refer to https://github.com/pytorch/xla/issues/6379 for backward optimization barrier info. +optimization_barrier_layers: + - MixtralDecoderLayer diff --git a/torchprime/torch_xla_models/configs/model/scaling/llama-fsdp-tp.yaml b/torchprime/torch_xla_models/configs/model/scaling/llama-fsdp-tp.yaml deleted file mode 100644 index 1147af4c..00000000 --- a/torchprime/torch_xla_models/configs/model/scaling/llama-fsdp-tp.yaml +++ /dev/null @@ -1,37 +0,0 @@ -# 2D (FSDP + TP) sharding configuration for Llama models. - -activation_checkpoint_layers: - - LlamaDecoderLayer - -# Refer to https://github.com/pytorch/xla/issues/6379 for backward optimization barrier info. -optimization_barrier_layers: - - LlamaDecoderLayer - -sharding: - # Weights - - # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/114): This - # cannot be `[tensor, fsdp]`, or the gradients will sometimes become NaN. - model.embed_tokens.weight: [fsdp, tensor] - - model.layers.*.self_attn.q_proj.weight: [tensor, fsdp] - model.layers.*.self_attn.k_proj.weight: [tensor, fsdp] - model.layers.*.self_attn.v_proj.weight: [tensor, fsdp] - model.layers.*.self_attn.o_proj.weight: [fsdp, tensor] - model.layers.*.mlp.gate_proj.weight: [tensor, fsdp] - model.layers.*.mlp.up_proj.weight: [tensor, fsdp] - model.layers.*.mlp.down_proj.weight: [fsdp, tensor] - model.layers.*.input_layernorm.weight: [fsdp] - model.layers.*.post_attention_layernorm.weight: [fsdp] - model.norm.weight: [fsdp] - lm_head.weight: [tensor, fsdp] - - # Activations - model.layers.*.self_attn.q_proj: [fsdp, null, tensor] - model.layers.*.self_attn.k_proj: [fsdp, null, tensor] - model.layers.*.self_attn.v_proj: [fsdp, null, tensor] - model.layers.*.self_attn.o_proj: [fsdp, null, tensor] - model.layers.*.input_layernorm: [fsdp, null, tensor] - model.layers.*.post_attention_layernorm: [fsdp, null, tensor] - model.layers.*.mlp: [fsdp, null, tensor] - model.layers.*: [fsdp, null, tensor] diff --git a/torchprime/torch_xla_models/configs/model/scaling/llama-fsdp.yaml b/torchprime/torch_xla_models/configs/model/scaling/llama-fsdp.yaml deleted file mode 100644 index 2aeca9c6..00000000 --- a/torchprime/torch_xla_models/configs/model/scaling/llama-fsdp.yaml +++ /dev/null @@ -1,25 +0,0 @@ -activation_checkpoint_layers: - - LlamaDecoderLayer - -# Refer to https://github.com/pytorch/xla/issues/6379 for backward optimization barrier info. -optimization_barrier_layers: - - LlamaDecoderLayer - -sharding: - # Weights - model.embed_tokens.weight: [fsdp, null] - model.layers.*.self_attn.q_proj.weight: [fsdp, null] - model.layers.*.self_attn.k_proj.weight: [null, fsdp] - model.layers.*.self_attn.v_proj.weight: [null, fsdp] - model.layers.*.self_attn.o_proj.weight: [fsdp, null] - model.layers.*.mlp.gate_proj.weight: [fsdp, null] - model.layers.*.mlp.up_proj.weight: [fsdp, null] - model.layers.*.mlp.down_proj.weight: [null, fsdp] - model.layers.*.input_layernorm.weight: [fsdp] - model.layers.*.post_attention_layernorm.weight: [fsdp] - model.norm.weight: [fsdp] - lm_head.weight: [fsdp, null] - - # Activations - model.layers.*: [fsdp, null, null] - lm_head: [fsdp, null, null] diff --git a/torchprime/torch_xla_models/configs/model/scaling/mixtral-fsdp.yaml b/torchprime/torch_xla_models/configs/model/scaling/mixtral-fsdp.yaml deleted file mode 100644 index b52cafa5..00000000 --- a/torchprime/torch_xla_models/configs/model/scaling/mixtral-fsdp.yaml +++ /dev/null @@ -1,26 +0,0 @@ -activation_checkpoint_layers: - - MixtralDecoderLayer - -# Refer to https://github.com/pytorch/xla/issues/6379 for backward optimization barrier info. -optimization_barrier_layers: - - MixtralDecoderLayer - -sharding: - # Weights - model.embed_tokens.weight: ['fsdp', null] - model.layers.*.self_attn.q_proj.weight: ['fsdp', null] - model.layers.*.self_attn.k_proj.weight: [null, 'fsdp'] - model.layers.*.self_attn.v_proj.weight: [null, 'fsdp'] - model.layers.*.self_attn.o_proj.weight: ['fsdp', null] - model.layers.*.block_sparse_moe.gate.weight: [null, 'fsdp'] - model.layers.*.block_sparse_moe.experts.w1: [null, null, 'fsdp'] - model.layers.*.block_sparse_moe.experts.w2: [null, 'fsdp', null] - model.layers.*.block_sparse_moe.experts.w3: [null, null, 'fsdp'] - model.layers.*.input_layernorm.weight: ['fsdp'] - model.layers.*.post_attention_layernorm.weight: ['fsdp'] - model.norm.weight: ['fsdp'] - lm_head.weight: ['fsdp', null] - - # Activations - model.layers.*[0]: [fsdp, null, null] # Shard the first output of the decoder layer - lm_head: [fsdp, null, null] diff --git a/torchprime/torch_xla_models/configs/model/sharding/llama-fsdp-tp.yaml b/torchprime/torch_xla_models/configs/model/sharding/llama-fsdp-tp.yaml new file mode 100644 index 00000000..bd10c7db --- /dev/null +++ b/torchprime/torch_xla_models/configs/model/sharding/llama-fsdp-tp.yaml @@ -0,0 +1,29 @@ +# 2D (FSDP + TP) sharding configuration for Llama models. + +# Weights + +# TODO(https://github.com/AI-Hypercomputer/torchprime/issues/114): This +# cannot be `[tensor, fsdp]`, or the gradients will sometimes become NaN. +model.embed_tokens.weight: [fsdp, tensor] + +model.layers.*.self_attn.q_proj.weight: [tensor, fsdp] +model.layers.*.self_attn.k_proj.weight: [tensor, fsdp] +model.layers.*.self_attn.v_proj.weight: [tensor, fsdp] +model.layers.*.self_attn.o_proj.weight: [fsdp, tensor] +model.layers.*.mlp.gate_proj.weight: [tensor, fsdp] +model.layers.*.mlp.up_proj.weight: [tensor, fsdp] +model.layers.*.mlp.down_proj.weight: [fsdp, tensor] +model.layers.*.input_layernorm.weight: [fsdp] +model.layers.*.post_attention_layernorm.weight: [fsdp] +model.norm.weight: [fsdp] +lm_head.weight: [tensor, fsdp] + +# Activations +model.layers.*.self_attn.q_proj: [fsdp, null, tensor] +model.layers.*.self_attn.k_proj: [fsdp, null, tensor] +model.layers.*.self_attn.v_proj: [fsdp, null, tensor] +model.layers.*.self_attn.o_proj: [fsdp, null, tensor] +model.layers.*.input_layernorm: [fsdp, null, tensor] +model.layers.*.post_attention_layernorm: [fsdp, null, tensor] +model.layers.*.mlp: [fsdp, null, tensor] +model.layers.*: [fsdp, null, tensor] diff --git a/torchprime/torch_xla_models/configs/model/sharding/llama-fsdp.yaml b/torchprime/torch_xla_models/configs/model/sharding/llama-fsdp.yaml new file mode 100644 index 00000000..28db781a --- /dev/null +++ b/torchprime/torch_xla_models/configs/model/sharding/llama-fsdp.yaml @@ -0,0 +1,17 @@ +# Weights +model.embed_tokens.weight: [fsdp, null] +model.layers.*.self_attn.q_proj.weight: [fsdp, null] +model.layers.*.self_attn.k_proj.weight: [null, fsdp] +model.layers.*.self_attn.v_proj.weight: [null, fsdp] +model.layers.*.self_attn.o_proj.weight: [fsdp, null] +model.layers.*.mlp.gate_proj.weight: [fsdp, null] +model.layers.*.mlp.up_proj.weight: [fsdp, null] +model.layers.*.mlp.down_proj.weight: [null, fsdp] +model.layers.*.input_layernorm.weight: [fsdp] +model.layers.*.post_attention_layernorm.weight: [fsdp] +model.norm.weight: [fsdp] +lm_head.weight: [fsdp, null] + +# Activations +model.layers.*: [fsdp, null, null] +lm_head: [fsdp, null, null] diff --git a/torchprime/torch_xla_models/configs/model/sharding/mixtral-fsdp.yaml b/torchprime/torch_xla_models/configs/model/sharding/mixtral-fsdp.yaml new file mode 100644 index 00000000..2e9480a9 --- /dev/null +++ b/torchprime/torch_xla_models/configs/model/sharding/mixtral-fsdp.yaml @@ -0,0 +1,18 @@ +# Weights +model.embed_tokens.weight: ['fsdp', null] +model.layers.*.self_attn.q_proj.weight: ['fsdp', null] +model.layers.*.self_attn.k_proj.weight: [null, 'fsdp'] +model.layers.*.self_attn.v_proj.weight: [null, 'fsdp'] +model.layers.*.self_attn.o_proj.weight: ['fsdp', null] +model.layers.*.block_sparse_moe.gate.weight: [null, 'fsdp'] +model.layers.*.block_sparse_moe.experts.w1: [null, null, 'fsdp'] +model.layers.*.block_sparse_moe.experts.w2: [null, 'fsdp', null] +model.layers.*.block_sparse_moe.experts.w3: [null, null, 'fsdp'] +model.layers.*.input_layernorm.weight: ['fsdp'] +model.layers.*.post_attention_layernorm.weight: ['fsdp'] +model.norm.weight: ['fsdp'] +lm_head.weight: ['fsdp', null] + +# Activations +model.layers.*[0]: [fsdp, null, null] # Shard the first output of the decoder layer +lm_head: [fsdp, null, null] diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index 93c02aa8..b17a1d3a 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -4,6 +4,7 @@ import math import sys from contextlib import contextmanager +from functools import partial from timeit import default_timer as timer # Third-party library imports @@ -37,11 +38,13 @@ from transformers.trainer_pt_utils import get_module_class_from_name from transformers.utils import check_min_version +from torchprime.layers.sequential import HomogeneousSequential from torchprime.metrics.step_duration import step_duration_from_latest_profile from torchprime.sharding.shard_model import ( shard_torch_xla_model_from_config, wrap_module, ) +from torchprime.torch_xla_models import offloading, remat_all, scan_layers from torchprime.torch_xla_models.topology import get_mesh, is_1d_sharding check_min_version("4.39.3") @@ -90,16 +93,14 @@ def __init__( # Annotate model weights and activations with sharding constraints to distribute # the training across devices following the SPMD paradigm. - sharding_config = OmegaConf.to_container( - self.config.model.scaling.sharding, resolve=True - ) + sharding_config = OmegaConf.to_container(self.config.model.sharding, resolve=True) assert isinstance( sharding_config, dict ), f"Sharding config {sharding_config} must be a dict" model = shard_torch_xla_model_from_config(model, config=sharding_config) # Rematerialize forward computation during the backward pass if requested. - model = self._checkpoint_model(model) + model = self._add_checkpoint_offload_scan_model(model) model = self._add_optimization_barrier_model(model) self.model = model @@ -171,25 +172,65 @@ def _get_train_dataloader(self): ) return loader - def _checkpoint_model(self, model: nn.Module): - classes = self._get_classes_by_names( - model, self.config.model.scaling.get("activation_checkpoint_layers", []) + def _add_checkpoint_offload_scan_model(self, model: nn.Module): + remat_classes = self._get_classes_by_names( + model, self.config.model.remat.get("activation_checkpoint_layers", []) ) - if not classes: - return model - - logger.info(f"Enabling activation checkpointing on {classes}") + layers_to_scan = self.config.model.remat.get("scan_layers", None) + offload_tensors = self.config.model.remat.get("offload_tensors", []) + + # Checking preconditions and logging. + if remat_classes: + logger.info(f"Enabling activation checkpointing on {remat_classes}") + if layers_to_scan: + assert isinstance(layers_to_scan, str) + logger.info(f"Compiling module `{layers_to_scan}` with scan") + if len(offload_tensors): + logger.info(f"Will offload these tensors to host RAM: {offload_tensors}") + if layers_to_scan is None: + raise NotImplementedError("Host offloading requires scan") + if len(remat_classes) != 1: + raise NotImplementedError( + "Host offloading requires checkpointing exactly one layer" + ) def maybe_checkpoint(mod, _name): - if isinstance(mod, tuple(classes)): + if isinstance(mod, tuple(remat_classes)): return checkpoint_module(mod) return mod - return wrap_module(model, maybe_checkpoint) + if layers_to_scan is None: + # Implement activation checkpointing without scan by wrapping modules. + if not remat_classes: + return model + return wrap_module(model, maybe_checkpoint) + + if not remat_classes: + # Scan without activation checkpointing. + return scan_layers.compile(model, layers_to_scan) + + # Implement activation checkpointing and host offloading under scan via + # a graph partitioner instead of `checkpoint_module`. + seq = model.get_submodule(layers_to_scan) + assert isinstance(seq, HomogeneousSequential) + if len(remat_classes) != 1 or list(remat_classes)[0] != seq.repeated_layer: + raise NotImplementedError( + f"When compiling decoder layers with scan and \ + activation checkpointing is also requested, we only support \ + checkpointing {seq.repeated_layer} i.e. the layer being scanned." + ) + if not len(offload_tensors): + partition_fn = remat_all.remat_all_partition_fn + else: + partition_fn = partial( + offloading.remat_all_and_offload_these_inputs, + names_to_offload=offload_tensors, + ) + return scan_layers.compile(model, layers_to_scan, partition_fn=partition_fn) def _add_optimization_barrier_model(self, model: nn.Module): classes = self._get_classes_by_names( - model, self.config.model.scaling.get("optimization_barrier_layers", []) + model, self.config.model.remat.get("optimization_barrier_layers", []) ) if not classes: return model