@@ -100,15 +100,51 @@ tp run torchprime/experimental/torchax_models/run.py global_batch_size=256
100
100
` tp run ` will broadcast the specified command to all VMs in the XPK cluster,
101
101
which is the convention for running SPMD distributed workloads.
102
102
103
- #### Env var passed to the workload
103
+ #### Env vars passed to the workload
104
104
105
105
` tp run ` will pick up these environment variables locally and proxy them
106
106
to the distributed workload, if found:
107
107
108
108
- ` HF_TOKEN ` : HuggingFace token
109
109
- ` XLA_IR_DEBUG ` : [ torch_xla debugging flag] [ torch_xla_debug_env ]
110
110
- ` XLA_HLO_DEBUG ` : [ torch_xla debugging flag] [ torch_xla_debug_env ]
111
- - ` LIBTPU_INIT_ARGS ` : xla flag
111
+ - ` LIBTPU_INIT_ARGS ` : XLA flags that affect compilation and execution behavior
112
+
113
+ ## Model status
114
+
115
+ Here are the status of various models. In general, there are five stages for
116
+ each model:
117
+
118
+ - ** TODO** : We need to implement the model.
119
+ - ** Implemented** : The model runs either a training or an inference step.
120
+ - ** Optimized** : We found the best scaling configuration for the model on one or
121
+ more hardware. One-off performance data is available.
122
+ - ** Convergence** : We tested that the training loss converges to a reasonable
123
+ value, or that the loss curve tracks an existing reference if exists.
124
+ - ** Production** : Not only is the model optimized and converges, its performance
125
+ is also continuously monitored. This is a good state for using the model in
126
+ production.
127
+
128
+ All implemented models will at least have unit tests to verify basic numerical
129
+ correctness, and the convergence verification stage serves as an additional
130
+ correctness guarantee.
131
+
132
+ If a model is at least implemented, you'll also find a training recipe linked
133
+ from the checkmark emoji in the table. If a model is optimized, you'll also find
134
+ MFU numbers linked from the table. Note that a model may continue to receive
135
+ ongoing optimization thereafter.
136
+
137
+ | ** Model** | ** Implemented** | ** Optimized** | ** Converges** |
138
+ | -------------------- | ---------------------------------------------------------------------- | ------------------------------------------------------------------ | ------------- |
139
+ | Llama 3.0 8B | [ ✅] ( torchprime/torch_xla_models/README.md#llama-30-8b-on-v6e-256 ) | [ ✅] ( torchprime/torch_xla_models/README.md#llama-30-8b-on-v6e-256 ) | [ TODO] ( https://github.com/AI-Hypercomputer/torchprime/issues/90 ) |
140
+ | Llama 3.1 8B | [ ✅] ( torchprime/torch_xla_models/README.md#llama-31-8b-on-v6e-256 ) | [ TODO] ( https://github.com/AI-Hypercomputer/torchprime/issues/133 ) | TODO |
141
+ | Llama 3.1 70B | [ TODO] ( https://github.com/AI-Hypercomputer/torchprime/issues/17 ) | TODO | TODO |
142
+ | Llama 3.1 405B | [ ✅] ( torchprime/torch_xla_models/README.md#llama-31-405b-on-v6e-256 ) | [ TODO] ( https://github.com/AI-Hypercomputer/torchprime/milestone/2 ) | TODO |
143
+ | Mixtral 8x7B | [ ✅] ( torchprime/torch_xla_models/README.md#mixtral-8x7b-on-v6e-256 ) | [ TODO] ( https://github.com/AI-Hypercomputer/torchprime/issues/44 ) | TODO |
144
+ | Mixtral 8x22B | [ TODO] ( https://github.com/AI-Hypercomputer/torchprime/issues/45 ) | TODO | TODO |
145
+ | DeepSeek V3/R1 | TODO | TODO | TODO |
146
+ | Stable Diffusion 2.0 | [ TODO] ( https://github.com/AI-Hypercomputer/torchprime/issues/87 ) | TODO | TODO |
147
+ | Stable Diffusion 2.1 | [ TODO] ( https://github.com/AI-Hypercomputer/torchprime/issues/88 ) | TODO | TODO |
112
148
113
149
## Structure
114
150
@@ -133,31 +169,6 @@ and attributes where this model code came from, if any. This also helps to
133
169
show case what changes we have done to make it performant on TPU. The original
134
170
version is not expected to be run.
135
171
136
- ## Run huggingface transformer models
137
- Torchprime supports run with huggingface models by taking advantage of ` tp run ` .
138
- To use huggingface models, you can clone
139
- [ huggingface/transformers] ( https://github.com/huggingface/transformers ) under
140
- torchprime and name it as ` local_transformers ` . This allows you to pick any
141
- branch or make code modifications in transformers for experiment:
142
- ```
143
- git clone https://github.com/huggingface/transformers.git local_transformers
144
- ```
145
- If huggingface transformer doesn't exist, torchprime will automatically clone
146
- the repo and build the docker for experiment. To switch to huggingface models,
147
- add flag ` --use-hf ` to ` tp run ` command:
148
- ```
149
- tp run --use-hf torchprime/hf_models/train.py
150
- ```
151
-
152
- ## Run with local torch/torch_xla wheel
153
- Torchprime supports run with user specified torch and torch_xla wheels placed
154
- under ` local_dist/ ` directory. The wheel will be automatically installed in the
155
- docker image when use ` tp run ` command. To use the wheel, add flag
156
- ` --use-local-wheel ` to ` tp run ` command:
157
- ```
158
- tp run --use-local-wheel torchprime/hf_models/train.py
159
- ```
160
-
161
172
## Contributing
162
173
163
174
Contributions are welcome! Please feel free to submit a pull request.
@@ -192,6 +203,21 @@ ruff check [--fix]
192
203
You can install a Ruff VSCode plugin to check errors and format files from
193
204
the editor.
194
205
206
+ ## Run distributed training with local torch/torch_xla wheel
207
+
208
+ Torchprime supports running with user specified torch and torch_xla wheels placed
209
+ under ` local_dist/ ` directory. The wheel will be automatically installed in the
210
+ docker image when use ` tp run ` command. To use the wheel, add flag
211
+ ` --use-local-wheel ` to ` tp run ` command:
212
+
213
+ ``` sh
214
+ tp run --use-local-wheel torchprime/hf_models/train.py
215
+ ```
216
+
217
+ The wheels should be built inside a
218
+ [ PyTorch/XLA development docker image] [ torch_xla_dev_docker ] or the PyTorch/XLA
219
+ VSCode Dev Container to minimize compatibility issues.
220
+
195
221
## License
196
222
197
223
This project is licensed under the New BSD License - see the [ LICENSE] ( LICENSE )
@@ -205,3 +231,4 @@ For more information on PyTorch/XLA, visit the
205
231
[ xpk ] : https://github.com/AI-Hypercomputer/xpk
206
232
[ torch_xla_debug_env ] : https://github.com/pytorch/xla/blob/master/docs/source/learn/troubleshoot.md#environment-variables
207
233
[ hydra ] : https://hydra.cc/docs/intro/
234
+ [ torch_xla_dev_docker ] : https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md#manually-build-in-docker-container
0 commit comments