Skip to content

Wire up CLI for scan and host offloading #156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions .github/workflows/e2e_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ jobs:
outputs:
llama-3-8b-name: ${{ steps.run-llama-3-8b.outputs.name }}
llama-3_1-8b-sa-name: ${{ steps.run-llama-3_1-8b-SplashAttention.outputs.name }}
llama-3_1-8b-scan-offload-name: ${{ steps.run-llama-3_1-8b-scan-offload.outputs.name }}
llama-3-8b-2d-name: ${{ steps.run-llama-3-8b-2d.outputs.name }}
llama-3-8b-2-slice-name: ${{ steps.run-llama-3-8b-2-slice.outputs.name }}
mixtral-8x7b-name: ${{ steps.run-mixtral-8x7b.outputs.name }}
Expand Down Expand Up @@ -62,6 +63,7 @@ jobs:
dataset_config_name=wikitext-2-raw-v1 \
profile_step=3 \
max_steps=15

- name: Run Llama 3.1 8B (Splash Attention)
id: run-llama-3_1-8b-SplashAttention
env:
Expand All @@ -81,6 +83,27 @@ jobs:
dataset_config_name=wikitext-2-raw-v1 \
profile_step=3 \
max_steps=15

- name: Run Llama 3.1 8B (Scan + Offload)
id: run-llama-3_1-8b-scan-offload
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py llama-3dot1-8b-sa)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run \
--name $name \
torchprime/torch_xla_models/train.py \
model=llama-3.1-8b \
model/remat=llama-scan-offload \
global_batch_size=8 \
ici_mesh.fsdp=4 \
dataset_config_name=wikitext-2-raw-v1 \
profile_step=3 \
max_steps=15

- name: Run Llama 3.0 8B (2D sharding)
id: run-llama-3-8b-2d
env:
Expand All @@ -94,7 +117,7 @@ jobs:
--name $name \
torchprime/torch_xla_models/train.py \
model=llama-3-8b \
model/scaling=llama-fsdp-tp \
model/sharding=llama-fsdp-tp \
global_batch_size=8 \
ici_mesh.fsdp=2 \
ici_mesh.tensor=2 \
Expand Down Expand Up @@ -136,7 +159,7 @@ jobs:
--num-slices 2 \
torchprime/torch_xla_models/train.py \
model=llama-3-8b \
model/scaling=llama-fsdp \
model/sharding=llama-fsdp \
global_batch_size=16 \
dcn_mesh.fsdp=2 \
ici_mesh.fsdp=4 \
Expand All @@ -154,14 +177,23 @@ jobs:
secrets: inherit

llama-3_1-8b-sa:
name: Llama 3.1 8B (Spalsh Attention)
name: Llama 3.1 8B (Splash Attention)
needs: tp-run
uses: ./.github/workflows/reusable_e2e_check.yml
with:
jobset_name: ${{ needs.tp-run.outputs.llama-3_1-8b-sa-name }}
artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }}
secrets: inherit

llama-3_1-8b-scan-offload:
name: Llama 3.1 8B (Scan + Offload)
needs: tp-run
uses: ./.github/workflows/reusable_e2e_check.yml
with:
jobset_name: ${{ needs.tp-run.outputs.llama-3_1-8b-scan-offload-name }}
artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }}
secrets: inherit

llama-3-8b-2d:
name: Llama 3.0 8B (2D sharding)
needs: tp-run
Expand Down
6 changes: 4 additions & 2 deletions torchprime/torch_xla_models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,10 @@ tp run torchprime/torch_xla_models/train.py \
- `train.py`: Main training script that sets up the model, data, and training loop
- `configs/base.yaml`: Configuration file for the training script
- `configs/model`: Configuration files for models
- `configs/model/scaling`: Configuration files for scaling the training of a model, e.g.
rematerialization, sharding tensors.
- `configs/model/sharding`: Configuration files for distributing the training
over many chips
- `configs/model/remat`: Configuration files for rematerialization strategy e.g.
activation checkpointing, host offloading
- `llama/model.py`: Implementation of the Llama model family
- `mixtral/model.py`: Implementation of the Mixtral model family

