Skip to content

Commit f23d3cf

Browse files
authored
Support torch/torch_xla wheel locally (#142)
* Support torch/torch_xla wheel locally * nit * nit
1 parent d1acfa9 commit f23d3cf

File tree

5 files changed

+42
-3
lines changed

5 files changed

+42
-3
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ pip-wheel-metadata/
3535
# Ignore build and distribution directories
3636
build/
3737
dist/
38+
local_dist/
3839
*.egg-info/
3940

4041
# Ignore environment configuration files

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,15 @@ add flag `--use-hf` to `tp run` command:
149149
tp run --use-hf torchprime/hf_models/train.py
150150
```
151151

152+
## Run with local torch/torch_xla wheel
153+
Torchprime supports run with user specified torch and torch_xla wheels placed
154+
under `local_dist/` directory. The wheel will be automatically installed in the
155+
docker image when use `tp run` command. To use the wheel, add flag
156+
`--use-local-wheel` to `tp run` command:
157+
```
158+
tp run --use-local-wheel torchprime/hf_models/train.py
159+
```
160+
152161
## Contributing
153162

154163
Contributions are welcome! Please feel free to submit a pull request.

torchprime/launcher/Dockerfile

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_cxx11_20250227
44

55
ARG USE_TRANSFORMERS=false
6+
ARG USE_LOCAL_WHEEL=false
7+
68
# Install system dependencies
79
RUN apt-get update && apt-get install -y curl gnupg
810

@@ -39,6 +41,23 @@ RUN pip install -e .
3941

4042
# Now copy the rest of the repo
4143
COPY . /workspaces/torchprime
44+
45+
# Install torch and torch_xla from local wheels if USE_LOCAL_WHEEL and exists
46+
# under local_dist directory. Note that you need to build the torch and
47+
# torch_xla using the
48+
RUN if [ "$USE_LOCAL_WHEEL" = "true" ]; then \
49+
if [ -d "local_dist" ] && [ "$(find local_dist -name 'torch-*.whl' | wc -l)" -gt 0 ]; then \
50+
pip install local_dist/torch-*.whl; \
51+
else \
52+
echo "torch wheel not found in local_dist directory"; \
53+
fi; \
54+
if [ -d "local_dist" ] && [ "$(find local_dist -name 'torch_xla-*.whl' | wc -l)" -gt 0 ]; then \
55+
pip install local_dist/torch_xla-*.whl; \
56+
else \
57+
echo "torch_xla wheel not found in local_dist directory"; \
58+
fi; \
59+
fi
60+
4261
# This should not install any packages. Only symlink the source code.
4362
RUN pip install --no-deps -e .
4463

torchprime/launcher/buildpush.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def buildpush(
4848

4949
build_cmd = f"{sudo_cmd} docker build"
5050
if build_arg:
51-
build_cmd += f" --build-arg {build_arg}"
51+
for _arg in build_arg:
52+
build_cmd += f" --build-arg {_arg}"
5253
build_cmd += (
5354
f" --network=host --progress=auto -t {docker_tag} {context_dir} -f {docker_file}"
5455
)

torchprime/launcher/cli.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,13 @@ def create_and_activate_gcloud(gcloud_config_name, config: Config):
208208
default=None,
209209
)
210210
@click.option("--use-hf", is_flag=True, help="Use HuggingFace transformer")
211+
@click.option(
212+
"--use-local-wheel",
213+
is_flag=True,
214+
help="Use local torch and torch_xla wheels under folder local_dist/",
215+
)
211216
@interactive
212-
def run(args, name: str | None, use_hf: bool):
217+
def run(args, name: str | None, use_hf: bool, use_local_wheel: bool):
213218
"""
214219
Runs the provided SPMD training command as an xpk job on a GKE cluster.
215220
"""
@@ -218,7 +223,11 @@ def run(args, name: str | None, use_hf: bool):
218223
click.echo(get_project_dir().absolute())
219224

220225
# Build docker image.
221-
build_arg = "USE_TRANSFORMERS=true" if use_hf else None
226+
build_arg = []
227+
if use_hf:
228+
build_arg.append("USE_TRANSFORMERS=true")
229+
if use_local_wheel:
230+
build_arg.append("USE_LOCAL_WHEEL=true")
222231
docker_project = config.docker_project
223232
if docker_project is None:
224233
docker_project = config.project

0 commit comments

Comments
 (0)