|
1010 | 1010 | },
|
1011 | 1011 | {
|
1012 | 1012 | "cell_type": "markdown",
|
1013 |
| - "metadata": { |
1014 |
| - "id": "jCApVd7671c1" |
1015 |
| - }, |
| 1013 | + "id": "3813cbf2", |
| 1014 | + "metadata": {}, |
1016 | 1015 | "source": [
|
1017 |
| - "## Disconnect the Colab runtime" |
| 1016 | + "# Profiling for Hyperparameter Tuning" |
1018 | 1017 | ]
|
1019 | 1018 | },
|
1020 | 1019 | {
|
1021 | 1020 | "cell_type": "code",
|
1022 |
| - "execution_count": 14, |
1023 |
| - "metadata": { |
1024 |
| - "id": "NsqYdbrDVKSq" |
1025 |
| - }, |
| 1021 | + "execution_count": null, |
| 1022 | + "id": "b5d933c6", |
| 1023 | + "metadata": {}, |
1026 | 1024 | "outputs": [],
|
1027 | 1025 | "source": [
|
1028 |
| - "from google.colab import runtime\n", |
1029 |
| - "runtime.unassign()" |
| 1026 | + "!pip install -Uq tensorboard-plugin-profile tensorflow tensorboard" |
1030 | 1027 | ]
|
1031 | 1028 | },
|
1032 | 1029 | {
|
1033 | 1030 | "cell_type": "markdown",
|
1034 |
| - "metadata": { |
1035 |
| - "id": "Yj0vj28bIPwI" |
1036 |
| - }, |
| 1031 | + "id": "2ac5fc4d", |
| 1032 | + "metadata": {}, |
| 1033 | + "source": [ |
| 1034 | + "Load the tensorboard colab extension." |
| 1035 | + ] |
| 1036 | + }, |
| 1037 | + { |
| 1038 | + "cell_type": "code", |
| 1039 | + "execution_count": null, |
| 1040 | + "id": "74f0c212", |
| 1041 | + "metadata": {}, |
| 1042 | + "outputs": [], |
| 1043 | + "source": [ |
| 1044 | + "%load_ext tensorboard" |
| 1045 | + ] |
| 1046 | + }, |
| 1047 | + { |
| 1048 | + "cell_type": "markdown", |
| 1049 | + "id": "17c6131f", |
| 1050 | + "metadata": {}, |
1037 | 1051 | "source": [
|
1038 |
| - "## One more thing\n", |
| 1052 | + "As we're going to be running this model a number of times, we need some scaffolding to more easily compare our work. For a baseline, we'll need to perform some warmup to guarantee that our code is JIT'd and that our TPUs are warm. For improved comparability, we'll only start tracing after we've finished warmup." |
| 1053 | + ] |
| 1054 | + }, |
| 1055 | + { |
| 1056 | + "cell_type": "code", |
| 1057 | + "execution_count": null, |
| 1058 | + "id": "ddfd576e", |
| 1059 | + "metadata": {}, |
| 1060 | + "outputs": [], |
| 1061 | + "source": [ |
| 1062 | + "trace_dir = \"/tmp/jax-trace/\"\n", |
| 1063 | + "\n", |
| 1064 | + "def loop_step(batch, step):\n", |
| 1065 | + " input_batch = jnp.array(jnp.array(batch).T)\n", |
| 1066 | + " target_batch = prep_target_batch(input_batch)\n", |
| 1067 | + " train_step(model, optimizer, metrics, jax.device_put((input_batch, target_batch), NamedSharding(mesh, P('batch', None))))\n", |
1039 | 1068 | "\n",
|
1040 |
| - "Remember in cell #5, we use 4-way data parallel and 2-way tensor parallel. Of course there are different ways to partition your model/data. For example, 8-way data parallel is another popular way. To switch to 8-way data parallel, uncomment the last line in cell # 4 to replace the `Mesh` definition with:\n", |
| 1069 | + "def generate_trace():\n", |
| 1070 | + " tracing_steps = 30\n", |
| 1071 | + " warmup_steps = 5\n", |
| 1072 | + " for current_step in range(warmup_steps + tracing_steps):\n", |
| 1073 | + " if current_step == warmup_steps:\n", |
| 1074 | + " jax.profiler.start_trace(trace_dir)\n", |
| 1075 | + " with jax.profiler.StepTraceAnnotation(\"train\", step_num=current_step):\n", |
| 1076 | + " batch = next(text_dl)\n", |
| 1077 | + " loop_step(batch, current_step)\n", |
| 1078 | + "\n", |
| 1079 | + " jax.profiler.stop_trace()" |
| 1080 | + ] |
| 1081 | + }, |
| 1082 | + { |
| 1083 | + "cell_type": "markdown", |
| 1084 | + "id": "de70f5b7", |
| 1085 | + "metadata": {}, |
| 1086 | + "source": [ |
| 1087 | + "Now we'll perform some traces to compare results of different batch sizes. This will take several minutes as we need to reprocess our input data to prepare new batches each time." |
| 1088 | + ] |
| 1089 | + }, |
| 1090 | + { |
| 1091 | + "cell_type": "code", |
| 1092 | + "execution_count": null, |
| 1093 | + "id": "bc9452a6", |
| 1094 | + "metadata": {}, |
| 1095 | + "outputs": [], |
| 1096 | + "source": [ |
| 1097 | + "trace_dir = \"/tmp/jax-trace-batch-comparison/\"\n", |
| 1098 | + "\n", |
| 1099 | + "batch_size = 64\n", |
| 1100 | + "text_dl = iter(load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen))\n", |
| 1101 | + "generate_trace()\n", |
| 1102 | + "\n", |
| 1103 | + "batch_size = 256\n", |
| 1104 | + "text_dl = iter(load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen))\n", |
| 1105 | + "generate_trace()" |
| 1106 | + ] |
| 1107 | + }, |
| 1108 | + { |
| 1109 | + "cell_type": "markdown", |
| 1110 | + "id": "ea379965", |
| 1111 | + "metadata": {}, |
| 1112 | + "source": [ |
| 1113 | + "Run Tensorboard with the Profiler Plugin to compare our runs. Runs are listed in order from newest to oldest, so the top run in the list will be have `batch_size = 256`.\n", |
| 1114 | + "\n", |
| 1115 | + "The key metrics to focus on here for this hyperparameter are FLOPS Utilization and Average Step Time.\n", |
| 1116 | + "\n", |
| 1117 | + "In general, we want to maximize FLOPS Utilization while minimizing the step time per training example. In this case, we can see that increasing the batch size from 64 -> 256 achieves both of those. FLOPS increases from 16% to 27%. Average Step Time increase from 100ms to 260ms, however we increased our batch size by 300%. This means we move from 1.5ms per training example to 1.02ms per training example." |
| 1118 | + ] |
| 1119 | + }, |
| 1120 | + { |
| 1121 | + "cell_type": "code", |
| 1122 | + "execution_count": null, |
| 1123 | + "id": "b86c565a", |
| 1124 | + "metadata": {}, |
| 1125 | + "outputs": [], |
| 1126 | + "source": [ |
| 1127 | + "%tensorboard --logdir=$trace_dir" |
| 1128 | + ] |
| 1129 | + }, |
| 1130 | + { |
| 1131 | + "cell_type": "markdown", |
| 1132 | + "id": "657967a5", |
| 1133 | + "metadata": {}, |
| 1134 | + "source": [ |
| 1135 | + "Next, we can explore alternative parallelism methods. In cell #4, we used 4-way data parallel and 2-way tensor parallel. 8-way data parallel is another popular way. Let's compare results between them. To switch to 8-way data parallel, we'll replace the `Mesh` definition with:\n", |
1041 | 1136 | "\n",
|
1042 | 1137 | "`mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))`\n",
|
1043 | 1138 | "\n",
|
1044 | 1139 | "JAX will automatically figure out how to shard the model and data to use the new partition strategy and nothing else need to be done. Re-connect the TPU runtime and run it again to see how it runs.\n",
|
1045 | 1140 | "\n",
|
1046 | 1141 | "How simple and powerful is this! And that's the beauty of JAX automatic parallelism."
|
1047 | 1142 | ]
|
| 1143 | + }, |
| 1144 | + { |
| 1145 | + "cell_type": "code", |
| 1146 | + "execution_count": null, |
| 1147 | + "id": "80daa8dc", |
| 1148 | + "metadata": {}, |
| 1149 | + "outputs": [], |
| 1150 | + "source": [ |
| 1151 | + "trace_dir = \"/tmp/jax-trace-parallelism-comparison/\"\n", |
| 1152 | + "\n", |
| 1153 | + "mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))\n", |
| 1154 | + "generate_trace()\n", |
| 1155 | + "\n", |
| 1156 | + "mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))\n", |
| 1157 | + "generate_trace()" |
| 1158 | + ] |
| 1159 | + }, |
| 1160 | + { |
| 1161 | + "cell_type": "markdown", |
| 1162 | + "id": "ad96e72b", |
| 1163 | + "metadata": {}, |
| 1164 | + "source": [ |
| 1165 | + "Once again we'll run tensorboard.\n", |
| 1166 | + "\n", |
| 1167 | + "Looking at the results, we see that the step times are nearly the same, however the FLOPS Utilization is at 13% for 8-way data parallelism compared to 27% or 4-way data parallelism.\n", |
| 1168 | + "\n", |
| 1169 | + "By looking at the Trace Viewer tool and looking under each TPU's ops, we can see that the TPUs spend a large amount of time idle while waiting for the host, as well as spending a good amount of time in `reduce_sum` operations." |
| 1170 | + ] |
| 1171 | + }, |
| 1172 | + { |
| 1173 | + "cell_type": "code", |
| 1174 | + "execution_count": null, |
| 1175 | + "id": "780e9c72", |
| 1176 | + "metadata": {}, |
| 1177 | + "outputs": [], |
| 1178 | + "source": [ |
| 1179 | + "%tensorboard --logdir=$trace_dir" |
| 1180 | + ] |
| 1181 | + }, |
| 1182 | + { |
| 1183 | + "cell_type": "markdown", |
| 1184 | + "id": "deca486e", |
| 1185 | + "metadata": {}, |
| 1186 | + "source": [ |
| 1187 | + "By changing hyperparameters and comparing profiles, we're able to gain significant insights into our bottlenecks and limitations. These are just two examples of hyperparameters to tune, but plenty more of them will have significant effects on training speed and resource utilization." |
| 1188 | + ] |
| 1189 | + }, |
| 1190 | + { |
| 1191 | + "cell_type": "markdown", |
| 1192 | + "metadata": { |
| 1193 | + "id": "jCApVd7671c1" |
| 1194 | + }, |
| 1195 | + "source": [ |
| 1196 | + "## Disconnect the Colab runtime" |
| 1197 | + ] |
| 1198 | + }, |
| 1199 | + { |
| 1200 | + "cell_type": "code", |
| 1201 | + "execution_count": 14, |
| 1202 | + "metadata": { |
| 1203 | + "id": "NsqYdbrDVKSq" |
| 1204 | + }, |
| 1205 | + "outputs": [], |
| 1206 | + "source": [ |
| 1207 | + "from google.colab import runtime\n", |
| 1208 | + "runtime.unassign()" |
| 1209 | + ] |
1048 | 1210 | }
|
1049 | 1211 | ],
|
1050 | 1212 | "metadata": {
|
|
0 commit comments