Expand Down
23 changes: 23 additions & 0 deletions torchprime/torch_xla_models/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,26 @@ dcn_mesh:
fsdp: 1
tensor: 1
expert: 1

# These are default values for model activation rematerialization configuration.
# They can be overridden on the command line or by importing one of the presets
# in the `model/remat` directory.
model:
remat:
# The class names of model layers whose intermediate activations should be
# recomputed during the backward pass (i.e. activation checkpointing).
activation_checkpoint_layers: []

# If not null, compile a module of type `HomogeneousSequential` located at the
# given path in the module tree using `torch_xla.experimental.scan_layers`.
scan_layers: null

# If specified, offload these tensors to host RAM during the forward pass and
# move them back during the backward pass.
#
# The tensors to be offloaded should be given a name by wrapping them with the
# `torchprime.torch_xla_models.offloading.offload_name` call. Then the same
# name could be specified here to offload that tensor.
#
# Currently in order to offload tensors, `scan_layers` must also be enabled.
offload_tensors: []
3 changes: 2 additions & 1 deletion torchprime/torch_xla_models/configs/model/llama-3-8b.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
defaults:
- _self_ # refers to this config file
- scaling: llama-fsdp # refers to scaling/llama-fsdp.yaml
- sharding: llama-fsdp # refers to sharding/llama-fsdp.yaml
- remat: llama # refers to remat/llama.yaml

model_class: llama.LlamaForCausalLM # Used to import the model from this class
vocab_size: 128256
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
defaults:
- _self_
- scaling: llama-fsdp-tp
- sharding: llama-fsdp-tp
- remat: llama

model_class: llama.LlamaForCausalLM
vocab_size: 128256
Expand Down
5 changes: 3 additions & 2 deletions torchprime/torch_xla_models/configs/model/llama-3.1-8b.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
defaults:
- _self_ # refers to this config file
- scaling: llama-fsdp # refers to scaling/llama-fsdp.yaml
- sharding: llama-fsdp # refers to sharding/llama-fsdp.yaml
- remat: llama # refers to remat/llama.yaml

model_class: llama.LlamaForCausalLM # Used to import the model from this class
vocab_size: 128256
Expand All @@ -25,4 +26,4 @@ rope_scaling:
factor: 8.0
low_freq_factor: 1.0
high_freq_factor: 4.0
original_context_len: 8192
original_context_len: 8192
3 changes: 2 additions & 1 deletion torchprime/torch_xla_models/configs/model/mixtral-8x7b.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
defaults:
- _self_ # refers to this config file
- scaling: mixtral-fsdp # refers to scaling/mixtral-fsdp.yaml
- sharding: mixtral-fsdp # refers to sharding/mixtral-fsdp.yaml
- remat: mixtral # refers to remat/mixtral.yaml

model_class: mixtral.MixtralForCausalLM # Used to import the model from this class
bos_token_id: 1
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
defaults:
- llama-scan
- _self_

offload_tensors:
- decoder_input
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
defaults:
- llama
- _self_

# Compile the sequence of decoder layers at `model.layers` using scan.
scan_layers: 'model.layers'
6 changes: 6 additions & 0 deletions torchprime/torch_xla_models/configs/model/remat/llama.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
activation_checkpoint_layers:
- LlamaDecoderLayer

# Refer to https://github.com/pytorch/xla/issues/6379 for backward optimization barrier info.
optimization_barrier_layers:
- LlamaDecoderLayer
6 changes: 6 additions & 0 deletions torchprime/torch_xla_models/configs/model/remat/mixtral.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
activation_checkpoint_layers:
- MixtralDecoderLayer

# Refer to https://github.com/pytorch/xla/issues/6379 for backward optimization barrier info.
optimization_barrier_layers:
- MixtralDecoderLayer

