Skip to content

Commit 88a04d4

Browse files
committed
Wire up CLI for scan and host offloading
1 parent d6f2452 commit 88a04d4

18 files changed

+204
-112
lines changed

.github/workflows/e2e_test.yml

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ jobs:
1717
outputs:
1818
llama-3-8b-name: ${{ steps.run-llama-3-8b.outputs.name }}
1919
llama-3_1-8b-sa-name: ${{ steps.run-llama-3_1-8b-SplashAttention.outputs.name }}
20+
llama-3_1-8b-scan-offload-name: ${{ steps.run-llama-3_1-8b-scan-offload.outputs.name }}
2021
llama-3-8b-2d-name: ${{ steps.run-llama-3-8b-2d.outputs.name }}
2122
llama-3-8b-2-slice-name: ${{ steps.run-llama-3-8b-2-slice.outputs.name }}
2223
mixtral-8x7b-name: ${{ steps.run-mixtral-8x7b.outputs.name }}
@@ -62,6 +63,7 @@ jobs:
6263
dataset_config_name=wikitext-2-raw-v1 \
6364
profile_step=3 \
6465
max_steps=15
66+
6567
- name: Run Llama 3.1 8B (Splash Attention)
6668
id: run-llama-3_1-8b-SplashAttention
6769
env:
@@ -81,6 +83,27 @@ jobs:
8183
dataset_config_name=wikitext-2-raw-v1 \
8284
profile_step=3 \
8385
max_steps=15
86+
87+
- name: Run Llama 3.1 8B (Scan + Offload)
88+
id: run-llama-3_1-8b-scan-offload
89+
env:
90+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
91+
XLA_IR_DEBUG: 1
92+
XLA_HLO_DEBUG: 1
93+
run: |
94+
name=$(e2e_testing/gen_name.py llama-3dot1-8b-sa)
95+
echo "name=$name" >> "$GITHUB_OUTPUT"
96+
tp run \
97+
--name $name \
98+
torchprime/torch_xla_models/train.py \
99+
model=llama-3.1-8b \
100+
model/remat=llama-scan-offload \
101+
global_batch_size=8 \
102+
ici_mesh.fsdp=4 \
103+
dataset_config_name=wikitext-2-raw-v1 \
104+
profile_step=3 \
105+
max_steps=15
106+
84107
- name: Run Llama 3.0 8B (2D sharding)
85108
id: run-llama-3-8b-2d
86109
env:
@@ -94,7 +117,7 @@ jobs:
94117
--name $name \
95118
torchprime/torch_xla_models/train.py \
96119
model=llama-3-8b \
97-
model/scaling=llama-fsdp-tp \
120+
model/sharding=llama-fsdp-tp \
98121
global_batch_size=8 \
99122
ici_mesh.fsdp=2 \
100123
ici_mesh.tensor=2 \
@@ -136,7 +159,7 @@ jobs:
136159
--num-slices 2 \
137160
torchprime/torch_xla_models/train.py \
138161
model=llama-3-8b \
139-
model/scaling=llama-fsdp \
162+
model/sharding=llama-fsdp \
140163
global_batch_size=16 \
141164
dcn_mesh.fsdp=2 \
142165
ici_mesh.fsdp=4 \
@@ -154,14 +177,23 @@ jobs:
154177
secrets: inherit
155178

156179
llama-3_1-8b-sa:
157-
name: Llama 3.1 8B (Spalsh Attention)
180+
name: Llama 3.1 8B (Splash Attention)
158181
needs: tp-run
159182
uses: ./.github/workflows/reusable_e2e_check.yml
160183
with:
161184
jobset_name: ${{ needs.tp-run.outputs.llama-3_1-8b-sa-name }}
162185
artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }}
163186
secrets: inherit
164187

188+
llama-3_1-8b-scan-offload:
189+
name: Llama 3.1 8B (Scan + Offload)
190+
needs: tp-run
191+
uses: ./.github/workflows/reusable_e2e_check.yml
192+
with:
193+
jobset_name: ${{ needs.tp-run.outputs.llama-3_1-8b-scan-offload-name }}
194+
artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }}
195+
secrets: inherit
196+
165197
llama-3-8b-2d:
166198
name: Llama 3.0 8B (2D sharding)
167199
needs: tp-run

