Skip to content

Commit 12a0bb8

Browse files
authored
Add Profiling Section to LLM Pretraining Tutorial (#89)
1 parent b4b4ec0 commit 12a0bb8

File tree

2 files changed

+262
-23
lines changed

2 files changed

+262
-23
lines changed

docs/JAX_for_LLM_pretraining.ipynb

Lines changed: 177 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,41 +1010,203 @@
10101010
},
10111011
{
10121012
"cell_type": "markdown",
1013-
"metadata": {
1014-
"id": "jCApVd7671c1"
1015-
},
1013+
"id": "3813cbf2",
1014+
"metadata": {},
10161015
"source": [
1017-
"## Disconnect the Colab runtime"
1016+
"# Profiling for Hyperparameter Tuning"
10181017
]
10191018
},
10201019
{
10211020
"cell_type": "code",
1022-
"execution_count": 14,
1023-
"metadata": {
1024-
"id": "NsqYdbrDVKSq"
1025-
},
1021+
"execution_count": null,
1022+
"id": "b5d933c6",
1023+
"metadata": {},
10261024
"outputs": [],
10271025
"source": [
1028-
"from google.colab import runtime\n",
1029-
"runtime.unassign()"
1026+
"!pip install -Uq tensorboard-plugin-profile tensorflow tensorboard"
10301027
]
10311028
},
10321029
{
10331030
"cell_type": "markdown",
1034-
"metadata": {
1035-
"id": "Yj0vj28bIPwI"
1036-
},
1031+
"id": "2ac5fc4d",
1032+
"metadata": {},
1033+
"source": [
1034+
"Load the tensorboard colab extension."
1035+
]
1036+
},
1037+
{
1038+
"cell_type": "code",
1039+
"execution_count": null,
1040+
"id": "74f0c212",
1041+
"metadata": {},
1042+
"outputs": [],
1043+
"source": [
1044+
"%load_ext tensorboard"
1045+
]
1046+
},
1047+
{
1048+
"cell_type": "markdown",
1049+
"id": "17c6131f",
1050+
"metadata": {},
10371051
"source": [
1038-
"## One more thing\n",
1052+
"As we're going to be running this model a number of times, we need some scaffolding to more easily compare our work. For a baseline, we'll need to perform some warmup to guarantee that our code is JIT'd and that our TPUs are warm. For improved comparability, we'll only start tracing after we've finished warmup."
1053+
]
1054+
},
1055+
{
1056+
"cell_type": "code",
1057+
"execution_count": null,
1058+
"id": "ddfd576e",
1059+
"metadata": {},
1060+
"outputs": [],
1061+
"source": [
1062+
"trace_dir = \"/tmp/jax-trace/\"\n",
1063+
"\n",
1064+
"def loop_step(batch, step):\n",
1065+
" input_batch = jnp.array(jnp.array(batch).T)\n",
1066+
" target_batch = prep_target_batch(input_batch)\n",
1067+
" train_step(model, optimizer, metrics, jax.device_put((input_batch, target_batch), NamedSharding(mesh, P('batch', None))))\n",
10391068
"\n",
1040-
"Remember in cell #5, we use 4-way data parallel and 2-way tensor parallel. Of course there are different ways to partition your model/data. For example, 8-way data parallel is another popular way. To switch to 8-way data parallel, uncomment the last line in cell # 4 to replace the `Mesh` definition with:\n",
1069+
"def generate_trace():\n",
1070+
" tracing_steps = 30\n",
1071+
" warmup_steps = 5\n",
1072+
" for current_step in range(warmup_steps + tracing_steps):\n",
1073+
" if current_step == warmup_steps:\n",
1074+
" jax.profiler.start_trace(trace_dir)\n",
1075+
" with jax.profiler.StepTraceAnnotation(\"train\", step_num=current_step):\n",
1076+
" batch = next(text_dl)\n",
1077+
" loop_step(batch, current_step)\n",
1078+
"\n",
1079+
" jax.profiler.stop_trace()"
1080+
]
1081+
},
1082+
{
1083+
"cell_type": "markdown",
1084+
"id": "de70f5b7",
1085+
"metadata": {},
1086+
"source": [
1087+
"Now we'll perform some traces to compare results of different batch sizes. This will take several minutes as we need to reprocess our input data to prepare new batches each time."
1088+
]
1089+
},
1090+
{
1091+
"cell_type": "code",
1092+
"execution_count": null,
1093+
"id": "bc9452a6",
1094+
"metadata": {},
1095+
"outputs": [],
1096+
"source": [
1097+
"trace_dir = \"/tmp/jax-trace-batch-comparison/\"\n",
1098+
"\n",
1099+
"batch_size = 64\n",
1100+
"text_dl = iter(load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen))\n",
1101+
"generate_trace()\n",
1102+
"\n",
1103+
"batch_size = 256\n",
1104+
"text_dl = iter(load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen))\n",
1105+
"generate_trace()"
1106+
]
1107+
},
1108+
{
1109+
"cell_type": "markdown",
1110+
"id": "ea379965",
1111+
"metadata": {},
1112+
"source": [
1113+
"Run Tensorboard with the Profiler Plugin to compare our runs. Runs are listed in order from newest to oldest, so the top run in the list will be have `batch_size = 256`.\n",
1114+
"\n",
1115+
"The key metrics to focus on here for this hyperparameter are FLOPS Utilization and Average Step Time.\n",
1116+
"\n",
1117+
"In general, we want to maximize FLOPS Utilization while minimizing the step time per training example. In this case, we can see that increasing the batch size from 64 -> 256 achieves both of those. FLOPS increases from 16% to 27%. Average Step Time increase from 100ms to 260ms, however we increased our batch size by 300%. This means we move from 1.5ms per training example to 1.02ms per training example."
1118+
]
1119+
},
1120+
{
1121+
"cell_type": "code",
1122+
"execution_count": null,
1123+
"id": "b86c565a",
1124+
"metadata": {},
1125+
"outputs": [],
1126+
"source": [
1127+
"%tensorboard --logdir=$trace_dir"
1128+
]
1129+
},
1130+
{
1131+
"cell_type": "markdown",
1132+
"id": "657967a5",
1133+
"metadata": {},
1134+
"source": [
1135+
"Next, we can explore alternative parallelism methods. In cell #4, we used 4-way data parallel and 2-way tensor parallel. 8-way data parallel is another popular way. Let's compare results between them. To switch to 8-way data parallel, we'll replace the `Mesh` definition with:\n",
10411136
"\n",
10421137
"`mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))`\n",
10431138
"\n",
10441139
"JAX will automatically figure out how to shard the model and data to use the new partition strategy and nothing else need to be done. Re-connect the TPU runtime and run it again to see how it runs.\n",
10451140
"\n",
10461141
"How simple and powerful is this! And that's the beauty of JAX automatic parallelism."
10471142
]
1143+
},
1144+
{
1145+
"cell_type": "code",
1146+
"execution_count": null,
1147+
"id": "80daa8dc",
1148+
"metadata": {},
1149+
"outputs": [],
1150+
"source": [
1151+
"trace_dir = \"/tmp/jax-trace-parallelism-comparison/\"\n",
1152+
"\n",
1153+
"mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))\n",
1154+
"generate_trace()\n",
1155+
"\n",
1156+
"mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))\n",
1157+
"generate_trace()"
1158+
]
1159+
},
1160+
{
1161+
"cell_type": "markdown",
1162+
"id": "ad96e72b",
1163+
"metadata": {},
1164+
"source": [
1165+
"Once again we'll run tensorboard.\n",
1166+
"\n",
1167+
"Looking at the results, we see that the step times are nearly the same, however the FLOPS Utilization is at 13% for 8-way data parallelism compared to 27% or 4-way data parallelism.\n",
1168+
"\n",
1169+
"By looking at the Trace Viewer tool and looking under each TPU's ops, we can see that the TPUs spend a large amount of time idle while waiting for the host, as well as spending a good amount of time in `reduce_sum` operations."
1170+
]
1171+
},
1172+
{
1173+
"cell_type": "code",
1174+
"execution_count": null,
1175+
"id": "780e9c72",
1176+
"metadata": {},
1177+
"outputs": [],
1178+
"source": [
1179+
"%tensorboard --logdir=$trace_dir"
1180+
]
1181+
},
1182+
{
1183+
"cell_type": "markdown",
1184+
"id": "deca486e",
1185+
"metadata": {},
1186+
"source": [
1187+
"By changing hyperparameters and comparing profiles, we're able to gain significant insights into our bottlenecks and limitations. These are just two examples of hyperparameters to tune, but plenty more of them will have significant effects on training speed and resource utilization."
1188+
]
1189+
},
1190+
{
1191+
"cell_type": "markdown",
1192+
"metadata": {
1193+
"id": "jCApVd7671c1"
1194+
},
1195+
"source": [
1196+
"## Disconnect the Colab runtime"
1197+
]
1198+
},
1199+
{
1200+
"cell_type": "code",
1201+
"execution_count": 14,
1202+
"metadata": {
1203+
"id": "NsqYdbrDVKSq"
1204+
},
1205+
"outputs": [],
1206+
"source": [
1207+
"from google.colab import runtime\n",
1208+
"runtime.unassign()"
1209+
]
10481210
}
10491211
],
10501212
"metadata": {

docs/JAX_for_LLM_pretraining.md

Lines changed: 85 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -462,25 +462,102 @@ checkpointer.save('/content/save', state)
462462
!ls /content/save/
463463
```
464464

465-
+++ {"id": "jCApVd7671c1"}
465+
# Profiling for Hyperparameter Tuning
466466

467-
## Disconnect the Colab runtime
467+
```{code-cell}
468+
!pip install -Uq tensorboard-plugin-profile tensorflow tensorboard
469+
```
470+
471+
Load the tensorboard colab extension.
468472

469473
```{code-cell}
470-
:id: NsqYdbrDVKSq
474+
%load_ext tensorboard
475+
```
471476

472-
from google.colab import runtime
473-
runtime.unassign()
477+
As we're going to be running this model a number of times, we need some scaffolding to more easily compare our work. For a baseline, we'll need to perform some warmup to guarantee that our code is JIT'd and that our TPUs are warm. For improved comparability, we'll only start tracing after we've finished warmup.
478+
479+
```{code-cell}
480+
trace_dir = "/tmp/jax-trace/"
481+
482+
def loop_step(batch, step):
483+
input_batch = jnp.array(jnp.array(batch).T)
484+
target_batch = prep_target_batch(input_batch)
485+
train_step(model, optimizer, metrics, jax.device_put((input_batch, target_batch), NamedSharding(mesh, P('batch', None))))
486+
487+
def generate_trace():
488+
tracing_steps = 30
489+
warmup_steps = 5
490+
for current_step in range(warmup_steps + tracing_steps):
491+
if current_step == warmup_steps:
492+
jax.profiler.start_trace(trace_dir)
493+
with jax.profiler.StepTraceAnnotation("train", step_num=current_step):
494+
batch = next(text_dl)
495+
loop_step(batch, current_step)
496+
497+
jax.profiler.stop_trace()
474498
```
475499

476-
+++ {"id": "Yj0vj28bIPwI"}
500+
Now we'll perform some traces to compare results of different batch sizes. This will take several minutes as we need to reprocess our input data to prepare new batches each time.
477501

478-
## One more thing
502+
```{code-cell}
503+
trace_dir = "/tmp/jax-trace-batch-comparison/"
479504
480-
Remember in cell #5, we use 4-way data parallel and 2-way tensor parallel. Of course there are different ways to partition your model/data. For example, 8-way data parallel is another popular way. To switch to 8-way data parallel, uncomment the last line in cell # 4 to replace the `Mesh` definition with:
505+
batch_size = 64
506+
text_dl = iter(load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen))
507+
generate_trace()
508+
509+
batch_size = 256
510+
text_dl = iter(load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen))
511+
generate_trace()
512+
```
513+
514+
Run Tensorboard with the Profiler Plugin to compare our runs. Runs are listed in order from newest to oldest, so the top run in the list will be have `batch_size = 256`.
515+
516+
The key metrics to focus on here for this hyperparameter are FLOPS Utilization and Average Step Time.
517+
518+
In general, we want to maximize FLOPS Utilization while minimizing the step time per training example. In this case, we can see that increasing the batch size from 64 -> 256 achieves both of those. FLOPS increases from 16% to 27%. Average Step Time increase from 100ms to 260ms, however we increased our batch size by 300%. This means we move from 1.5ms per training example to 1.02ms per training example.
519+
520+
```{code-cell}
521+
%tensorboard --logdir=$trace_dir
522+
```
523+
524+
Next, we can explore alternative parallelism methods. In cell #4, we used 4-way data parallel and 2-way tensor parallel. 8-way data parallel is another popular way. Let's compare results between them. To switch to 8-way data parallel, we'll replace the `Mesh` definition with:
481525

482526
`mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))`
483527

484528
JAX will automatically figure out how to shard the model and data to use the new partition strategy and nothing else need to be done. Re-connect the TPU runtime and run it again to see how it runs.
485529

486530
How simple and powerful is this! And that's the beauty of JAX automatic parallelism.
531+
532+
```{code-cell}
533+
trace_dir = "/tmp/jax-trace-parallelism-comparison/"
534+
535+
mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))
536+
generate_trace()
537+
538+
mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))
539+
generate_trace()
540+
```
541+
542+
Once again we'll run tensorboard.
543+
544+
Looking at the results, we see that the step times are nearly the same, however the FLOPS Utilization is at 13% for 8-way data parallelism compared to 27% or 4-way data parallelism.
545+
546+
By looking at the Trace Viewer tool and looking under each TPU's ops, we can see that the TPUs spend a large amount of time idle while waiting for the host, as well as spending a good amount of time in `reduce_sum` operations.
547+
548+
```{code-cell}
549+
%tensorboard --logdir=$trace_dir
550+
```
551+
552+
By changing hyperparameters and comparing profiles, we're able to gain significant insights into our bottlenecks and limitations. These are just two examples of hyperparameters to tune, but plenty more of them will have significant effects on training speed and resource utilization.
553+
554+
+++ {"id": "jCApVd7671c1"}
555+
556+
## Disconnect the Colab runtime
557+
558+
```{code-cell}
559+
:id: NsqYdbrDVKSq
560+
561+
from google.colab import runtime
562+
runtime.unassign()
563+
```

0 commit comments

Comments
 (0)