diff --git a/.github/scripts/install_triton.sh b/.github/scripts/install_triton.sh index 5e957e11..5dfbb5bd 100755 --- a/.github/scripts/install_triton.sh +++ b/.github/scripts/install_triton.sh @@ -1,11 +1,37 @@ #!/bin/bash set -ex + +# Parse command line arguments +USE_CPU_BACKEND=false +while [[ $# -gt 0 ]]; do + case $1 in + --cpu) + USE_CPU_BACKEND=true + shift + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + ( mkdir -p /tmp/$USER pushd /tmp/$USER pip uninstall -y triton pytorch-triton || true rm -rf triton/ || true - git clone https://github.com/triton-lang/triton.git # install triton latest main + + # Clone the appropriate repository based on backend + if [ "$USE_CPU_BACKEND" = true ]; then + # Install triton-cpu from triton-cpu repository + git clone --recursive https://github.com/triton-lang/triton-cpu.git triton + else + # Install triton from main repository for GPU backend + git clone https://github.com/triton-lang/triton.git triton + fi + + # Shared build process for both backends ( pushd triton/ conda config --set channel_priority strict @@ -14,10 +40,14 @@ set -ex conda install -y -c conda-forge gcc_linux-64=13 gxx_linux-64=13 gcc=13 gxx=13 pip install -r python/requirements.txt # Use TRITON_PARALLEL_LINK_JOBS=2 to avoid OOM on CPU CI machines - MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 pip install . # install to conda site-packages/ folder + if [ "$USE_CPU_BACKEND" = true ]; then + MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 pip install -e python # install to conda site-packages/ folder + else + MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 pip install . # install to conda site-packages/ folder + fi popd ) - rm -rf triton/ + #rm -rf triton/ popd ) exit 0 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9850ae94..4e35e931 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -39,3 +39,32 @@ jobs: ./.github/scripts/install_triton.sh pip install -r requirements.txt python -m unittest discover -s test/ -p "*.py" -v -t . + + test_cpu_triton: + name: test-cpu-py${{ matrix.python-version }}-triton-cpu + strategy: + fail-fast: true + matrix: + python-version: ["3.12"] + include: + - name: A10G + runs-on: linux.g5.4xlarge.nvidia.gpu + torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu126' + gpu-arch-type: "cuda" + gpu-arch-version: "12.6" + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + timeout: 120 + runner: ${{ matrix.runs-on }} + gpu-arch-type: ${{ matrix.gpu-arch-type }} + gpu-arch-version: ${{ matrix.gpu-arch-version }} + submodules: recursive + script: | + conda create -n venv python=${{ matrix.python-version }} -y + conda activate venv + python -m pip install --upgrade pip + pip install ${{ matrix.torch-spec }} + time ./.github/scripts/install_triton.sh --cpu + pip install -r requirements.txt + pip install pytest pytest-timeout + TRITON_CPU_BACKEND=1 pytest --timeout 60 test diff --git a/helion/_testing.py b/helion/_testing.py index 21ecb714..6fd31d27 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -1,8 +1,12 @@ from __future__ import annotations +import functools import importlib +import os import sys from typing import TYPE_CHECKING +from typing import Callable +import unittest import torch @@ -15,7 +19,19 @@ from .runtime.kernel import Kernel -DEVICE = torch.device("cuda") +USE_TRITON_CPU_BACKEND: bool = os.environ.get("TRITON_CPU_BACKEND", "0") == "1" + +if USE_TRITON_CPU_BACKEND: + DEVICE = torch.device("cpu") +else: + DEVICE = torch.device("cuda") + + +skipIfTritonCpu: Callable[[Callable[..., object]], Callable[..., object]] = ( + functools.partial( + unittest.skipIf, USE_TRITON_CPU_BACKEND, "does not work with triton cpu" + ) +) def import_path(filename: Path) -> types.ModuleType: diff --git a/test/test_loops.py b/test/test_loops.py index 148251c3..5611d0dc 100644 --- a/test/test_loops.py +++ b/test/test_loops.py @@ -11,6 +11,7 @@ from helion._testing import DEVICE from helion._testing import code_and_output from helion._testing import import_path +from helion._testing import skipIfTritonCpu import helion.language as hl datadir = Path(__file__).parent / "data" @@ -154,6 +155,7 @@ def _device_loop_3d_make_precompiler(x: torch.Tensor): return make_precompiler(_device_loop_3d_kernel)(x, out, out.stride(0), out.stride(1), out.stride(2), out.stride(3), x.stride(0), x.stride(1), x.stride(2), x.stride(3), b, c, d, _BLOCK_SIZE_3, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""", ) + @skipIfTritonCpu() def test_3d_device_loop1(self): args = (torch.randn([128, 128, 128, 128], device=DEVICE),) code, result = code_and_output( @@ -263,6 +265,7 @@ def _device_loop_3d_make_precompiler(x: torch.Tensor): return make_precompiler(_device_loop_3d_kernel)(x, out, out.stride(0), out.stride(1), out.stride(2), out.stride(3), x.stride(0), x.stride(1), x.stride(2), x.stride(3), a, b, c, d, _BLOCK_SIZE_0, _BLOCK_SIZE_1_2_3, num_warps=4, num_stages=3)""", ) + @skipIfTritonCpu() def test_3d_device_loop3(self): args = (torch.randn([128, 128, 128, 128], device=DEVICE),) code, result = code_and_output( diff --git a/test/test_print.py b/test/test_print.py index 9ba7e67a..1975eb6c 100644 --- a/test/test_print.py +++ b/test/test_print.py @@ -12,6 +12,7 @@ import helion from helion._testing import DEVICE from helion._testing import code_and_output +from helion._testing import skipIfTritonCpu import helion.language as hl @@ -107,6 +108,7 @@ def run_test_with_and_without_triton_interpret_envvar(self, test_func): else: os.environ["TRITON_INTERPRET"] = original_env + @skipIfTritonCpu() def test_basic_print(self): """Test basic print with prefix and tensor values"""