Skip to content

Commit 825320e

Browse files
authored
Support run trainer locally (#111)
* Support run trainer locally * nit * update on feedback
1 parent e100715 commit 825320e

File tree

3 files changed

+103
-15
lines changed

3 files changed

+103
-15
lines changed

README.md

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,33 +169,69 @@ Finally, each model may also provide a GPU "original" version that illustrates
169169
and attributes where this model code came from, if any. This also helps to
170170
show case what changes we have done to make it performant on TPU. The original
171171
version is not expected to be run.
172-
173172
## Contributing
174173

175174
Contributions are welcome! Please feel free to submit a pull request.
176175

177176
When developing, use `pip install -e '.[dev]'` to install dev dependencies such
178177
as linter and formatter.
179178

180-
How to run tests:
179+
### How to run tests:
181180

182181
```sh
183182
pytest
184183
```
185184

186-
How to run some of the tests, and re-run them whenever you change a file:
185+
### How to run some of the tests, and re-run them whenever you change a file:
187186

188187
```sh
189188
tp -i test ... # replace with path to tests/directories
190189
```
191190

192-
How to format:
191+
192+
### How to run HuggingFace transformer models
193+
Torchprime supports run with huggingface models by taking advantage of `tp run`.
194+
To use huggingface models, you can clone
195+
[huggingface/transformers](https://github.com/huggingface/transformers) under
196+
torchprime and name it as `local_transformers`. This allows you to pick any
197+
branch or make code modifications in transformers for experiment:
198+
```
199+
git clone https://github.com/huggingface/transformers.git local_transformers
200+
```
201+
If huggingface transformer doesn't exist, torchprime will automatically clone
202+
the repo and build the docker for experiment. To switch to huggingface models,
203+
add flag `--use-hf` to `tp run` command:
204+
```
205+
tp run --use-hf torchprime/hf_models/train.py
206+
```
207+
208+
### How to run inside the docker container locally
209+
You can also run locally without XPK with docker. When running inside the docker
210+
container, it will use the same dependencies and build process as used in the
211+
XPK approach, improving the hermeticity and reliability.
212+
```
213+
tp docker-run torchprime/torch_xla_models/train.py
214+
```
215+
This will run the TorchPrime docker image locally. You can also add `--use-hf`
216+
to run HuggingFace model locally.
217+
```
218+
tp docker-run --use-hf torchprime/hf_models/train.py
219+
```
220+
221+
### How to run locally without XPK:
222+
```
223+
tp dbrun torchprime/torch_xla_models/train.py
224+
```
225+
This will run the TorchPrime docker image locally. You can also add `--use-hf`
226+
to run HuggingFace model locally.
227+
228+
### How to format:
193229

194230
```sh
195231
ruff format
196232
```
197233

198-
How to lint:
234+
### How to lint:
199235

200236
```sh
201237
ruff check [--fix]

torchprime/launcher/buildpush.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
def buildpush(
1616
torchprime_project_id,
1717
torchprime_docker_url=None,
18-
torchprime_docker_tag=None,
18+
push_docker=True,
1919
*,
2020
build_arg=None,
2121
) -> str:
@@ -36,7 +36,7 @@ def buildpush(
3636

3737
# Determine Docker tag
3838
default_tag = f"{datetime_str}-{random_chars}"
39-
docker_tag = torchprime_docker_tag if torchprime_docker_tag else default_tag
39+
docker_tag = default_tag
4040

4141
# Determine Docker URL
4242
default_url = f"gcr.io/{torchprime_project_id}/torchprime-{user}:{docker_tag}"
@@ -62,7 +62,8 @@ def buildpush(
6262
_run(
6363
f"{sudo_cmd} docker tag {docker_tag} {docker_url}",
6464
)
65-
_run(f"{sudo_cmd} docker push {docker_url}")
65+
if push_docker:
66+
_run(f"{sudo_cmd} docker push {docker_url}")
6667
except subprocess.CalledProcessError as e:
6768
print(f"Error running command: {e}")
6869
exit(e.returncode)
@@ -83,9 +84,10 @@ def _run(command):
8384
# Read environment variables or use defaults
8485
torchprime_project_id = os.getenv("TORCHPRIME_PROJECT_ID", "tpu-pytorch")
8586
torchprime_docker_url = os.getenv("TORCHPRIME_DOCKER_URL", None)
86-
torchprime_docker_tag = os.getenv("TORCHPRIME_DOCKER_TAG", None)
87+
push_docker_str = os.getenv("TORCHPRIME_PUSH_DOCKER", "true")
88+
push_docker = push_docker_str.lower() in ("true", "1", "yes", "y")
8789
buildpush(
8890
torchprime_project_id,
8991
torchprime_docker_url,
90-
torchprime_docker_tag,
92+
push_docker,
9193
)

torchprime/launcher/cli.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@
2323
import torchprime.launcher.doctor
2424
from torchprime.launcher.buildpush import buildpush
2525

26+
_DOCKER_ENV_FORWARD_LIST = [
27+
"HF_TOKEN",
28+
"XLA_IR_DEBUG",
29+
"XLA_HLO_DEBUG",
30+
"LIBTPU_INIT_ARGS",
31+
]
32+
2633

2734
@dataclass_json
2835
@dataclass
@@ -195,6 +202,53 @@ def create_and_activate_gcloud(gcloud_config_name, config: Config):
195202
)
196203

197204

205+
@cli.command(
206+
name="docker-run",
207+
context_settings=dict(
208+
ignore_unknown_options=True,
209+
),
210+
)
211+
@click.argument("args", nargs=-1, type=click.UNPROCESSED)
212+
@click.option("--use-hf", is_flag=True, help="Use HuggingFace transformer")
213+
def docker_run(args, use_hf: bool):
214+
"""
215+
Runs the provided training command locally for quick testing.
216+
"""
217+
config = read_config()
218+
219+
click.echo(get_project_dir().absolute())
220+
221+
# Build docker image.
222+
build_arg = "USE_TRANSFORMERS=true" if use_hf else None
223+
docker_project = config.docker_project
224+
if docker_project is None:
225+
docker_project = config.project
226+
docker_url = buildpush(docker_project, push_docker=False, build_arg=build_arg)
227+
# Forward a bunch of important env vars.
228+
env_forwarding = [
229+
arg for env_var in _DOCKER_ENV_FORWARD_LIST for arg in forward_env(env_var)
230+
]
231+
command = [
232+
"python",
233+
] + list(args)
234+
docker_command = [
235+
"docker",
236+
"run",
237+
"-i",
238+
*env_forwarding,
239+
"--privileged",
240+
"--net",
241+
"host",
242+
"--rm",
243+
"-v",
244+
f"{os.getcwd()}:/workspace",
245+
"-w",
246+
"/workspace",
247+
docker_url,
248+
] + command
249+
subprocess.run(docker_command, check=True)
250+
251+
198252
@cli.command(
199253
context_settings=dict(
200254
ignore_unknown_options=True,
@@ -255,12 +309,8 @@ def run(
255309

256310
# Forward a bunch of important env vars.
257311
env_forwarding = [
258-
*forward_env("HF_TOKEN"), # HuggingFace token
259-
*forward_env("XLA_IR_DEBUG"), # torch_xla debugging flag
260-
*forward_env("XLA_HLO_DEBUG"), # torch_xla debugging flag
261-
*forward_env("LIBTPU_INIT_ARGS"), # XLA flags
312+
arg for env_var in _DOCKER_ENV_FORWARD_LIST for arg in forward_env(env_var)
262313
]
263-
264314
# Pass artifact dir and jobset name as env vars.
265315
artifact_arg = [
266316
"--env",

0 commit comments

Comments
 (0)