File tree Expand file tree Collapse file tree 4 files changed +45
-5
lines changed Expand file tree Collapse file tree 4 files changed +45
-5
lines changed Original file line number Diff line number Diff line change 1
- # note: the sharding annonation is currently hard coded in pytorch-tpu/transformers
1
+ # note: the sharding annotation is currently hard coded in pytorch-tpu/transformers
2
2
defaults :
3
3
- _self_
4
4
@@ -15,7 +15,11 @@ train_script:
15
15
args :
16
16
dataset_name : " wikitext"
17
17
dataset_config_name : " wikitext-103-raw-v1"
18
- per_device_train_batch_size : 256 # this is global batch size if use minibatch
18
+
19
+ # If minibatch is False, this should be set to the global batch size.
20
+ # If minibatch is True, this should be set to the per host batch size.
21
+ per_device_train_batch_size : 256
22
+
19
23
do_train : true
20
24
output_dir : " test-clm"
21
25
overwrite_output_dir : true
@@ -32,4 +36,4 @@ train_script:
32
36
dataloader_drop_last : true
33
37
flash_attention : true
34
38
max_steps : 50
35
- seed : 42
39
+ seed : 42
Original file line number Diff line number Diff line change
1
+ {
2
+ "architectures" : [
3
+ " LlamaForCausalLM"
4
+ ],
5
+ "attention_bias" : false ,
6
+ "attention_dropout" : 0.0 ,
7
+ "bos_token_id" : 128000 ,
8
+ "eos_token_id" : 128001 ,
9
+ "hidden_act" : " silu" ,
10
+ "hidden_size" : 16384 ,
11
+ "initializer_range" : 0.02 ,
12
+ "intermediate_size" : 53248 ,
13
+ "max_position_embeddings" : 131072 ,
14
+ "mlp_bias" : false ,
15
+ "model_type" : " llama" ,
16
+ "num_attention_heads" : 128 ,
17
+ "num_hidden_layers" : 126 ,
18
+ "num_key_value_heads" : 8 ,
19
+ "pretraining_tp" : 1 ,
20
+ "rms_norm_eps" : 1e-05 ,
21
+ "rope_scaling" : {
22
+ "factor" : 8.0 ,
23
+ "low_freq_factor" : 1.0 ,
24
+ "high_freq_factor" : 4.0 ,
25
+ "original_max_position_embeddings" : 8192 ,
26
+ "rope_type" : " llama3"
27
+ },
28
+ "rope_theta" : 500000.0 ,
29
+ "tie_word_embeddings" : false ,
30
+ "torch_dtype" : " bfloat16" ,
31
+ "transformers_version" : " 4.42.3" ,
32
+ "use_cache" : false ,
33
+ "vocab_size" : 128256
34
+ }
Original file line number Diff line number Diff line change @@ -26,6 +26,9 @@ def build_command(config: DictConfig) -> list:
26
26
args [k ] = v
27
27
28
28
for k , v in args .items ():
29
+ if v is None :
30
+ # We may delete an argument by setting it to `null` on the CLI.
31
+ continue
29
32
if isinstance (v , bool ):
30
33
if v :
31
34
cmd .append (f"--{ k } " )
Original file line number Diff line number Diff line change @@ -52,8 +52,7 @@ RUN if [ "$USE_TRANSFORMERS" = "true" ] && [ -d "local_transformers" ]; then \
52
52
53
53
# Only install transformers if USE_TRANSFORMERS is true
54
54
RUN if [ "$USE_TRANSFORMERS" = "true" ]; then \
55
- pip install --no-deps -e /workspaces/torchprime/local_transformers; \
56
- pip install --no-deps evaluate; \
55
+ pip install -e /workspaces/torchprime/local_transformers evaluate; \
57
56
fi
58
57
59
58
ENV LIBTPU_INIT_ARGS "--xla_tpu_scoped_vmem_limit_kib=98304 --xla_enable_async_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true"
You can’t perform that action at this time.
0 commit comments