Skip to content

Commit 07f3dfe

Browse files
authored
Update HF runner configs (#110)
Added a Llama 3.1 405B configuration. Also fixed some comments in the config yaml. In particular for per_device_train_batch_size, I based the comment on http://shortn/_SR8wqscDlo. Also changes Dockerfile to install all transformer deps, because the local_transformer may be newer than our transformer dependency. In that case it may require a newer dependency such as tokenizers==0.20 instead of tokenizers==0.19.
1 parent ed8655d commit 07f3dfe

File tree

4 files changed

+45
-5
lines changed

4 files changed

+45
-5
lines changed

torchprime/hf_models/configs/default.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# note: the sharding annonation is currently hard coded in pytorch-tpu/transformers
1+
# note: the sharding annotation is currently hard coded in pytorch-tpu/transformers
22
defaults:
33
- _self_
44

@@ -15,7 +15,11 @@ train_script:
1515
args:
1616
dataset_name: "wikitext"
1717
dataset_config_name: "wikitext-103-raw-v1"
18-
per_device_train_batch_size: 256 # this is global batch size if use minibatch
18+
19+
# If minibatch is False, this should be set to the global batch size.
20+
# If minibatch is True, this should be set to the per host batch size.
21+
per_device_train_batch_size: 256
22+
1923
do_train: true
2024
output_dir: "test-clm"
2125
overwrite_output_dir: true
@@ -32,4 +36,4 @@ train_script:
3236
dataloader_drop_last: true
3337
flash_attention: true
3438
max_steps: 50
35-
seed: 42
39+
seed: 42
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"architectures": [
3+
"LlamaForCausalLM"
4+
],
5+
"attention_bias": false,
6+
"attention_dropout": 0.0,
7+
"bos_token_id": 128000,
8+
"eos_token_id": 128001,
9+
"hidden_act": "silu",
10+
"hidden_size": 16384,
11+
"initializer_range": 0.02,
12+
"intermediate_size": 53248,
13+
"max_position_embeddings": 131072,
14+
"mlp_bias": false,
15+
"model_type": "llama",
16+
"num_attention_heads": 128,
17+
"num_hidden_layers": 126,
18+
"num_key_value_heads": 8,
19+
"pretraining_tp": 1,
20+
"rms_norm_eps": 1e-05,
21+
"rope_scaling": {
22+
"factor": 8.0,
23+
"low_freq_factor": 1.0,
24+
"high_freq_factor": 4.0,
25+
"original_max_position_embeddings": 8192,
26+
"rope_type": "llama3"
27+
},
28+
"rope_theta": 500000.0,
29+
"tie_word_embeddings": false,
30+
"torch_dtype": "bfloat16",
31+
"transformers_version": "4.42.3",
32+
"use_cache": false,
33+
"vocab_size": 128256
34+
}

torchprime/hf_models/train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def build_command(config: DictConfig) -> list:
2626
args[k] = v
2727

2828
for k, v in args.items():
29+
if v is None:
30+
# We may delete an argument by setting it to `null` on the CLI.
31+
continue
2932
if isinstance(v, bool):
3033
if v:
3134
cmd.append(f"--{k}")

torchprime/launcher/Dockerfile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ RUN if [ "$USE_TRANSFORMERS" = "true" ] && [ -d "local_transformers" ]; then \
5252

5353
# Only install transformers if USE_TRANSFORMERS is true
5454
RUN if [ "$USE_TRANSFORMERS" = "true" ]; then \
55-
pip install --no-deps -e /workspaces/torchprime/local_transformers; \
56-
pip install --no-deps evaluate; \
55+
pip install -e /workspaces/torchprime/local_transformers evaluate; \
5756
fi
5857

5958
ENV LIBTPU_INIT_ARGS "--xla_tpu_scoped_vmem_limit_kib=98304 --xla_enable_async_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true"

0 commit comments

Comments
 (0)