This file was deleted.

25 changes: 0 additions & 25 deletions torchprime/torch_xla_models/configs/model/scaling/llama-fsdp.yaml

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# 2D (FSDP + TP) sharding configuration for Llama models.

# Weights

# TODO(https://github.com/AI-Hypercomputer/torchprime/issues/114): This
# cannot be `[tensor, fsdp]`, or the gradients will sometimes become NaN.
model.embed_tokens.weight: [fsdp, tensor]

model.layers.*.self_attn.q_proj.weight: [tensor, fsdp]
model.layers.*.self_attn.k_proj.weight: [tensor, fsdp]
model.layers.*.self_attn.v_proj.weight: [tensor, fsdp]
model.layers.*.self_attn.o_proj.weight: [fsdp, tensor]
model.layers.*.mlp.gate_proj.weight: [tensor, fsdp]
model.layers.*.mlp.up_proj.weight: [tensor, fsdp]
model.layers.*.mlp.down_proj.weight: [fsdp, tensor]
model.layers.*.input_layernorm.weight: [fsdp]
model.layers.*.post_attention_layernorm.weight: [fsdp]
model.norm.weight: [fsdp]
lm_head.weight: [tensor, fsdp]

# Activations
model.layers.*.self_attn.q_proj: [fsdp, null, tensor]
model.layers.*.self_attn.k_proj: [fsdp, null, tensor]
model.layers.*.self_attn.v_proj: [fsdp, null, tensor]
model.layers.*.self_attn.o_proj: [fsdp, null, tensor]
model.layers.*.input_layernorm: [fsdp, null, tensor]
model.layers.*.post_attention_layernorm: [fsdp, null, tensor]
model.layers.*.mlp: [fsdp, null, tensor]
model.layers.*: [fsdp, null, tensor]
17 changes: 17 additions & 0 deletions torchprime/torch_xla_models/configs/model/sharding/llama-fsdp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Weights
model.embed_tokens.weight: [fsdp, null]
model.layers.*.self_attn.q_proj.weight: [fsdp, null]
model.layers.*.self_attn.k_proj.weight: [null, fsdp]
model.layers.*.self_attn.v_proj.weight: [null, fsdp]
model.layers.*.self_attn.o_proj.weight: [fsdp, null]
model.layers.*.mlp.gate_proj.weight: [fsdp, null]
model.layers.*.mlp.up_proj.weight: [fsdp, null]
model.layers.*.mlp.down_proj.weight: [null, fsdp]
model.layers.*.input_layernorm.weight: [fsdp]
model.layers.*.post_attention_layernorm.weight: [fsdp]
model.norm.weight: [fsdp]
lm_head.weight: [fsdp, null]

# Activations
model.layers.*: [fsdp, null, null]
lm_head: [fsdp, null, null]
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Weights
model.embed_tokens.weight: ['fsdp', null]
model.layers.*.self_attn.q_proj.weight: ['fsdp', null]
model.layers.*.self_attn.k_proj.weight: [null, 'fsdp']
model.layers.*.self_attn.v_proj.weight: [null, 'fsdp']
model.layers.*.self_attn.o_proj.weight: ['fsdp', null]
model.layers.*.block_sparse_moe.gate.weight: [null, 'fsdp']
model.layers.*.block_sparse_moe.experts.w1: [null, null, 'fsdp']
model.layers.*.block_sparse_moe.experts.w2: [null, 'fsdp', null]
model.layers.*.block_sparse_moe.experts.w3: [null, null, 'fsdp']
model.layers.*.input_layernorm.weight: ['fsdp']
model.layers.*.post_attention_layernorm.weight: ['fsdp']
model.norm.weight: ['fsdp']
lm_head.weight: ['fsdp', null]

# Activations
model.layers.*[0]: [fsdp, null, null] # Shard the first output of the decoder layer
lm_head: [fsdp, null, null]
Loading