Skip to content

Commit 6771cf5

Browse files
oulgensvekars
andauthored
Add tutorial for user defined triton kernels (#2783)
* Add tutorial for user defined triton kernels --------- Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent 5fbef68 commit 6771cf5

File tree

4 files changed

+187
-0
lines changed

4 files changed

+187
-0
lines changed

.jenkins/metadata.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
"intermediate_source/scaled_dot_product_attention_tutorial.py": {
4646
"needs": "linux.g5.4xlarge.nvidia.gpu"
4747
},
48+
"recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py": {
49+
"needs": "linux.g5.4xlarge.nvidia.gpu"
50+
},
4851
"prototype_source/gpu_quantization_torchao_tutorial.py": {
4952
"needs": "linux.g5.4xlarge.nvidia.gpu"
5053
}

en-wordlist.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Chatbots
3434
Chen
3535
Colab
3636
Colorectal
37+
Composibility
3738
Conda
3839
Conv
3940
ConvNet
@@ -270,6 +271,7 @@ approximators
270271
autodiff
271272
autoencoder
272273
autograd
274+
autotune
273275
autotuner
274276
backend
275277
backends
@@ -303,6 +305,7 @@ composable
303305
concat
304306
conda
305307
config
308+
configs
306309
contrastive
307310
conv
308311
convolutional
@@ -551,6 +554,7 @@ torchviz
551554
traceback
552555
tradeoff
553556
tradeoffs
557+
triton
554558
uint
555559
umap
556560
uncomment

recipes_source/recipes_index.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
307307
:link: ../recipes/compiling_optimizer.html
308308
:tags: Model-Optimization
309309

310+
.. Using User-Defined Triton Kernels with ``torch.compile``
311+
312+
.. customcarditem::
313+
:header: Using User-Defined Triton Kernels with ``torch.compile``
314+
:card_description: Learn how to use user-defined kernels with ``torch.compile``
315+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
316+
:link: ../recipes/torch_compile_user_defined_triton_kernel_tutorial.html
317+
:tags: Model-Optimization
318+
310319
.. Intel(R) Extension for PyTorch*
311320
312321
.. customcarditem::
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
Using User-Defined Triton Kernels with ``torch.compile``
5+
=========================================================
6+
**Author:** `Oguz Ulgen <https://github.com/oulgen>`_
7+
"""
8+
9+
######################################################################
10+
# User-defined Triton kernels can be used to optimize specific parts of your
11+
# model's computation. These kernels are written in Triton's language, which is designed
12+
# to make it easier to achieve peak hardware performance. By using user-defined Triton
13+
# kernels with ``torch.compile``, you can integrate these optimized computations into
14+
# your PyTorch model, potentially achieving significant performance improvements.
15+
#
16+
# This recipes demonstrates how you can use user-defined Triton kernels with ``torch.compile``.
17+
#
18+
# Prerequisites
19+
# -------------------
20+
#
21+
# Before starting this recipe, make sure that you have the following:
22+
#
23+
# * Basic understanding of ``torch.compile`` and Triton. See:
24+
#
25+
# * `torch.compiler API documentation <https://pytorch.org/docs/stable/torch.compiler.html#torch-compiler>`__
26+
# * `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__
27+
# * `Triton language documentation <https://triton-lang.org/main/index.html>`__
28+
#
29+
# * PyTorch 2.3 or later
30+
# * A GPU that supports Triton
31+
#
32+
33+
import torch
34+
from torch.utils._triton import has_triton
35+
36+
######################################################################
37+
# Basic Usage
38+
# --------------------
39+
#
40+
# In this example, we will use a simple vector addition kernel from the Triton documentation
41+
# with ``torch.compile``.
42+
# For reference, see `Triton documentation <https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html>`__.
43+
#
44+
45+
if not has_triton():
46+
print("Skipping because triton is not supported on this device.")
47+
else:
48+
import triton
49+
from triton import language as tl
50+
51+
@triton.jit
52+
def add_kernel(
53+
in_ptr0,
54+
in_ptr1,
55+
out_ptr,
56+
n_elements,
57+
BLOCK_SIZE: "tl.constexpr",
58+
):
59+
pid = tl.program_id(axis=0)
60+
block_start = pid * BLOCK_SIZE
61+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
62+
mask = offsets < n_elements
63+
x = tl.load(in_ptr0 + offsets, mask=mask)
64+
y = tl.load(in_ptr1 + offsets, mask=mask)
65+
output = x + y
66+
tl.store(out_ptr + offsets, output, mask=mask)
67+
68+
@torch.compile(fullgraph=True)
69+
def add_fn(x, y):
70+
output = torch.zeros_like(x)
71+
n_elements = output.numel()
72+
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
73+
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
74+
return output
75+
76+
x = torch.randn(4, device="cuda")
77+
y = torch.randn(4, device="cuda")
78+
out = add_fn(x, y)
79+
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
80+
81+
######################################################################
82+
# Advanced Usage
83+
# -------------------------------------------------------------------
84+
#
85+
# Triton's autotune feature is a powerful tool that automatically optimizes the configuration
86+
# parameters of your Triton kernels. It explores a range of possible configurations and
87+
# selects the one that delivers the best performance for your specific use case.
88+
#
89+
# When used with ``torch.compile``, ``triton.autotune`` can help ensure that your PyTorch
90+
# model is running as efficiently as possible. Here is an example of using ``torch.compile``
91+
# and ``triton.autotune``.
92+
#
93+
# .. note::
94+
#
95+
# ``torch.compile`` only supports configs and key arguments to ``triton.autotune``.
96+
97+
if not has_triton():
98+
print("Skipping because triton is not supported on this device.")
99+
else:
100+
import triton
101+
from triton import language as tl
102+
103+
@triton.autotune(
104+
configs=[
105+
triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
106+
triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
107+
triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
108+
triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
109+
],
110+
key=[],
111+
)
112+
@triton.jit
113+
def add_kernel_autotuned(
114+
in_ptr0,
115+
in_ptr1,
116+
out_ptr,
117+
n_elements,
118+
BLOCK_SIZE: "tl.constexpr",
119+
):
120+
pid = tl.program_id(axis=0)
121+
block_start = pid * BLOCK_SIZE
122+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
123+
mask = offsets < n_elements
124+
x = tl.load(in_ptr0 + offsets, mask=mask)
125+
y = tl.load(in_ptr1 + offsets, mask=mask)
126+
output = x + y
127+
tl.store(out_ptr + offsets, output, mask=mask)
128+
129+
@torch.compile(fullgraph=True)
130+
def add_fn(x, y):
131+
output = torch.zeros_like(x)
132+
n_elements = output.numel()
133+
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
134+
add_kernel_autotuned[grid](x, y, output, n_elements)
135+
return output
136+
137+
x = torch.randn(4, device="cuda")
138+
y = torch.randn(4, device="cuda")
139+
out = add_fn(x, y)
140+
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
141+
142+
######################################################################
143+
# Composibility and Limitations
144+
# --------------------------------------------------------------------
145+
#
146+
# As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile``
147+
# includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor.
148+
# You can use these features together to build complex, high-performance models.
149+
#
150+
# However, there are certain limitations to be aware of:
151+
#
152+
# * **Tensor Subclasses:** Currently, there is no support for
153+
# tensor subclasses and other advanced features.
154+
# * **Triton Features:** While ``triton.heuristics`` can be used either standalone or
155+
# before ``triton.autotune``, it cannot be used after ```triton.autotune``. This
156+
# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used
157+
# together, ``triton.heuristics`` must be used first.
158+
#
159+
# Conclusion
160+
# -----------
161+
# In this recipe, we explored how to utilize user-defined Triton kernels
162+
# with ``torch.compile``. We delved into the basic usage of a simple
163+
# vector addition kernel and advanced usage involving Triton's autotune
164+
# feature. We also discussed the composability of user-defined Triton
165+
# kernels with other PyTorch features and highlighted some current limitations.
166+
#
167+
# See Also
168+
# ---------
169+
#
170+
# * `Compiling the Optimizers: <https://pytorch.org/tutorials/recipes/compiling_optimizer.html>`__
171+
# * `Implementing High-Performance Transformers with Scaled Dot Product Attention<https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`__

0 commit comments

Comments
 (0)