Skip to content

Commit e100715

Browse files
authored
Support of Splash Attention using xla_builder.call_jax (#145)
* Support of Splash Attention using xla_builder.call_jax * nit * nit update * Support align sharding spec with JAX * Update model config and enable e2e test for sa * name nit * fix test arg * enable e2e test for sa * Support run spalsh attention without repeat KV heads * fix e2e test * fix mixtral test * update splash attention kernel block size config to improve performance * nit test fix * lru_cache for func arg to call_jax to prevent jit recompilation * debug with profile * caching the call_jax to reduce tracing overhead * clean up code regarding shard spec * fix e2e test naming * nit * fix e2e test naming again * fix name for xpk
1 parent 64a6e1a commit e100715

File tree

11 files changed

+869
-46
lines changed

11 files changed

+869
-46
lines changed

.github/workflows/e2e_test.yml

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ jobs:
1616
ARTIFACT_DIR: gs://torchprime-e2e-tests/${{ github.job }}/${{ github.run_id }}-${{ github.run_attempt }}
1717
outputs:
1818
llama-3-8b-name: ${{ steps.run-llama-3-8b.outputs.name }}
19+
llama-3_1-8b-sa-name: ${{ steps.run-llama-3_1-8b-SplashAttention.outputs.name }}
1920
llama-3-8b-2d-name: ${{ steps.run-llama-3-8b-2d.outputs.name }}
2021
llama-3-8b-2-slice-name: ${{ steps.run-llama-3-8b-2-slice.outputs.name }}
2122
mixtral-8x7b-name: ${{ steps.run-mixtral-8x7b.outputs.name }}
@@ -61,7 +62,25 @@ jobs:
6162
dataset_config_name=wikitext-2-raw-v1 \
6263
profile_step=3 \
6364
max_steps=15
64-
65+
- name: Run Llama 3.1 8B (Splash Attention)
66+
id: run-llama-3_1-8b-SplashAttention
67+
env:
68+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
69+
XLA_IR_DEBUG: 1
70+
XLA_HLO_DEBUG: 1
71+
run: |
72+
name=$(e2e_testing/gen_name.py llama-3dot1-8b-sa)
73+
echo "name=$name" >> "$GITHUB_OUTPUT"
74+
tp run \
75+
--name $name \
76+
torchprime/torch_xla_models/train.py \
77+
model=llama-3.1-8b \
78+
model.attention_kernel=splash_attention \
79+
global_batch_size=8 \
80+
ici_mesh.fsdp=4 \
81+
dataset_config_name=wikitext-2-raw-v1 \
82+
profile_step=3 \
83+
max_steps=15
6584
- name: Run Llama 3.0 8B (2D sharding)
6685
id: run-llama-3-8b-2d
6786
env:
@@ -134,6 +153,15 @@ jobs:
134153
artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }}
135154
secrets: inherit
136155

156+
llama-3_1-8b-sa:
157+
name: Llama 3.1 8B (Spalsh Attention)
158+
needs: tp-run
159+
uses: ./.github/workflows/reusable_e2e_check.yml
160+
with:
161+
jobset_name: ${{ needs.tp-run.outputs.llama-3_1-8b-sa-name }}
162+
artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }}
163+
secrets: inherit
164+
137165
llama-3-8b-2d:
138166
name: Llama 3.0 8B (2D sharding)
139167
needs: tp-run

torchprime/torch_xla_models/configs/model/llama-3-8b.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@ initializer_range: 0.02
1818
rms_norm_eps: 1.0e-05
1919
attention_dropout: false
2020
attention_bias: false
21-
flash_attention: true
21+
# choose attention_kernel from: [flash_attention, splash_attention, null]
22+
attention_kernel: flash_attention
2223
rope_theta: 500000.0

torchprime/torch_xla_models/configs/model/llama-3.1-405b.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ initializer_range: 0.02
1818
rms_norm_eps: 1.0e-05
1919
attention_dropout: false
2020
attention_bias: false
21-
flash_attention: true
21+
# choose attention_kernel from: [flash_attention, splash_attention, null]
22+
attention_kernel: flash_attention
2223
rope_theta: 500000.0
2324
rope_scaling:
2425
factor: 8.0

torchprime/torch_xla_models/configs/model/llama-3.1-8b.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ initializer_range: 0.02
1818
rms_norm_eps: 1.0e-05
1919
attention_dropout: false
2020
attention_bias: false
21-
flash_attention: true
21+
# choose attention_kernel from: [flash_attention, splash_attention, null]
22+
attention_kernel: flash_attention
2223
rope_theta: 500000.0
2324
rope_scaling:
2425
factor: 8.0

torchprime/torch_xla_models/configs/model/mixtral-8x7b.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ router_aux_loss_coef: 0.02
2020
vocab_size: 32000
2121
attention_bias: false
2222
attention_dropout: 0.0
23-
flash_attention: true
23+
# choose attention_kernel from: [flash_attention, splash_attention, null]
24+
attention_kernel: flash_attention
2425
moe_implementation: gmm
2526
tokenizer_name: mistralai/Mixtral-8x7B-v0.1

0 commit comments

Comments
 (0)