diff --git a/examples/helion_puzzles.ipynb b/examples/helion_puzzles.ipynb new file mode 100644 index 00000000..947bb175 --- /dev/null +++ b/examples/helion_puzzles.ipynb @@ -0,0 +1,470 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Helion Puzzles\n", + "\n", + "Programming for accelerators such as GPUs is critical for modern AI systems. This often means programming directly in proprietary low-level languages such as CUDA. Helion is a Python-embedded domain-specific language (DSL) for authoring machine learning kernels, designed to compile down to Triton, a performant backend for programming GPUs and other devices.\n", + "\n", + "Helion aims to raise the level of abstraction compared to Triton, making it easier to write correct and efficient kernels while enabling more automation in the autotuning process.\n", + "\n", + "This set of puzzles is meant to teach you how to use Helion from first principles in an interactive fashion. You will start with trivial examples and build your way up to real algorithms like Flash Attention and Quantized neural networks." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "First, let's install the necessary dependencies. Helion requires a recent version of PyTorch and a development version of Triton." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "# Only need to run the first time.\n", + "!pip install torch\n", + "# !pip install git+https://github.com/triton-lang/triton.git\n", + "!pip install git+https://github.com/pytorch-labs/helion.git" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import helion\n", + "import helion.language as hl\n", + "from torch import Tensor" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's also create a simple testing function to verify our implementations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from triton.testing import do_bench\n", + "def test_kernel(kernel_fn, spec_fn, *args):\n", + " \"\"\"Test a Helion kernel against a reference implementation.\"\"\"\n", + " # Run our implementation\n", + " result = kernel_fn(*args)\n", + " # Run reference implementation\n", + " expected = spec_fn(*args)\n", + "\n", + " # Check if results match\n", + " torch.testing.assert_close(result, expected)\n", + " print(\"✅ Results Match ✅\")\n", + "\n", + "\n", + "def benchmark_kernel(kernel_fn, *args, **kwargs):\n", + " \"\"\"Benchmark a Helion kernel.\"\"\"\n", + " no_args = lambda: kernel_fn(*args, **kwargs)\n", + " time_in_ms = do_bench(no_args)\n", + " print(f\"⏱ Time: {time_in_ms} ms\")\n", + "\n", + "\n", + "def compare_implementations(kernel_fn, spec_fn, *args, **kwargs):\n", + " \"\"\"Benchmark a Helion kernel and its reference implementation.\"\"\"\n", + " kernel_no_args = lambda: kernel_fn(*args, **kwargs)\n", + " spec_no_args = lambda: spec_fn(*args, **kwargs)\n", + " kernel_time = do_bench(kernel_no_args)\n", + " spec_time = do_bench(spec_no_args)\n", + " print(f\"⏱ Helion Kernel Time: {kernel_time:.3f} ms, PyTorch Reference Time: {spec_time:.3f} ms, Speedup: {spec_time/kernel_time:.3f}x\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Introduction to Helion\n", + "\n", + "Helion allows you to write GPU kernels using familiar PyTorch syntax. The code outside the `for` loops is standard PyTorch code executed on the CPU. The code inside the `for` loops is compiled into a Triton kernel, resulting in a single GPU kernel.\n", + "\n", + "Unlike raw Triton, Helion handles memory management, tiling, and other low-level details automatically. This allows you to focus on the algorithm rather than the implementation details.\n", + "\n", + "## Basic Structure of a Helion Kernel\n", + "\n", + "A Helion kernel consists of two main parts:\n", + "1. **Host Code**: Standard PyTorch code executed on the CPU (outside the loops)\n", + "2. **Device Code**: Operations inside `hl.tile()` loops that execute on the GPU\n", + "\n", + "Let's examine a simple example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "output": { + "id": 1048275227229175, + "loadingStatus": "loaded" + } + }, + "outputs": [], + "source": [ + "@helion.kernel(config=helion.Config(block_sizes = [128, 128])) # The @helion.kernel decorator marks this function for compilation\n", + "def example_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n", + " # Host code: Standard PyTorch operations\n", + " m, n = x.size()\n", + " out = torch.empty_like(x) # Allocate output tensor\n", + "\n", + " # The hl.tile loop defines the parallel execution structure\n", + " for tile_m, tile_n in hl.tile([m, n]):\n", + " # Device code: Everything inside the hl.tile loop runs on GPU\n", + " out[tile_m, tile_n] = x[tile_m, tile_n] + y[tile_m, tile_n] # Simple element-wise addition expressed w/ pytorch ops\n", + "\n", + " return out # Return the result back to the host\n", + "\n", + "# Create some sample data\n", + "x = torch.randn(10, 10, device=\"cuda\")\n", + "y = torch.randn(10, 10, device=\"cuda\")\n", + "\n", + "# Run the kernel\n", + "result = example_add(x, y)\n", + "\n", + "# Verify result\n", + "expected = x + y\n", + "torch.testing.assert_close(result, expected)\n", + "print(\"✅ Results Match ✅\")\n", + "benchmark_kernel(example_add, x, y)\n", + "compare_implementations(example_add, torch.add, x, y)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Autotuning in Helion\n", + "\n", + "In the previous example, we explicitly specified a configuration using `config=helion.Config(block_sizes=[128, 128])`. This bypasses Helion's autotuning mechanism and uses our predefined settings. While this is quick to run, manually choosing optimal parameters can be challenging and hardware-dependent.\n", + "\n", + "### What is Autotuning?\n", + "\n", + "Autotuning is Helion's process of automatically finding the best configuration parameters for your specific:\n", + "- Hardware (GPU model)\n", + "- Problem size\n", + "- Operation patterns\n", + "\n", + "When you omit the `config` parameter, Helion will automatically search for the optimal configuration:\n", + "\n", + "```python\n", + "@helion.kernel() # No config = automatic tuning\n", + "def autotuned_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n", + " m, n = x.size()\n", + " out = torch.empty_like(x)\n", + " for tile_m, tile_n in hl.tile([m, n]):\n", + " out[tile_m, tile_n] = x[tile_m, tile_n] + y[tile_m, tile_n]\n", + " return out\n", + "```\n", + "\n", + "Feel free to remove the above code to see how much more performant it is than the original, although be warned it might take some time 😃" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's move on to our puzzles!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Puzzle 1: Constant Add\n", + "\n", + "Add a constant to a vector." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "output": { + "id": 1202076708065505, + "loadingStatus": "loaded" + } + }, + "outputs": [], + "source": [ + "def add_spec(x: Tensor) -> Tensor:\n", + " \"\"\"This is the spec that you should implement.\"\"\"\n", + " return x + 10.\n", + "\n", + " # ---- ✨ Is this the best block size? ----\n", + "@helion.kernel(config = helion.Config(block_sizes = [1,]))\n", + "def add_kernel(x: torch.Tensor) -> torch.Tensor:\n", + " # ---- ✨ Your Code Here ✨----\n", + " # Set up the output buffer which you will return\n", + "\n", + " # Use Helion to tile the computation\n", + " for tile_n in hl.tile(TILE_RANGE):\n", + " # ---- ✨ Your Code Here ✨----\n", + " # Get the tile from x and add 10\n", + "\n", + " return out\n", + "\n", + "# Test the kernel\n", + "x = torch.randn(8192, device=\"cuda\")\n", + "test_kernel(add_kernel, add_spec, x)\n", + "benchmark_kernel(add_kernel, x)\n", + "compare_implementations(add_kernel, add_spec, x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Puzzle 2: Outer Vector Add\n", + "\n", + "Add two vectors using an outer product pattern." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "output": { + "id": 1745892802671787, + "loadingStatus": "loaded" + } + }, + "outputs": [], + "source": [ + "def broadcast_add_spec(x: Tensor, y: Tensor) -> Tensor:\n", + " return x[None, :] + y[:, None]\n", + "\n", + " # ---- ✨ What should the block sizes be? ----\n", + "@helion.kernel(config = helion.Config(block_sizes = []))\n", + "def broadcast_add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n", + " # Get tensor sizes\n", + " # ---- ✨ Your Code Here ✨----\n", + "\n", + " return out\n", + "\n", + "# Test the kernel\n", + "x = torch.randn(1142, device=\"cuda\")\n", + "y = torch.randn(512, device=\"cuda\")\n", + "test_kernel(broadcast_add_kernel, broadcast_add_spec, x, y)\n", + "benchmark_kernel(broadcast_add_kernel, x, y)\n", + "compare_implementations(broadcast_add_kernel, broadcast_add_spec, x, y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Puzzle 3: Fused Outer Multiplication\n", + "\n", + "Multiply a row vector to a column vector and take a relu." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def mul_relu_block_spec(x: Tensor, y: Tensor) -> Tensor:\n", + " return torch.relu(x[None, :] * y[:, None])\n", + "\n", + "\n", + " # ---- ✨ Is this the best block size? ----\n", + "@helion.kernel(config = helion.Config(block_sizes = [32, 32]))\n", + "def mul_relu_block_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n", + " return out\n", + "\n", + "# Test the kernel\n", + "x = torch.randn(512, device=\"cuda\")\n", + "y = torch.randn(512, device=\"cuda\")\n", + "test_kernel(mul_relu_block_kernel, mul_relu_block_spec, x, y)\n", + "compare_implementations(mul_relu_block_kernel, mul_relu_block_spec, x, y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Puzzle 4: Long Sum\n", + "\n", + "Sum of a batch of numbers. TODO Give good example of how this reduction is done and why we need to register blocks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def sum_spec(x: Float32[Tensor, \"4 200\"]) -> Float32[Tensor, \"4\"]:\n", + " return x.sum(1)\n", + "\n", + " # ---- ✨ Your Code Here ✨----\n", + "@helion.kernel()\n", + "def sum_kernel(x: torch.Tensor) -> torch.Tensor:\n", + " # Get tensor sizes\n", + " # ---- ✨ Your Code Here ✨----\n", + "\n", + " return out\n", + "\n", + "# Test the kernel\n", + "x = torch.randn(4, 200, device=\"cuda\")\n", + "test_kernel(sum_kernel, sum_spec, x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Autotuning in Helion\n", + "\n", + "One of the major advantages of Helion is its sophisticated autotuning capability. Let's see how we can leverage this for our matrix multiplication kernel:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import helion\n", + "import helion.language as hl\n", + "import time\n", + "\n", + "# Define a matrix multiplication kernel\n", + "@helion.kernel() # No config means autotuning will be used\n", + "def matmul_autotune(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n", + " m, k = x.size()\n", + " k, n = y.size()\n", + " out = torch.empty([m, n], dtype=x.dtype, device=x.device)\n", + "\n", + " for tile_m, tile_n in hl.tile([m, n]):\n", + " acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)\n", + " for tile_k in hl.tile(k):\n", + " acc = acc + torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n])\n", + " out[tile_m, tile_n] = acc\n", + "\n", + " return out\n", + "\n", + "# Create larger tensors for better autotuning results\n", + "x = torch.randn(1024, 1024, device=\"cuda\")\n", + "y = torch.randn(1024, 1024, device=\"cuda\")\n", + "\n", + "# First run will trigger autotuning\n", + "print(\"Running with autotuning (this might take a while)...\")\n", + "start = time.time()\n", + "result = matmul_autotune(x, y)\n", + "end = time.time()\n", + "print(f\"First run time (including autotuning): {end - start:.2f}s\")\n", + "\n", + "# Second run will use the tuned configuration\n", + "start = time.time()\n", + "result = matmul_autotune(x, y)\n", + "end = time.time()\n", + "print(f\"Second run time (using tuned config): {end - start:.2f}s\")\n", + "\n", + "# Verify correctness\n", + "expected = x @ y\n", + "print(f\"Result is correct: {torch.allclose(result, expected, rtol=1e-2, atol=1e-2)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Hardcoding Configurations\n", + "\n", + "After autotuning, you might want to hardcode the best configuration:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Example of hardcoding a configuration after autotuning\n", + "@helion.kernel(config=helion.Config(\n", + " block_sizes=[[64, 128], [16]],\n", + " loop_orders=[[1, 0]],\n", + " num_warps=4,\n", + " num_stages=3,\n", + " indexing='block_ptr',\n", + " l2_grouping=32\n", + "))\n", + "def matmul_fixed_config(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n", + " m, k = x.size()\n", + " k, n = y.size()\n", + " out = torch.empty([m, n], dtype=x.dtype, device=x.device)\n", + "\n", + " for tile_m, tile_n in hl.tile([m, n]):\n", + " acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)\n", + " for tile_k in hl.tile(k):\n", + " acc = acc + torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n])\n", + " out[tile_m, tile_n] = acc\n", + "\n", + " return out\n", + "\n", + "# Run with fixed configuration (no autotuning)\n", + "start = time.time()\n", + "result = matmul_fixed_config(x, y)\n", + "end = time.time()\n", + "print(f\"Run time with fixed config: {end - start:.2f}s\")\n", + "\n", + "# Verify correctness\n", + "expected = x @ y\n", + "print(f\"Result is correct: {torch.allclose(result, expected, rtol=1e-2, atol=1e-2)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "In this notebook, we've explored how to use Helion to write efficient GPU kernels using a high-level, PyTorch-like syntax. The key advantages of Helion include:\n", + "\n", + "1. **Higher-level abstraction** than raw Triton, making it easier to write correct kernels\n", + "2. **Automatic tiling and memory management**, eliminating a common source of bugs\n", + "3. **Powerful autotuning** that can explore a wide range of implementations automatically\n", + "4. **Familiar PyTorch syntax** that builds on existing knowledge\n", + "\n", + "These puzzles should give you a good foundation for writing your own Helion kernels for a variety of applications." + ] + } + ], + "metadata": { + "fileHeader": "", + "fileUid": "015576dc-f856-4f67-b6b8-2116e623b467", + "isAdHoc": false, + "kernelspec": { + "display_name": "nightly (local)", + "language": "python", + "name": "nightly_local" + }, + "language_info": { + "name": "python" + }, + "notebookId": "1739719030234768", + "notebookNumber": "N7197491", + "orig_nbformat": 4 + } +} diff --git a/pyproject.toml b/pyproject.toml index 8e9a42f9..a798891a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ src = ["helion"] docstring-code-format = true quote-style = "double" line-ending = "lf" -exclude = [".github/*"] +exclude = [".github/*", "*.ipynb"] [tool.ruff.lint] select = [ @@ -62,7 +62,7 @@ ignore = [ ] extend-safe-fixes = ["TC", "UP045", "RUF013", "RSE102"] preview = true -exclude = ["test/data/*", ".github/*"] +exclude = ["test/data/*", ".github/*", "", "*.ipynb"] [tool.ruff.lint.isort] extra-standard-library = ["typing_extensions"]