torchprime/torch_xla_models/README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,10 @@ tp run torchprime/torch_xla_models/train.py \
137137
- `train.py`: Main training script that sets up the model, data, and training loop
138138
- `configs/base.yaml`: Configuration file for the training script
139139
- `configs/model`: Configuration files for models
140-
- `configs/model/scaling`: Configuration files for scaling the training of a model, e.g.
141-
rematerialization, sharding tensors.
140+
- `configs/model/sharding`: Configuration files for distributing the training
141+
over many chips.
142+
- `configs/model/remat`: Configuration files for rematerialization strategy e.g.
143+
activation checkpointing, host offloading
142144
- `llama/model.py`: Implementation of the Llama model family
143145
- `mixtral/model.py`: Implementation of the Mixtral model family
144146

torchprime/torch_xla_models/configs/default.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,16 @@ dcn_mesh:
4747
fsdp: 1
4848
tensor: 1
4949
expert: 1
50+
51+
# These are default values for model activation rematerialization configuration.
52+
# They can be overridden on the command line or by importing one of the presets
53+
# in the `model/remat` directory.
54+
model:
55+
remat:
56+
# If specified, compile a module of type `HomogeneousSequential` located at the
57+
# given path in the module tree using `torch_xla.experimental.scan_layers`.
58+
scan_layers: null
59+
60+
# If specified, offload these tensors to host RAM during the forward pass and
61+
# move them back during the backward pass.
62+
offload_tensors: []

torchprime/torch_xla_models/configs/model/llama-3-8b.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
defaults:
22
- _self_ # refers to this config file
3-
- scaling: llama-fsdp # refers to scaling/llama-fsdp.yaml
3+
- sharding: llama-fsdp # refers to sharding/llama-fsdp.yaml
4+
- remat: llama # refers to remat/llama.yaml
45

56
model_class: llama.LlamaForCausalLM # Used to import the model from this class
67
vocab_size: 128256

torchprime/torch_xla_models/configs/model/llama-3.1-405b.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
defaults:
22
- _self_
3-
- scaling: llama-fsdp-tp
3+
- sharding: llama-fsdp-tp
4+
- remat: llama
45

56
model_class: llama.LlamaForCausalLM
67
vocab_size: 128256

torchprime/torch_xla_models/configs/model/llama-3.1-8b.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
defaults:
22
- _self_ # refers to this config file
3-
- scaling: llama-fsdp # refers to scaling/llama-fsdp.yaml
3+
- sharding: llama-fsdp # refers to sharding/llama-fsdp.yaml
4+
- remat: llama # refers to remat/llama.yaml
45

56
model_class: llama.LlamaForCausalLM # Used to import the model from this class
67
vocab_size: 128256
@@ -25,4 +26,4 @@ rope_scaling:
2526
factor: 8.0
2627
low_freq_factor: 1.0
2728
high_freq_factor: 4.0
28-
original_context_len: 8192
29+
original_context_len: 8192

torchprime/torch_xla_models/configs/model/mixtral-8x7b.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
defaults:
22
- _self_ # refers to this config file
3-
- scaling: mixtral-fsdp # refers to scaling/mixtral-fsdp.yaml
3+
- sharding: mixtral-fsdp # refers to sharding/mixtral-fsdp.yaml
4+
- remat: mixtral # refers to remat/mixtral.yaml
45

56
model_class: mixtral.MixtralForCausalLM # Used to import the model from this class
67
bos_token_id: 1
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
defaults:
2+
- llama-scan
3+
- _self_
4+
5+
offload_tensors:
6+
- decoder_input
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
defaults:
2+
- llama
3+
- _self_
4+
5+
# Compile the sequence of decoder layers at `model.layers` using scan.
6+
scan_layers: 'model.layers'
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
activation_checkpoint_layers:
2+
- LlamaDecoderLayer
3+
4+
# Refer to https://github.com/pytorch/xla/issues/6379 for backward optimization barrier info.
5+
optimization_barrier_layers:
6+
- LlamaDecoderLayer
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
activation_checkpoint_layers:
2+
- MixtralDecoderLayer
3+
4+
# Refer to https://github.com/pytorch/xla/issues/6379 for backward optimization barrier info.
5+
optimization_barrier_layers:
6+
- MixtralDecoderLayer

torchprime/torch_xla_models/configs/model/scaling/llama-fsdp-tp.yaml

Lines changed: 0 additions & 37 deletions
This file was deleted.

