|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +""" |
| 4 | +Using User-Defined Triton Kernels with ``torch.compile`` |
| 5 | +========================================================= |
| 6 | +**Author:** `Oguz Ulgen <https://github.com/oulgen>`_ |
| 7 | +""" |
| 8 | + |
| 9 | +###################################################################### |
| 10 | +# User-defined Triton kernels can be used to optimize specific parts of your |
| 11 | +# model's computation. These kernels are written in Triton's language, which is designed |
| 12 | +# to make it easier to achieve peak hardware performance. By using user-defined Triton |
| 13 | +# kernels with ``torch.compile``, you can integrate these optimized computations into |
| 14 | +# your PyTorch model, potentially achieving significant performance improvements. |
| 15 | +# |
| 16 | +# This recipes demonstrates how you can use user-defined Triton kernels with ``torch.compile``. |
| 17 | +# |
| 18 | +# Prerequisites |
| 19 | +# ------------------- |
| 20 | +# |
| 21 | +# Before starting this recipe, make sure that you have the following: |
| 22 | +# |
| 23 | +# * Basic understanding of ``torch.compile`` and Triton. See: |
| 24 | +# |
| 25 | +# * `torch.compiler API documentation <https://pytorch.org/docs/stable/torch.compiler.html#torch-compiler>`__ |
| 26 | +# * `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__ |
| 27 | +# * `Triton language documentation <https://triton-lang.org/main/index.html>`__ |
| 28 | +# |
| 29 | +# * PyTorch 2.3 or later |
| 30 | +# * A GPU that supports Triton |
| 31 | +# |
| 32 | + |
| 33 | +import torch |
| 34 | +from torch.utils._triton import has_triton |
| 35 | + |
| 36 | +###################################################################### |
| 37 | +# Basic Usage |
| 38 | +# -------------------- |
| 39 | +# |
| 40 | +# In this example, we will use a simple vector addition kernel from the Triton documentation |
| 41 | +# with ``torch.compile``. |
| 42 | +# For reference, see `Triton documentation <https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html>`__. |
| 43 | +# |
| 44 | + |
| 45 | +if not has_triton(): |
| 46 | + print("Skipping because triton is not supported on this device.") |
| 47 | +else: |
| 48 | + import triton |
| 49 | + from triton import language as tl |
| 50 | + |
| 51 | + @triton.jit |
| 52 | + def add_kernel( |
| 53 | + in_ptr0, |
| 54 | + in_ptr1, |
| 55 | + out_ptr, |
| 56 | + n_elements, |
| 57 | + BLOCK_SIZE: "tl.constexpr", |
| 58 | + ): |
| 59 | + pid = tl.program_id(axis=0) |
| 60 | + block_start = pid * BLOCK_SIZE |
| 61 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 62 | + mask = offsets < n_elements |
| 63 | + x = tl.load(in_ptr0 + offsets, mask=mask) |
| 64 | + y = tl.load(in_ptr1 + offsets, mask=mask) |
| 65 | + output = x + y |
| 66 | + tl.store(out_ptr + offsets, output, mask=mask) |
| 67 | + |
| 68 | + @torch.compile(fullgraph=True) |
| 69 | + def add_fn(x, y): |
| 70 | + output = torch.zeros_like(x) |
| 71 | + n_elements = output.numel() |
| 72 | + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| 73 | + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4) |
| 74 | + return output |
| 75 | + |
| 76 | + x = torch.randn(4, device="cuda") |
| 77 | + y = torch.randn(4, device="cuda") |
| 78 | + out = add_fn(x, y) |
| 79 | + print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}") |
| 80 | + |
| 81 | +###################################################################### |
| 82 | +# Advanced Usage |
| 83 | +# ------------------------------------------------------------------- |
| 84 | +# |
| 85 | +# Triton's autotune feature is a powerful tool that automatically optimizes the configuration |
| 86 | +# parameters of your Triton kernels. It explores a range of possible configurations and |
| 87 | +# selects the one that delivers the best performance for your specific use case. |
| 88 | +# |
| 89 | +# When used with ``torch.compile``, ``triton.autotune`` can help ensure that your PyTorch |
| 90 | +# model is running as efficiently as possible. Here is an example of using ``torch.compile`` |
| 91 | +# and ``triton.autotune``. |
| 92 | +# |
| 93 | +# .. note:: |
| 94 | +# |
| 95 | +# ``torch.compile`` only supports configs and key arguments to ``triton.autotune``. |
| 96 | + |
| 97 | +if not has_triton(): |
| 98 | + print("Skipping because triton is not supported on this device.") |
| 99 | +else: |
| 100 | + import triton |
| 101 | + from triton import language as tl |
| 102 | + |
| 103 | + @triton.autotune( |
| 104 | + configs=[ |
| 105 | + triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8), |
| 106 | + triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4), |
| 107 | + triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8), |
| 108 | + triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4), |
| 109 | + ], |
| 110 | + key=[], |
| 111 | + ) |
| 112 | + @triton.jit |
| 113 | + def add_kernel_autotuned( |
| 114 | + in_ptr0, |
| 115 | + in_ptr1, |
| 116 | + out_ptr, |
| 117 | + n_elements, |
| 118 | + BLOCK_SIZE: "tl.constexpr", |
| 119 | + ): |
| 120 | + pid = tl.program_id(axis=0) |
| 121 | + block_start = pid * BLOCK_SIZE |
| 122 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 123 | + mask = offsets < n_elements |
| 124 | + x = tl.load(in_ptr0 + offsets, mask=mask) |
| 125 | + y = tl.load(in_ptr1 + offsets, mask=mask) |
| 126 | + output = x + y |
| 127 | + tl.store(out_ptr + offsets, output, mask=mask) |
| 128 | + |
| 129 | + @torch.compile(fullgraph=True) |
| 130 | + def add_fn(x, y): |
| 131 | + output = torch.zeros_like(x) |
| 132 | + n_elements = output.numel() |
| 133 | + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| 134 | + add_kernel_autotuned[grid](x, y, output, n_elements) |
| 135 | + return output |
| 136 | + |
| 137 | + x = torch.randn(4, device="cuda") |
| 138 | + y = torch.randn(4, device="cuda") |
| 139 | + out = add_fn(x, y) |
| 140 | + print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}") |
| 141 | + |
| 142 | +###################################################################### |
| 143 | +# Composibility and Limitations |
| 144 | +# -------------------------------------------------------------------- |
| 145 | +# |
| 146 | +# As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile`` |
| 147 | +# includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor. |
| 148 | +# You can use these features together to build complex, high-performance models. |
| 149 | +# |
| 150 | +# However, there are certain limitations to be aware of: |
| 151 | +# |
| 152 | +# * **Tensor Subclasses:** Currently, there is no support for |
| 153 | +# tensor subclasses and other advanced features. |
| 154 | +# * **Triton Features:** While ``triton.heuristics`` can be used either standalone or |
| 155 | +# before ``triton.autotune``, it cannot be used after ```triton.autotune``. This |
| 156 | +# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used |
| 157 | +# together, ``triton.heuristics`` must be used first. |
| 158 | +# |
| 159 | +# Conclusion |
| 160 | +# ----------- |
| 161 | +# In this recipe, we explored how to utilize user-defined Triton kernels |
| 162 | +# with ``torch.compile``. We delved into the basic usage of a simple |
| 163 | +# vector addition kernel and advanced usage involving Triton's autotune |
| 164 | +# feature. We also discussed the composability of user-defined Triton |
| 165 | +# kernels with other PyTorch features and highlighted some current limitations. |
| 166 | +# |
| 167 | +# See Also |
| 168 | +# --------- |
| 169 | +# |
| 170 | +# * `Compiling the Optimizers: <https://pytorch.org/tutorials/recipes/compiling_optimizer.html>`__ |
| 171 | +# * `Implementing High-Performance Transformers with Scaled Dot Product Attention<https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`__ |
0 commit comments