Skip to content

Commit c4eeeb5

Browse files
cblmemoMichaelvll
andauthored
[Core] Support TPU v6 (#4220)
* init * fix * nit * format * add readme * add inference example * nit * add multi-host training * rephrase catalog doc * Update examples/tpu/v6e/README.md Co-authored-by: Zhanghao Wu <zhanghao.wu@outlook.com> --------- Co-authored-by: Zhanghao Wu <zhanghao.wu@outlook.com>
1 parent 5dda9cf commit c4eeeb5

File tree

9 files changed

+290
-9
lines changed

9 files changed

+290
-9
lines changed

examples/tpu/v6e/README.md

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# TPU v6e
2+
3+
Trillium (also refers to v6e) is Cloud TPU’s latest generation AI accelerator. SkyPilot support TPU v6e with provisioning, training and serving.
4+
5+
## Catalogs
6+
7+
Currently, for TPU v6e, the public APIs for regions and pricing is not released yet, and pricing info for `us-central1`, `us-central2`, `us-south1` is not available. We set the price to `0.0` in those regions for now.
8+
9+
```
10+
## Provisioning
11+
12+
To provision TPU v6e, use the following command:
13+
14+
```bash
15+
$ sky launch --gpus tpu-v6e-16 -c tpu-v6e
16+
```
17+
18+
After that, you can SSH to the instance and start developing your model:
19+
20+
```bash
21+
$ ssh tpu-v6e
22+
```
23+
24+
## Training
25+
26+
Examples in this directory (`train-llama3-8b.yaml`) shows how to use TPU v6e to train a Llama3 8b model, using PyTorch (XLA) on the wikitext dataset. To start the training, use the following command:
27+
28+
```bash
29+
$ HF_TOKEN=hf_xxx sky launch train-llama3-8b.yaml -c train-llama3-8b --env HF_TOKEN
30+
```
31+
32+
### Single-Host Training
33+
34+
The training throughput for a `tpu-v6e-8` instance should around 0.5 samples/s:
35+
36+
```bash
37+
(task, pid=17499) ***** train metrics *****
38+
(task, pid=17499) epoch = 1.1765
39+
(task, pid=17499) total_flos = 109935420GF
40+
(task, pid=17499) train_loss = 10.6011
41+
(task, pid=17499) train_runtime = 0:11:12.77
42+
(task, pid=17499) train_samples = 282
43+
(task, pid=17499) train_samples_per_second = 0.476
44+
(task, pid=17499) train_steps_per_second = 0.03
45+
INFO: Job finished (status: SUCCEEDED).
46+
```
47+
48+
### Multi-Host Training
49+
50+
By changing the TPU type to `tpu-v6e-16` and the `--per_device_train_batch_size` to `32`, the training throughput increased to around 1 samples/s:
51+
52+
```bash
53+
(head, rank=0, pid=17894) ***** train metrics *****
54+
(head, rank=0, pid=17894) epoch = 2.5
55+
(head, rank=0, pid=17894) total_flos = 219870840GF
56+
(head, rank=0, pid=17894) train_loss = 10.1527
57+
(head, rank=0, pid=17894) train_runtime = 0:11:13.18
58+
(head, rank=0, pid=17894) train_samples = 282
59+
(head, rank=0, pid=17894) train_samples_per_second = 0.951
60+
(head, rank=0, pid=17894) train_steps_per_second = 0.03
61+
62+
(worker1, rank=1, pid=15406, ip=10.164.0.57) ***** train metrics *****
63+
(worker1, rank=1, pid=15406, ip=10.164.0.57) epoch = 2.5
64+
(worker1, rank=1, pid=15406, ip=10.164.0.57) total_flos = 219870840GF
65+
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_loss = 10.1527
66+
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_runtime = 0:11:15.08
67+
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples = 282
68+
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples_per_second = 0.948
69+
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_steps_per_second = 0.03
70+
71+
(worker2, rank=2, pid=16552, ip=10.164.0.58) ***** train metrics *****
72+
(worker2, rank=2, pid=16552, ip=10.164.0.58) epoch = 2.5
73+
(worker2, rank=2, pid=16552, ip=10.164.0.58) total_flos = 219870840GF
74+
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_loss = 10.1527
75+
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_runtime = 0:11:15.61
76+
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples = 282
77+
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples_per_second = 0.947
78+
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_steps_per_second = 0.03
79+
80+
(worker3, rank=3, pid=17469, ip=10.164.0.59) ***** train metrics *****
81+
(worker3, rank=3, pid=17469, ip=10.164.0.59) epoch = 2.5
82+
(worker3, rank=3, pid=17469, ip=10.164.0.59) total_flos = 219870840GF
83+
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_loss = 10.1527
84+
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_runtime = 0:11:15.10
85+
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples = 282
86+
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples_per_second = 0.948
87+
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_steps_per_second = 0.03
88+
89+
INFO: Job finished (status: SUCCEEDED).
90+
```
91+
92+
# Serving
93+
94+
TPU v6e also supports serving. Examples in this directory (`serve-llama2-7b.yaml`) shows how to use TPU v6e to serve a Llama2 7b model, using PyTorch (XLA) and the JetStream lib. To start the serving, use the following command:
95+
96+
```bash
97+
$ HF_TOKEN=hf_xxx sky launch serve-llama2-7b.yaml -c serve-llama2-7b --env HF_TOKEN
98+
```
99+
100+
After the server is ready, you should see the following message:
101+
102+
```bash
103+
(task, pid=26431) 2024-09-24 19:58:15,160 - root - INFO - Starting server on port 9000 with 64 threads
104+
(task, pid=26431) I0924 19:58:15.160293 140454572087296 server_lib.py:155] Starting server on port 9000 with 64 threads
105+
(task, pid=26431) 2024-09-24 19:58:15,161 - root - INFO - Not starting JAX profiler server: False
106+
(task, pid=26431) I0924 19:58:15.161907 140454572087296 server_lib.py:164] Not starting JAX profiler server: False
107+
(task, pid=26431) Started jetstream_server....
108+
```
109+
110+
You can now start a benchmark to test the serving performance:
111+
112+
```bash
113+
$ sky exec serve-llama2-7b benchmark-llama2-7b.yaml
114+
... (emitted logs)
115+
(task, pid=25491) Successful requests: 100
116+
(task, pid=25491) Benchmark duration: 8.753792 s
117+
(task, pid=25491) Total input tokens: 21888
118+
(task, pid=25491) Total generated tokens: 18803
119+
(task, pid=25491) Request throughput: 11.42 requests/s
120+
(task, pid=25491) Input token throughput: 2500.40 tokens/s
121+
(task, pid=25491) Output token throughput: 2147.98 tokens/s
122+
(task, pid=25491) Mean TTFT: 1981.93 ms
123+
(task, pid=25491) Median TTFT: 1829.33 ms
124+
(task, pid=25491) P99 TTFT: 4511.95 ms
125+
(task, pid=25491) Mean TPOT: 130.71 ms
126+
(task, pid=25491) Median TPOT: 18.88 ms
127+
(task, pid=25491) P99 TPOT: 2487.37 ms
128+
```
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
envs:
2+
model_name: llama-2
3+
tokenizer_path: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original/tokenizer.model
4+
5+
run: |
6+
cd JetStream
7+
python benchmarks/benchmark_serving.py \
8+
--tokenizer=$tokenizer_path --num-prompts=100 \
9+
--dataset openorca --save-request-outputs \
10+
--warmup-mode=sampled --model=$model_name

examples/tpu/v6e/config-8B.json

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
{
2+
"architectures": [
3+
"LlamaForCausalLM"
4+
],
5+
"attention_bias": false,
6+
"attention_dropout": 0.0,
7+
"bos_token_id": 128000,
8+
"eos_token_id": 128001,
9+
"hidden_act": "silu",
10+
"hidden_size": 4096,
11+
"initializer_range": 0.02,
12+
"intermediate_size": 14336,
13+
"max_position_embeddings": 8192,
14+
"model_type": "llama",
15+
"num_attention_heads": 32,
16+
"num_hidden_layers": 32,
17+
"num_key_value_heads": 8,
18+
"pretraining_tp": 1,
19+
"rms_norm_eps": 1e-05,
20+
"rope_scaling": null,
21+
"rope_theta": 500000.0,
22+
"tie_word_embeddings": false,
23+
"torch_dtype": "bfloat16",
24+
"transformers_version": "4.40.0.dev0",
25+
"use_cache": true,
26+
"vocab_size": 128256
27+
}

examples/tpu/v6e/fsdp_config.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"fsdp_transformer_layer_cls_to_wrap": [
3+
"LlamaDecoderLayer"
4+
],
5+
"xla": true,
6+
"xla_fsdp_v2": true,
7+
"xla_fsdp_grad_ckpt": true
8+
}

examples/tpu/v6e/serve-llama2-7b.yaml

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
resources:
2+
accelerators: tpu-v6e-8 # Fill in the accelerator type you want to use
3+
4+
envs:
5+
HF_TOKEN: # fill in your huggingface token
6+
HF_REPO_ID: meta-llama/Llama-2-7b
7+
model_name: llama-2
8+
input_ckpt_dir: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original
9+
output_ckpt_dir: /home/gcpuser/sky_workdir/ckpt/llama2-7b/converted
10+
tokenizer_path: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original/tokenizer.model
11+
12+
setup: |
13+
pip3 install huggingface_hub
14+
python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
15+
16+
# Setup TPU
17+
pip3 install cloud-tpu-client
18+
sudo apt update
19+
sudo apt install -y libopenblas-base
20+
pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \
21+
--index-url https://download.pytorch.org/whl/nightly/cpu
22+
pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \
23+
-f https://storage.googleapis.com/libtpu-releases/index.html
24+
pip install torch_xla[pallas] \
25+
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
26+
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
27+
28+
# Setup runtime for serving
29+
git clone https://github.com/google/JetStream.git
30+
cd JetStream
31+
git checkout main
32+
git pull origin main
33+
pip install -e .
34+
cd benchmarks
35+
pip install -r requirements.in
36+
cd ../..
37+
git clone https://github.com/google/jetstream-pytorch.git
38+
cd jetstream-pytorch/
39+
git checkout jetstream-v0.2.3
40+
source install_everything.sh
41+
pip3 install -U --pre jax jaxlib libtpu-nightly requests \
42+
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
43+
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
44+
45+
46+
# Prepare checkpoint, inside jetstream-pytorch repo
47+
mkdir -p ${input_ckpt_dir}
48+
python3 -c "import huggingface_hub; huggingface_hub.snapshot_download('${HF_REPO_ID}', local_dir='${input_ckpt_dir}')"
49+
mkdir -p ${output_ckpt_dir}
50+
python -m convert_checkpoints --model_name=$model_name \
51+
--input_checkpoint_dir=$input_ckpt_dir \
52+
--output_checkpoint_dir=$output_ckpt_dir
53+
54+
run: |
55+
cd jetstream-pytorch
56+
python run_server.py --model_name=$model_name \
57+
--size=7b --batch_size=24 --max_cache_length=2048 \
58+
--checkpoint_path=$output_ckpt_dir \
59+
--tokenizer_path=$tokenizer_path \
60+
--sharding_config="default_shardings/llama.yaml"

examples/tpu/v6e/train-llama3-8b.yaml

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
resources:
2+
accelerators: tpu-v6e-8 # Fill in the accelerator type you want to use
3+
4+
envs:
5+
HF_TOKEN: # fill in your huggingface token
6+
7+
workdir: .
8+
9+
setup: |
10+
pip3 install huggingface_hub
11+
python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
12+
13+
# Setup TPU
14+
pip3 install cloud-tpu-client
15+
sudo apt update
16+
sudo apt install -y libopenblas-base
17+
pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \
18+
--index-url https://download.pytorch.org/whl/nightly/cpu
19+
pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \
20+
-f https://storage.googleapis.com/libtpu-releases/index.html
21+
pip install torch_xla[pallas] \
22+
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
23+
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
24+
25+
# Setup runtime for training
26+
git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
27+
cd transformers
28+
pip3 install -e .
29+
pip3 install datasets evaluate scikit-learn accelerate
30+
31+
run: |
32+
unset LD_PRELOAD
33+
PJRT_DEVICE=TPU XLA_USE_SPMD=1 ENABLE_PJRT_COMPATIBILITY=true \
34+
python3 transformers/examples/pytorch/language-modeling/run_clm.py \
35+
--dataset_name wikitext \
36+
--dataset_config_name wikitext-2-raw-v1 \
37+
--per_device_train_batch_size 16 \
38+
--do_train \
39+
--output_dir /home/$USER/tmp/test-clm \
40+
--overwrite_output_dir \
41+
--config_name /home/$USER/sky_workdir/config-8B.json \
42+
--cache_dir /home/$USER/cache \
43+
--tokenizer_name meta-llama/Meta-Llama-3-8B \
44+
--block_size 8192 \
45+
--optim adafactor \
46+
--save_strategy no \
47+
--logging_strategy no \
48+
--fsdp "full_shard" \
49+
--fsdp_config /home/$USER/sky_workdir/fsdp_config.json \
50+
--torch_dtype bfloat16 \
51+
--dataloader_drop_last yes \
52+
--flash_attention \
53+
--max_steps 20

sky/backends/cloud_vm_ray_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2467,7 +2467,7 @@ def num_ips_per_node(self) -> int:
24672467
"""Returns number of IPs per node in the cluster, handling TPU Pod."""
24682468
is_tpu_vm_pod = gcp_utils.is_tpu_vm_pod(self.launched_resources)
24692469
if is_tpu_vm_pod:
2470-
num_ips = gcp_utils.get_num_tpu_devices(self.launched_resources)
2470+
num_ips = len(self.internal_ips())
24712471
else:
24722472
num_ips = 1
24732473
return num_ips

sky/clouds/utils/gcp_utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,6 @@ def is_tpu_vm_pod(resources: Optional['resources_lib.Resources']) -> bool:
4949
return not acc.endswith('-8')
5050

5151

52-
def get_num_tpu_devices(resources: Optional['resources_lib.Resources']) -> int:
53-
if resources is None or not is_tpu(resources):
54-
raise ValueError('resources must be a valid TPU resource.')
55-
acc, _ = list(resources.accelerators.items())[0]
56-
num_tpu_devices = int(int(acc.split('-')[2]) / 8)
57-
return num_tpu_devices
58-
59-
6052
@dataclasses.dataclass
6153
class SpecificReservation:
6254
count: int

sky/resources.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,9 @@ def _get_default_runtime_version() -> str:
602602
# TPU V5 requires a newer runtime version.
603603
if acc.startswith('tpu-v5'):
604604
return 'v2-alpha-tpuv5'
605+
# TPU V6e requires a newer runtime version.
606+
if acc.startswith('tpu-v6e'):
607+
return 'v2-alpha-tpuv6e'
605608
return 'tpu-vm-base'
606609

607610
accelerator_args['runtime_version'] = (

0 commit comments

Comments
 (0)