torchprime/torch_xla_models/configs/model/scaling/llama-fsdp.yaml

Lines changed: 0 additions & 25 deletions
This file was deleted.

torchprime/torch_xla_models/configs/model/scaling/mixtral-fsdp.yaml

Lines changed: 0 additions & 26 deletions
This file was deleted.
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# 2D (FSDP + TP) sharding configuration for Llama models.
2+
3+
# Weights
4+
5+
# TODO(https://github.com/AI-Hypercomputer/torchprime/issues/114): This
6+
# cannot be `[tensor, fsdp]`, or the gradients will sometimes become NaN.
7+
model.embed_tokens.weight: [fsdp, tensor]
8+
9+
model.layers.*.self_attn.q_proj.weight: [tensor, fsdp]
10+
model.layers.*.self_attn.k_proj.weight: [tensor, fsdp]
11+
model.layers.*.self_attn.v_proj.weight: [tensor, fsdp]
12+
model.layers.*.self_attn.o_proj.weight: [fsdp, tensor]
13+
model.layers.*.mlp.gate_proj.weight: [tensor, fsdp]
14+
model.layers.*.mlp.up_proj.weight: [tensor, fsdp]
15+
model.layers.*.mlp.down_proj.weight: [fsdp, tensor]
16+
model.layers.*.input_layernorm.weight: [fsdp]
17+
model.layers.*.post_attention_layernorm.weight: [fsdp]
18+
model.norm.weight: [fsdp]
19+
lm_head.weight: [tensor, fsdp]
20+
21+
# Activations
22+
model.layers.*.self_attn.q_proj: [fsdp, null, tensor]
23+
model.layers.*.self_attn.k_proj: [fsdp, null, tensor]
24+
model.layers.*.self_attn.v_proj: [fsdp, null, tensor]
25+
model.layers.*.self_attn.o_proj: [fsdp, null, tensor]
26+
model.layers.*.input_layernorm: [fsdp, null, tensor]
27+
model.layers.*.post_attention_layernorm: [fsdp, null, tensor]
28+
model.layers.*.mlp: [fsdp, null, tensor]
29+
model.layers.*: [fsdp, null, tensor]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Weights
2+
model.embed_tokens.weight: [fsdp, null]
3+
model.layers.*.self_attn.q_proj.weight: [fsdp, null]
4+
model.layers.*.self_attn.k_proj.weight: [null, fsdp]
5+
model.layers.*.self_attn.v_proj.weight: [null, fsdp]
6+
model.layers.*.self_attn.o_proj.weight: [fsdp, null]
7+
model.layers.*.mlp.gate_proj.weight: [fsdp, null]
8+
model.layers.*.mlp.up_proj.weight: [fsdp, null]
9+
model.layers.*.mlp.down_proj.weight: [null, fsdp]
10+
model.layers.*.input_layernorm.weight: [fsdp]
11+
model.layers.*.post_attention_layernorm.weight: [fsdp]
12+
model.norm.weight: [fsdp]
13+
lm_head.weight: [fsdp, null]
14+
15+
# Activations
16+
model.layers.*: [fsdp, null, null]
17+
lm_head: [fsdp, null, null]
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Weights
2+
model.embed_tokens.weight: ['fsdp', null]
3+
model.layers.*.self_attn.q_proj.weight: ['fsdp', null]
4+
model.layers.*.self_attn.k_proj.weight: [null, 'fsdp']
5+
model.layers.*.self_attn.v_proj.weight: [null, 'fsdp']
6+
model.layers.*.self_attn.o_proj.weight: ['fsdp', null]
7+
model.layers.*.block_sparse_moe.gate.weight: [null, 'fsdp']
8+
model.layers.*.block_sparse_moe.experts.w1: [null, null, 'fsdp']
9+
model.layers.*.block_sparse_moe.experts.w2: [null, 'fsdp', null]
10+
model.layers.*.block_sparse_moe.experts.w3: [null, null, 'fsdp']
11+
model.layers.*.input_layernorm.weight: ['fsdp']
12+
model.layers.*.post_attention_layernorm.weight: ['fsdp']
13+
model.norm.weight: ['fsdp']
14+
lm_head.weight: ['fsdp', null]
15+
16+
# Activations
17+
model.layers.*[0]: [fsdp, null, null] # Shard the first output of the decoder layer
18+
lm_head: [fsdp, null, null]

0 commit comments

Comments
 (0)