|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "code", |
| 5 | + "execution_count": null, |
| 6 | + "metadata": { |
| 7 | + "collapsed": false |
| 8 | + }, |
| 9 | + "outputs": [], |
| 10 | + "source": [ |
| 11 | + "# For tips on running notebooks in Google Colab, see\n", |
| 12 | + "# https://pytorch.org/tutorials/beginner/colab\n", |
| 13 | + "%matplotlib inline" |
| 14 | + ] |
| 15 | + }, |
| 16 | + { |
| 17 | + "cell_type": "markdown", |
| 18 | + "metadata": {}, |
| 19 | + "source": [ |
| 20 | + "(beta) Explicit horizontal fusion with foreach\\_map and torch.compile\n", |
| 21 | + "============================================================\n", |
| 22 | + "\n", |
| 23 | + "**Author:** [Michael Lazos](https://github.com/mlazos)\n" |
| 24 | + ] |
| 25 | + }, |
| 26 | + { |
| 27 | + "cell_type": "markdown", |
| 28 | + "metadata": {}, |
| 29 | + "source": [ |
| 30 | + "Horizontal fusion is a key optimization in ML compilers. In eager,\n", |
| 31 | + "\n", |
| 32 | + ": this is typically expressed using the torch.\\_foreach\\* ops which\n", |
| 33 | + " parallelizes operations across a list of tensors. However,\n", |
| 34 | + " supporting all possible permutations of arguments is quite difficult\n", |
| 35 | + " (e.g. mixtures of scalars and lists). Foreach\\_map allows conversion\n", |
| 36 | + " of any pointwise op in `torch` to a horiztonally fused foreach\n", |
| 37 | + " variant. In this tutorial, we will demonstrate how to implement the\n", |
| 38 | + " Adam optimizer with `foreach_map` to generate a fully fused kernel.\n", |
| 39 | + "\n", |
| 40 | + "<div style=\"background-color: #54c7ec; color: #fff; font-weight: 700; padding-left: 10px; padding-top: 5px; padding-bottom: 5px\"><strong>NOTE:</strong></div>\n", |
| 41 | + "\n", |
| 42 | + "<div style=\"background-color: #f3f4f7; padding-left: 10px; padding-top: 10px; padding-bottom: 10px; padding-right: 10px\">\n", |
| 43 | + "\n", |
| 44 | + "<p>This tutorial requires PyTorch 2.7.0 or later.</p>\n", |
| 45 | + "\n", |
| 46 | + "</div>\n", |
| 47 | + "\n" |
| 48 | + ] |
| 49 | + }, |
| 50 | + { |
| 51 | + "cell_type": "markdown", |
| 52 | + "metadata": {}, |
| 53 | + "source": [ |
| 54 | + "Model Setup\n", |
| 55 | + "===========\n", |
| 56 | + "\n", |
| 57 | + "For this example, we\\'ll use a simple sequence of linear layers. We\n", |
| 58 | + "instantiate an independent copy to compare the two optimizer\n", |
| 59 | + "implementations.\n" |
| 60 | + ] |
| 61 | + }, |
| 62 | + { |
| 63 | + "cell_type": "code", |
| 64 | + "execution_count": null, |
| 65 | + "metadata": { |
| 66 | + "collapsed": false |
| 67 | + }, |
| 68 | + "outputs": [], |
| 69 | + "source": [ |
| 70 | + "import torch\n", |
| 71 | + "\n", |
| 72 | + "# exit cleanly if we are on a device that doesn't support ``torch.compile``\n", |
| 73 | + "if torch.cuda.get_device_capability() < (7, 0):\n", |
| 74 | + " print(\"Exiting because torch.compile is not supported on this device.\")\n", |
| 75 | + " import sys\n", |
| 76 | + " sys.exit(0)\n", |
| 77 | + "\n", |
| 78 | + "# Create simple model\n", |
| 79 | + "model = torch.nn.Sequential(\n", |
| 80 | + " *[torch.nn.Linear(1024, 1024, False, device=\"cuda\") for _ in range(10)]\n", |
| 81 | + ")\n", |
| 82 | + "model_copy = torch.nn.Sequential(\n", |
| 83 | + " *[torch.nn.Linear(1024, 1024, False, device=\"cuda\") for _ in range(10)]\n", |
| 84 | + ")\n", |
| 85 | + "input = torch.rand(1024, device=\"cuda\")\n", |
| 86 | + "\n", |
| 87 | + "# run forward pass\n", |
| 88 | + "output = model(input)\n", |
| 89 | + "output_copy = model_copy(input)\n", |
| 90 | + "\n", |
| 91 | + "# run backward to populate the grads for our optimizer below\n", |
| 92 | + "output.sum().backward()\n", |
| 93 | + "output_copy.sum().backward()" |
| 94 | + ] |
| 95 | + }, |
| 96 | + { |
| 97 | + "cell_type": "markdown", |
| 98 | + "metadata": {}, |
| 99 | + "source": [ |
| 100 | + "Helper functions for foreach\\_map implementation\n", |
| 101 | + "================================================\n", |
| 102 | + "\n", |
| 103 | + "In this section, we\\'ll begin our implementation of the Adam optimizer.\n" |
| 104 | + ] |
| 105 | + }, |
| 106 | + { |
| 107 | + "cell_type": "code", |
| 108 | + "execution_count": null, |
| 109 | + "metadata": { |
| 110 | + "collapsed": false |
| 111 | + }, |
| 112 | + "outputs": [], |
| 113 | + "source": [ |
| 114 | + "from torch._higher_order_ops.foreach_map import foreach_map\n", |
| 115 | + "\n", |
| 116 | + "# Helper function to extract optimizer states from a torch.optim.Adam instance\n", |
| 117 | + "def get_inputs(optim):\n", |
| 118 | + " steps = []\n", |
| 119 | + " params = []\n", |
| 120 | + " grads = []\n", |
| 121 | + " exp_avgs = []\n", |
| 122 | + " exp_avg_sqs = []\n", |
| 123 | + " for group in optim.param_groups:\n", |
| 124 | + " for p in group[\"params\"]:\n", |
| 125 | + " params.append(p)\n", |
| 126 | + " grads.append(p.grad)\n", |
| 127 | + " state = optim.state[p]\n", |
| 128 | + " exp_avgs.append(state[\"exp_avg\"])\n", |
| 129 | + " exp_avg_sqs.append(state[\"exp_avg_sq\"])\n", |
| 130 | + " steps.append(state[\"step\"])\n", |
| 131 | + "\n", |
| 132 | + " return steps, params, exp_avgs, exp_avg_sqs\n", |
| 133 | + "\n", |
| 134 | + "\n", |
| 135 | + "# Functions to update the different optimizer states\n", |
| 136 | + "def update_exp_avg_sq(exp_avg_sq, grad, beta2):\n", |
| 137 | + " return exp_avg_sq.mul(beta2).addcmul(grad, grad, value=1 - beta2)\n", |
| 138 | + "\n", |
| 139 | + "def update_param(param, step, exp_avg, exp_avg_sq, beta1, beta2, lr, eps):\n", |
| 140 | + " bias_correction1 = 1 - torch.pow(beta1, step)\n", |
| 141 | + " bias_correction2 = (1 - torch.pow(beta2, step)).sqrt()\n", |
| 142 | + " step_size = (lr / bias_correction1).neg()\n", |
| 143 | + " denom = (exp_avg_sq.sqrt() / (bias_correction2 * step_size)).add(eps / step_size)\n", |
| 144 | + " return torch.add(param, torch.div(exp_avg, denom))\n", |
| 145 | + "\n", |
| 146 | + "# Our full Adam implementation\n", |
| 147 | + "def foreach_map_adam(\n", |
| 148 | + " steps,\n", |
| 149 | + " params,\n", |
| 150 | + " exp_avgs,\n", |
| 151 | + " exp_avg_sqs,\n", |
| 152 | + " weight_decay=0,\n", |
| 153 | + " beta1=0.9,\n", |
| 154 | + " beta2=0.999,\n", |
| 155 | + " lr=1e-3,\n", |
| 156 | + " eps=1e-8,\n", |
| 157 | + "):\n", |
| 158 | + " with torch.no_grad():\n", |
| 159 | + " grads = [param.grad for param in params]\n", |
| 160 | + " # update step\n", |
| 161 | + " updated_steps = foreach_map(lambda x: x + 1, steps)\n", |
| 162 | + " torch._foreach_copy_(steps, updated_steps)\n", |
| 163 | + "\n", |
| 164 | + " if weight_decay != 0:\n", |
| 165 | + " foreach_map(torch.add, (grads,), alpha=weight_decay)\n", |
| 166 | + "\n", |
| 167 | + " # Higher-order operators (HOPs) cannot have multiple outputs at the moment\n", |
| 168 | + " # need to call foreach_map once for each output\n", |
| 169 | + " exp_avgs_updated = foreach_map(torch.lerp, exp_avgs, grads, 1 - beta1)\n", |
| 170 | + " exp_avgs_sq_updated = foreach_map(update_exp_avg_sq, exp_avg_sqs, grads, beta2)\n", |
| 171 | + " params_updated = foreach_map(\n", |
| 172 | + " update_param,\n", |
| 173 | + " params,\n", |
| 174 | + " steps,\n", |
| 175 | + " exp_avgs_updated,\n", |
| 176 | + " exp_avgs_sq_updated,\n", |
| 177 | + " beta1,\n", |
| 178 | + " beta2,\n", |
| 179 | + " lr,\n", |
| 180 | + " eps,\n", |
| 181 | + " )\n", |
| 182 | + " # Higher-order operators (HOPs) don't support input mutation today\n", |
| 183 | + " # so manually update the states in-place\n", |
| 184 | + " torch._foreach_copy_(exp_avgs, exp_avgs_updated)\n", |
| 185 | + " torch._foreach_copy_(exp_avg_sqs, exp_avgs_sq_updated)\n", |
| 186 | + " torch._foreach_copy_(params, params_updated)\n", |
| 187 | + " return" |
| 188 | + ] |
| 189 | + }, |
| 190 | + { |
| 191 | + "cell_type": "markdown", |
| 192 | + "metadata": {}, |
| 193 | + "source": [ |
| 194 | + "Setting up and running the compiled kernel\n", |
| 195 | + "==========================================\n", |
| 196 | + "\n", |
| 197 | + "In this section, we\\'ll run our Adam optimizer and compare the results\n", |
| 198 | + "\n", |
| 199 | + "<div style=\"background-color: #54c7ec; color: #fff; font-weight: 700; padding-left: 10px; padding-top: 5px; padding-bottom: 5px\"><strong>NOTE:</strong></div>\n", |
| 200 | + "\n", |
| 201 | + "<div style=\"background-color: #f3f4f7; padding-left: 10px; padding-top: 10px; padding-bottom: 10px; padding-right: 10px\">\n", |
| 202 | + "\n", |
| 203 | + "<p><code>torch.compile</code> is only supported on CUDA devices that have a compute capability of 7.0 or higher.</p>\n", |
| 204 | + "\n", |
| 205 | + "</div>\n", |
| 206 | + "\n" |
| 207 | + ] |
| 208 | + }, |
| 209 | + { |
| 210 | + "cell_type": "code", |
| 211 | + "execution_count": null, |
| 212 | + "metadata": { |
| 213 | + "collapsed": false |
| 214 | + }, |
| 215 | + "outputs": [], |
| 216 | + "source": [ |
| 217 | + "opt_eager = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01))\n", |
| 218 | + "opt_eager_copy = torch.optim.Adam(model_copy.parameters(), lr=torch.tensor(0.01))\n", |
| 219 | + "\n", |
| 220 | + "# warm up the optimizer state dict\n", |
| 221 | + "opt_eager.step()\n", |
| 222 | + "opt_eager_copy.step()\n", |
| 223 | + "\n", |
| 224 | + "inputs = get_inputs(opt_eager_copy)\n", |
| 225 | + "compiled_adam = torch.compile(foreach_map_adam)\n", |
| 226 | + "\n", |
| 227 | + "# optionally view the output code\n", |
| 228 | + "torch._logging.set_logs(output_code=True)\n", |
| 229 | + "\n", |
| 230 | + "# Warmup runs to compile the function\n", |
| 231 | + "for _ in range(5):\n", |
| 232 | + " opt_eager.step()\n", |
| 233 | + " compiled_adam(*inputs)\n", |
| 234 | + "\n", |
| 235 | + "for eager_p, compile_p in zip(opt_eager.param_groups[0][\"params\"], opt_eager_copy.param_groups[0][\"params\"]):\n", |
| 236 | + " torch.allclose(eager_p, compile_p)\n", |
| 237 | + "\n", |
| 238 | + "# Benchmark performance\n", |
| 239 | + "\n", |
| 240 | + " # Let's define a helpful benchmarking function:\n", |
| 241 | + "import torch.utils.benchmark as benchmark\n", |
| 242 | + "\n", |
| 243 | + "def benchmark_torch_function_in_microseconds(f, *args, **kwargs):\n", |
| 244 | + " t0 = benchmark.Timer(\n", |
| 245 | + " stmt=\"f(*args, **kwargs)\", globals={\"args\": args, \"kwargs\": kwargs, \"f\": f}\n", |
| 246 | + " )\n", |
| 247 | + " return t0.blocked_autorange().mean * 1e6\n", |
| 248 | + "\n", |
| 249 | + "eager_runtime = benchmark_torch_function_in_microseconds(opt_eager.step)\n", |
| 250 | + "compiled_runtime = benchmark_torch_function_in_microseconds(lambda: compiled_adam(*inputs))\n", |
| 251 | + "\n", |
| 252 | + "assert eager_runtime > compiled_runtime\n", |
| 253 | + " \n", |
| 254 | + "print(f\"eager runtime: {eager_runtime}us\")\n", |
| 255 | + "print(f\"compiled runtime: {compiled_runtime}us\")" |
| 256 | + ] |
| 257 | + }, |
| 258 | + { |
| 259 | + "cell_type": "markdown", |
| 260 | + "metadata": {}, |
| 261 | + "source": [ |
| 262 | + "Conclusion\n", |
| 263 | + "==========\n", |
| 264 | + "\n", |
| 265 | + "In this tutorial, we successfully implemented a custom fully-fused Adam\n", |
| 266 | + "optimizer using foreach\\_map. By leveraging the power of foreach\\_map\n", |
| 267 | + "and torch.compile, we were able to create an optimized version of the\n", |
| 268 | + "Adam optimizer that can be used in various machine learning\n", |
| 269 | + "applications. This tutorial provides a comprehensive guide on how to use\n", |
| 270 | + "foreach\\_map and torch.compile to optimize machine learning models, and\n", |
| 271 | + "serves as a valuable resource for developers looking to improve the\n", |
| 272 | + "performance of their models with horizontal fusion.\n", |
| 273 | + "\n", |
| 274 | + "See also:\n", |
| 275 | + "\n", |
| 276 | + "- [Compiled optimizer\n", |
| 277 | + " tutorial](https://pytorch.org/tutorials/recipes/compiling_optimizer.html) -\n", |
| 278 | + " an intro into the compiled optimizer.\n", |
| 279 | + "- [Compiling the optimizer with\n", |
| 280 | + " PT2](https://dev-discuss.pytorch.org/t/compiling-the-optimizer-with-pt2/1669) -\n", |
| 281 | + " deeper technical details on the compiled optimizer.\n" |
| 282 | + ] |
| 283 | + } |
| 284 | + ], |
| 285 | + "metadata": { |
| 286 | + "kernelspec": { |
| 287 | + "display_name": "Python 3", |
| 288 | + "language": "python", |
| 289 | + "name": "python3" |
| 290 | + }, |
| 291 | + "language_info": { |
| 292 | + "codemirror_mode": { |
| 293 | + "name": "ipython", |
| 294 | + "version": 3 |
| 295 | + }, |
| 296 | + "file_extension": ".py", |
| 297 | + "mimetype": "text/x-python", |
| 298 | + "name": "python", |
| 299 | + "nbconvert_exporter": "python", |
| 300 | + "pygments_lexer": "ipython3", |
| 301 | + "version": "3.10.12" |
| 302 | + } |
| 303 | + }, |
| 304 | + "nbformat": 4, |
| 305 | + "nbformat_minor": 0 |
| 306 | +} |
0 commit comments