From ae992417da5d3e59059d21ac77204d48b7dc75a7 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Mon, 9 Jun 2025 17:33:51 -0400 Subject: [PATCH 1/3] support TRITON_INTERPRET=1 --- triton_viz/core/trace.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/triton_viz/core/trace.py b/triton_viz/core/trace.py index 7e20ca2a..03970149 100644 --- a/triton_viz/core/trace.py +++ b/triton_viz/core/trace.py @@ -35,9 +35,15 @@ def add_client(self, new_client: Union[Client, str]) -> None: self.client_manager.add_clients([new_client_instance]) def __init__(self, kernel: JITFunction, client: Union[str, Client]) -> None: - assert isinstance(kernel, JITFunction), "Kernel must be a JITFunction" - self.interpreter_fn = InterpretedFunction(kernel.fn) self.fn = kernel + if isinstance(kernel, InterpretedFunction): + self.interpreter_fn = kernel + elif isinstance(kernel, JITFunction): + self.interpreter_fn = InterpretedFunction(kernel.fn) + else: + raise TypeError( + f"Kernel must be JITFunction or InterpretedFunction, got {type(kernel)}" + ) self.arg_names = kernel.arg_names self.client_manager = ClientManager() self.add_client(client) @@ -76,7 +82,7 @@ def decorator(kernel) -> Trace: return kernel # First-time wrapping - if isinstance(kernel, JITFunction): + if isinstance(kernel, (JITFunction, InterpretedFunction)): return Trace(kernel, clients) # If the object is already a Trace, just append the new client(s) @@ -85,8 +91,7 @@ def decorator(kernel) -> Trace: trace.add_client(clients) return trace - # If the object is neither a JITFunction nor Trace, raise an error - raise TypeError(f"Expected JITFunction, got {type(kernel)}") + raise TypeError(f"Expected JITFunction, InterpretedFunction or Trace, got {type(kernel)}") return decorator From 31062fd7a28e9b68e78ad8b4b03b10b6ca334543 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Mon, 9 Jun 2025 17:49:07 -0400 Subject: [PATCH 2/3] fix type annotation --- triton_viz/core/trace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/triton_viz/core/trace.py b/triton_viz/core/trace.py index 03970149..44ff8d5e 100644 --- a/triton_viz/core/trace.py +++ b/triton_viz/core/trace.py @@ -34,7 +34,7 @@ def add_client(self, new_client: Union[Client, str]) -> None: new_client_instance = self._normalize_client(new_client) self.client_manager.add_clients([new_client_instance]) - def __init__(self, kernel: JITFunction, client: Union[str, Client]) -> None: + def __init__(self, kernel: Union[JITFunction, InterpretedFunction], client: Union[str, Client]) -> None: self.fn = kernel if isinstance(kernel, InterpretedFunction): self.interpreter_fn = kernel From f177ce09bd2d98be5ecbcb7d593250bff30307b0 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 11 Jun 2025 14:20:32 -0700 Subject: [PATCH 3/3] [FIX] Fix CI (#74) --- .github/workflows/python-app.yml | 14 +++++++++----- tests/test_autotune_add.py | 15 ++++++++++----- tests/test_config.py | 3 ++- tests/test_print_traceback.py | 2 +- tests/test_wrapper.py | 1 + triton_viz/core/trace.py | 10 ++++++++-- 6 files changed, 31 insertions(+), 14 deletions(-) diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 3bc0f3e1..55a18749 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -5,11 +5,13 @@ name: Python application on: push: - branches-ignore: - - '**' + branches: + - main + - keren/v2.0 pull_request: - branches-ignore: - - '**' + branches: + - main + - keren/v2.0 permissions: contents: read @@ -21,6 +23,8 @@ concurrency: jobs: build: runs-on: ubuntu-latest + env: + TRITON_INTERPRET: "1" steps: - uses: actions/checkout@v3 @@ -44,7 +48,7 @@ jobs: pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 pip uninstall pytorch-triton -y - - name: Clone Triton and Install + - name: Install Triton run: | pip install triton==3.1.0 diff --git a/tests/test_autotune_add.py b/tests/test_autotune_add.py index e0ece54e..5c1cf79c 100644 --- a/tests/test_autotune_add.py +++ b/tests/test_autotune_add.py @@ -1,3 +1,4 @@ +import pytest import torch import triton import triton.language as tl @@ -7,8 +8,12 @@ from triton_viz import config as cfg -cfg.sanitizer_backend = "symexec" +try: + torch.cuda.current_device() +except: + pytest.skip("This test requires a CUDA-enabled environment.", allow_module_level=True) +cfg.sanitizer_backend = "symexec" @triton.autotune( configs=[ @@ -39,8 +44,8 @@ def test_autotune_add_inrange(): This test uses n_elements = 128, matching the size of the input tensors. It should NOT cause any out-of-bound access. """ - x = torch.randn(128, device="cuda") - y = torch.randn(128, device="cuda") + x = torch.randn(128) + y = torch.randn(128) out = torch.empty_like(x) # The kernel launch uses n_elements=128, aligned with the tensor size. @@ -55,8 +60,8 @@ def test_autotune_add_out_of_bound(): This test deliberately sets n_elements = 256, exceeding the actual buffer size (128). It will likely cause out-of-bound reads/writes, which may trigger errors or warnings. """ - x = torch.randn(128, device="cuda") - y = torch.randn(128, device="cuda") + x = torch.randn(128) + y = torch.randn(128) out = torch.empty_like(x) # The kernel launch uses n_elements=256, exceeding the valid tensor size. diff --git a/tests/test_config.py b/tests/test_config.py index c8926448..8ce9d9c9 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,4 +1,5 @@ -import pytest +import pytest, os +os.environ["TRITON_SANITIZER_BACKEND"] = "off" import triton_viz.core.config as cfg diff --git a/tests/test_print_traceback.py b/tests/test_print_traceback.py index 226d54e3..e09f999b 100644 --- a/tests/test_print_traceback.py +++ b/tests/test_print_traceback.py @@ -27,7 +27,7 @@ def kernel_A(ptr, n): def test_print_nested_functions(): - x = torch.arange(4, device="cuda", dtype=torch.float32) + x = torch.arange(4, dtype=torch.float32) print("Input:", x) # We'll launch a grid bigger than x.numel() to force a out-of-bounds error diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index 925c521f..54028c4a 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -59,6 +59,7 @@ def _decorator(fn): env = os.environ.copy() env["PYTHONPATH"] = str(tmp_path) + os.pathsep + env.get("PYTHONPATH", "") env["TRITON_SANITIZER_BACKEND"] = "symexec" + env["TRITON_INTERPRET"] = "1" # run the dummy program using triton-sanitizer cmd = ["triton-sanitizer", str(tmp_path / "dummy_program.py")] diff --git a/triton_viz/core/trace.py b/triton_viz/core/trace.py index 44ff8d5e..ef4f2ed1 100644 --- a/triton_viz/core/trace.py +++ b/triton_viz/core/trace.py @@ -34,7 +34,11 @@ def add_client(self, new_client: Union[Client, str]) -> None: new_client_instance = self._normalize_client(new_client) self.client_manager.add_clients([new_client_instance]) - def __init__(self, kernel: Union[JITFunction, InterpretedFunction], client: Union[str, Client]) -> None: + def __init__( + self, + kernel: Union[JITFunction, InterpretedFunction], + client: Union[str, Client], + ) -> None: self.fn = kernel if isinstance(kernel, InterpretedFunction): self.interpreter_fn = kernel @@ -91,7 +95,9 @@ def decorator(kernel) -> Trace: trace.add_client(clients) return trace - raise TypeError(f"Expected JITFunction, InterpretedFunction or Trace, got {type(kernel)}") + raise TypeError( + f"Expected JITFunction, InterpretedFunction or Trace, got {type(kernel)}" + ) return decorator