diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index f73b477b..220cd3af 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -5,9 +5,9 @@ name: Python application on: push: - branches: [ "main" ] + branches: [] pull_request: - branches: [ "main" ] + branches: [] permissions: contents: read @@ -30,11 +30,11 @@ jobs: with: python-version: '3.10' - - name: Lint with pre-commit - run: | - cd triton_viz - pip install pre-commit - pre-commit run --all-files + # - name: Lint with pre-commit + # run: | + # cd triton_viz + # pip install pre-commit + # pre-commit run --all-files - name: Install Dependencies if: steps.cache-pip.outputs.cache-hit != 'true' @@ -44,9 +44,7 @@ jobs: - name: Clone Triton and Install run: | - git clone https://github.com/openai/triton.git - cd triton/python - pip install -e . + pip install triton==3.1.0 - name: Install Triton-Viz run: | @@ -56,4 +54,4 @@ jobs: - name: Test with pytest run: | cd triton_viz - python -m pytest examples + python -m pytest tests diff --git a/tests/test_symbolic_execution.py b/examples/test_symbolic_execution.py similarity index 100% rename from tests/test_symbolic_execution.py rename to examples/test_symbolic_execution.py diff --git a/tests/test_autotune_add.py b/tests/test_autotune_add.py index 28881228..ff53fe94 100644 --- a/tests/test_autotune_add.py +++ b/tests/test_autotune_add.py @@ -2,10 +2,12 @@ import triton import triton.language as tl -# Example import of the Trace decorator with a sanitizer client -# Adjust according to your actual project structure import triton_viz from triton_viz.clients import Sanitizer +from triton_viz import config as cfg + + +cfg.sanitizer_backend = "symexec" @triton.autotune( configs=[ diff --git a/tests/test_print_traceback.py b/tests/test_print_traceback.py index 20e231a5..c90e9dfb 100644 --- a/tests/test_print_traceback.py +++ b/tests/test_print_traceback.py @@ -4,8 +4,11 @@ import triton_viz from triton_viz.clients import Sanitizer +from triton_viz import config as cfg +cfg.sanitizer_backend = "symexec" + @triton.jit def kernel_B(ptr, offset): # a simple function that adds 1 diff --git a/tests/test_trace_add_clients.py b/tests/test_trace_add_clients.py new file mode 100644 index 00000000..6ddd3fdc --- /dev/null +++ b/tests/test_trace_add_clients.py @@ -0,0 +1,43 @@ +import triton +import triton.language as tl + +import triton_viz +from triton_viz.clients import Sanitizer, Profiler, Tracer +from triton_viz import config as cfg + + +# Make sure sanitizer is on. +cfg.sanitizer_backend = "symexec" + +def test_trace_decorator_add_clients(): + """ + Test goal: + 1. Apply @trace("sanitizer") and @trace("profiler") to add the Sanitizer and Profiler clients. + 2. Apply @trace("tracer") to append a Tracer client. + 3. Apply @trace(("sanitizer",)) with a duplicate Sanitizer, which should be + ignored by the de-duplication logic. + + The final Trace object should contain exactly one instance each of + Sanitizer, Profiler, and Tracer (total = 3 clients). + """ + @triton_viz.trace("sanitizer") + @triton_viz.trace("profiler") + @triton_viz.trace("tracer") + @triton_viz.trace(Sanitizer(abort_on_error=True)) # Duplicate Sanitizer (should be ignored) + @triton.jit + def my_kernel(x_ptr, y_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + tl.store(out_ptr + offs, + tl.load(x_ptr + offs) + tl.load(y_ptr + offs)) + + # Should be wrapped as a Trace object. + from triton_viz.core.trace import Trace + assert isinstance(my_kernel, Trace) + + # Verify client de-duplication and addition logic + clients = my_kernel.client_manager.clients + assert len(clients) == 3 + assert sum(isinstance(c, Sanitizer) for c in clients) == 1 + assert sum(isinstance(c, Profiler) for c in clients) == 1 + assert sum(isinstance(c, Tracer) for c in clients) == 1 diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index 30fa1f70..3d541801 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -59,6 +59,7 @@ def _decorator(fn): # load sitecustomize.py env = os.environ.copy() env["PYTHONPATH"] = str(tmp_path) + os.pathsep + env.get("PYTHONPATH", "") + env["TRITON_SANITIZER_BACKEND"] = "symexec" # run the dummy program using triton-sanitizer cmd = ["triton-sanitizer", str(tmp_path / "dummy_program.py")] diff --git a/triton_viz/clients/sanitizer/sanitizer.py b/triton_viz/clients/sanitizer/sanitizer.py index db74404c..01153806 100644 --- a/triton_viz/clients/sanitizer/sanitizer.py +++ b/triton_viz/clients/sanitizer/sanitizer.py @@ -1,11 +1,11 @@ -import sys, os, datetime, traceback +import traceback +from abc import ABC from typing import Tuple, Callable, Optional, Type import numpy as np from anytree import Node, RenderTree from z3 import Solver, Int, IntVal, If, Sum, And, Or, Not, sat, simplify -import triton import triton.language as tl -from triton.runtime.interpreter import _get_np_dtype, TensorHandle +from triton.runtime.interpreter import TensorHandle from ...core.client import Client from ...core.data import ( @@ -17,7 +17,7 @@ CastImpl) from ..utils import check_out_of_bounds_access, check_storage_contiguous, get_physical_addr_from_tensor_slice, check_inner_stride_equal_to_one from .data import TracebackInfo, OutOfBoundsRecord, OutOfBoundsRecordBruteForce, OutOfBoundsRecordZ3 -from ...core.config import sanitizer_backend +from ...core import config as cfg def print_oob_record(oob_record: OutOfBoundsRecord, max_display=10): @@ -939,13 +939,37 @@ def op_cast_impl_overrider(src, dst_type): def finalize(self) -> list: return [] +class NullSanitizer: + """ + A do-nothing object returned when the sanitizer backend is 'off'. + Any attribute access raises an explicit error so misuse is obvious. + """ + def __getattr__(self, name): + raise RuntimeError( + "Sanitizer backend is off; no sanitizer functionality is available." + ) -def Sanitizer(abort_on_error=False): - if sanitizer_backend == "brute_force": - return SanitizerBruteForce(abort_on_error) - elif sanitizer_backend == "symexec": - return SanitizerSymbolicExecution(abort_on_error) - elif sanitizer_backend == "off": - return None - else: - raise ValueError(f"Invalid TRITON_SANITIZER_BACKEND: {sanitizer_backend}") +class Sanitizer(ABC): + """ + Factory class that returns the concrete sanitizer implementation + based on the value of ``cfg.sanitizer_backend``. + """ + def __new__(cls, abort_on_error: bool = False): + backend = cfg.sanitizer_backend + + if backend == "brute_force": + return SanitizerBruteForce(abort_on_error) + + if backend == "symexec": + return SanitizerSymbolicExecution(abort_on_error) + + if backend == "off": + return NullSanitizer() + + raise ValueError( + f"Invalid TRITON_SANITIZER_BACKEND: {backend!r} " + ) + +Sanitizer.register(SanitizerBruteForce) +Sanitizer.register(SanitizerSymbolicExecution) +Sanitizer.register(NullSanitizer) diff --git a/triton_viz/core/client.py b/triton_viz/core/client.py index d9048582..2c0d0591 100644 --- a/triton_viz/core/client.py +++ b/triton_viz/core/client.py @@ -4,7 +4,7 @@ from .data import Op, Launch from .patch import patch_op, unpatch_op, op_list, patch_calls -from typing import Tuple, Callable, Type, Optional +from typing import Tuple, Callable, Type, Optional, List class Client(ABC): @@ -32,10 +32,19 @@ def finalize(self) -> list: class ClientManager: - def __init__(self, clients: list[Client]): - self.clients = clients + def __init__(self, clients: Optional[List[Client]] = None): + self.clients = clients if clients is not None else [] self.launch = Launch() + def add_clients(self, new_clients: List[Client]) -> None: + for new_client in new_clients: + duplicate = any( + isinstance(existing_client, new_client.__class__) + for existing_client in self.clients + ) + if not duplicate: + self.clients.append(new_client) + @contextmanager def patch(self): with patch_calls(): diff --git a/triton_viz/core/config.py b/triton_viz/core/config.py index b81cedd2..59b84c9c 100644 --- a/triton_viz/core/config.py +++ b/triton_viz/core/config.py @@ -9,14 +9,7 @@ def __init__(self, name: str) -> None: super().__init__(name) # --- Sanitizer backend --- - env_backend = os.getenv("TRITON_SANITIZER_BACKEND", "") - if env_backend: - self.sanitizer_backend = env_backend # verify using setter - else: - raise ValueError( - f"TRITON_SANITIZER_BACKEND is not set!" - f"Available backends are: {AVAILABLE_SANITIZER_BACKENDS}" - ) + self._sanitizer_backend = os.getenv("TRITON_SANITIZER_BACKEND", "") or None # --- Grid execution progress flag --- env_flag = os.getenv("REPORT_GRID_EXECUTION_PROGRESS", "0") @@ -25,6 +18,11 @@ def __init__(self, name: str) -> None: # ---------- sanitizer_backend ---------- @property def sanitizer_backend(self) -> str: + if self._sanitizer_backend is None: + raise RuntimeError( + f"TRITON_SANITIZER_BACKEND is not set!" + f"Available backends are: {AVAILABLE_SANITIZER_BACKENDS}" + ) return self._sanitizer_backend @sanitizer_backend.setter diff --git a/triton_viz/core/patch.py b/triton_viz/core/patch.py index 9183fe81..5f1710ab 100644 --- a/triton_viz/core/patch.py +++ b/triton_viz/core/patch.py @@ -3,7 +3,7 @@ from typing import Callable, Type, Dict from tqdm import tqdm -from .config import report_grid_execution_progress, sanitizer_backend +from . import config as cfg from .data import ( Op, RawLoad, Load, RawStore, Store, UnaryOp, BinaryOp, TernaryOp, ProgramId, @@ -136,14 +136,14 @@ def _grid_executor_call(self, *args_dev, **kwargs): if kwargs.pop("warmup", False): return def run_grid_loops(): - for x in tqdm(range(grid[0]), desc='Grid X', leave=False, disable=not report_grid_execution_progress): - for y in tqdm(range(grid[1]), desc='Grid Y', leave=False, disable=not (report_grid_execution_progress and grid[1] > 1)): - for z in tqdm(range(grid[2]), desc='Grid Z', leave=False, disable=not (report_grid_execution_progress and grid[2] > 1)): + for x in tqdm(range(grid[0]), desc='Grid X', leave=False, disable=not cfg.report_grid_execution_progress): + for y in tqdm(range(grid[1]), desc='Grid Y', leave=False, disable=not (cfg.report_grid_execution_progress and grid[1] > 1)): + for z in tqdm(range(grid[2]), desc='Grid Z', leave=False, disable=not (cfg.report_grid_execution_progress and grid[2] > 1)): interpreter_builder.set_grid_idx(x, y, z) client_manager.grid_idx_callback((x, y, z)) self.fn(**call_args) # if symbolic execution, only do one iteration - if sanitizer_backend == "symexec": + if cfg.sanitizer_backend == "symexec": return # Removes not used reserved keywords from kwargs # Triton doesn't support keyword-only, variable positional or variable keyword arguments diff --git a/triton_viz/core/trace.py b/triton_viz/core/trace.py index 20d2c40c..9783bfb6 100644 --- a/triton_viz/core/trace.py +++ b/triton_viz/core/trace.py @@ -5,7 +5,7 @@ import os from typing import Tuple, Union -from .config import sanitizer_backend +from . import config as cfg from ..clients import Sanitizer, Profiler, Tracer from .client import ClientManager, Client from .data import Launch @@ -16,26 +16,33 @@ class Trace(KernelInterface): - def __init__(self, kernel: JITFunction, clients: Union[Tuple[Union[str, Client], ...], Union[str, Client]]) -> None: + @staticmethod + def _normalize_client(client: Union[str, Client]) -> Client: + if isinstance(client, str): + name = client.lower() + if name == "sanitizer": + return Sanitizer() + if name == "profiler": + return Profiler() + if name == "tracer": + return Tracer() + raise ValueError(f"Unknown client: {client}") + elif isinstance(client, Client): + return client + else: + raise TypeError(f"Expected str or Client, got {type(client)}") + + def add_client(self, new_client: Union[Client, str]) -> None: + new_client_instance = self._normalize_client(new_client) + self.client_manager.add_clients([new_client_instance]) + + def __init__(self, kernel: JITFunction, client: Union[str, Client]) -> None: assert isinstance(kernel, JITFunction), "Kernel must be a JITFunction" self.interpreter_fn = InterpretedFunction(kernel.fn) self.fn = kernel self.arg_names = kernel.arg_names - init_clients: list[Client] = [] - clients = (clients,) if not isinstance(clients, tuple) else clients - for client in clients: - if isinstance(client, str): - if client.lower() == "sanitizer": - init_clients.append(Sanitizer()) - elif client.lower() == "profiler": - init_clients.append(Profiler()) - elif client.lower() == "tracer": - init_clients.append(Tracer()) - else: - raise ValueError(f"Unknown client: {client}") - else: - init_clients.append(client) - self.client_manager = ClientManager(init_clients) + self.client_manager = ClientManager() + self.add_client(client) def run(self, *args, **kwargs): with self.client_manager.patch(): @@ -52,17 +59,36 @@ def finalize(self): launches.append(self.client_manager.launch) -def trace(clients: Union[Tuple[Union[str, Client], ...], Union[str, Client]] = ("sanitizer", "profiler")): +def trace(clients: Union[str, Client]): """ Create a trace object that can be used to run a kernel with instrumentation clients. :param kernel: The kernel to run. - :param clients: A tuple of clients to run with the kernel. + :param client: A client to run with the kernel. """ - def decorator(kernel: JITFunction) -> Trace: - if sanitizer_backend == "off": + if not clients: + raise ValueError("At least one client must be specified!") + + if not isinstance(clients, (str, Client)): + raise TypeError(f"Expected str or Client, got {type(clients)}") + + def decorator(kernel) -> Trace: + # When sanitizer is off, skip tracing and return the original kernel unchanged + if cfg.sanitizer_backend == "off": return kernel - return Trace(kernel, clients) + + # First-time wrapping + if isinstance(kernel, JITFunction): + return Trace(kernel, clients) + + # If the object is already a Trace, just append the new client(s) + if isinstance(kernel, Trace): + trace = kernel + trace.add_client(clients) + return trace + + # If the object is neither a JITFunction nor Trace, raise an error + raise TypeError(f"Expected JITFunction, got {type(kernel)}") return decorator