From 093777d08a885ee40de11504de485f8b84bd8278 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 4 Jun 2025 13:50:37 -0400 Subject: [PATCH 1/3] linting with pre-commit --- examples/elementwise_add_autotune.py | 13 +- examples/load_and_store.py | 10 +- {tests => examples}/test_non_contiguous.py | 19 +- examples/test_symbolic_execution.py | 70 ++-- requirements.txt | 10 +- tests/test_autotune_add.py | 35 +- tests/test_print_traceback.py | 5 +- tests/test_trace_add_clients.py | 10 +- tests/test_wrapper.py | 11 +- triton_viz/__init__.py | 2 +- triton_viz/clients/__init__.py | 9 +- triton_viz/clients/profiler/profiler.py | 8 +- triton_viz/clients/sanitizer/data.py | 6 +- triton_viz/clients/sanitizer/sanitizer.py | 390 ++++++++++++++------- triton_viz/clients/tracer/tracer.py | 22 +- triton_viz/clients/utils.py | 38 +- triton_viz/core/__init__.py | 63 +++- triton_viz/core/client.py | 10 +- triton_viz/core/config.py | 2 + triton_viz/core/patch.py | 116 ++++-- triton_viz/core/trace.py | 5 +- triton_viz/wrapper.py | 10 +- 22 files changed, 600 insertions(+), 264 deletions(-) rename {tests => examples}/test_non_contiguous.py (58%) diff --git a/examples/elementwise_add_autotune.py b/examples/elementwise_add_autotune.py index 1dc34e2d..79fd41aa 100644 --- a/examples/elementwise_add_autotune.py +++ b/examples/elementwise_add_autotune.py @@ -8,11 +8,11 @@ # Custom kernel for element-wise addition with autotuning @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE': 32}, num_warps=1), - triton.Config({'BLOCK_SIZE': 64}, num_warps=2), - triton.Config({'BLOCK_SIZE': 128}, num_warps=4), + triton.Config({"BLOCK_SIZE": 32}, num_warps=1), + triton.Config({"BLOCK_SIZE": 64}, num_warps=2), + triton.Config({"BLOCK_SIZE": 128}, num_warps=4), ], - key=['n_elements'], # Key for selecting the optimal config based on input size + key=["n_elements"], # Key for selecting the optimal config based on input size warmup=5, rep=5, ) @@ -37,10 +37,11 @@ def elementwise_add_kernel( # Store the result tl.store(output_ptr + offsets, output, mask=mask) + # Create PyTorch tensors as input n_elements = 100000 -x = torch.randn(n_elements, device='cuda') -y = torch.randn(n_elements, device='cuda') +x = torch.randn(n_elements, device="cuda") +y = torch.randn(n_elements, device="cuda") output = torch.empty_like(x) # Launch the Triton kernel diff --git a/examples/load_and_store.py b/examples/load_and_store.py index 4c2202d8..8736a71c 100644 --- a/examples/load_and_store.py +++ b/examples/load_and_store.py @@ -5,6 +5,7 @@ import triton_viz from triton_viz.clients import Sanitizer + @triton_viz.trace(clients=Sanitizer(abort_on_error=True)) @triton.jit def simple_kernel(X_ptr, Y_ptr, BLOCK_SIZE: tl.constexpr): @@ -15,18 +16,19 @@ def simple_kernel(X_ptr, Y_ptr, BLOCK_SIZE: tl.constexpr): tl.store(Y_ptr + idx, x) -if __name__ == '__main__': + +if __name__ == "__main__": BLOCK_SIZE = 1024 n_elements = 512 # Create input and output tensors - X = torch.arange(n_elements, dtype=torch.float32, device='cuda') - Y = torch.empty_like(X, device='cuda') + X = torch.arange(n_elements, dtype=torch.float32, device="cuda") + Y = torch.empty_like(X, device="cuda") # Launch the Triton kernel grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) simple_kernel[grid](X, Y, BLOCK_SIZE=BLOCK_SIZE) - + # Verify the results print("Input tensor X:", X) print("Output tensor Y:", Y) diff --git a/tests/test_non_contiguous.py b/examples/test_non_contiguous.py similarity index 58% rename from tests/test_non_contiguous.py rename to examples/test_non_contiguous.py index a1f40a30..1ae59433 100644 --- a/tests/test_non_contiguous.py +++ b/examples/test_non_contiguous.py @@ -1,5 +1,8 @@ import torch -from triton_viz.clients.utils import check_storage_contiguous, get_physical_addr_from_tensor_slice +from triton_viz.clients.utils import ( + check_storage_contiguous, + get_physical_addr_from_tensor_slice, +) def test_transpose(): @@ -8,20 +11,26 @@ def test_transpose(): print(b) print([(0, b.numel() - 1)]) + def test_2d_slice(): a = torch.arange(25).view(5, 5) b = a[1:4, 1:4] - print('b:', b) + print("b:", b) print("is_contiguous:", check_storage_contiguous(b)) segments = get_physical_addr_from_tensor_slice(b) for start, end in segments: - print(f"[{(start - b.data_ptr()) / b.element_size()}, {(end - b.data_ptr()) / b.element_size()}]") + print( + f"[{(start - b.data_ptr()) / b.element_size()}, {(end - b.data_ptr()) / b.element_size()}]" + ) + def test_3d_slice(): a = torch.arange(125).view(5, 5, 5) b = a[1:4, 1:4, 1:4] - print('b:', b) + print("b:", b) print("is_contiguous:", check_storage_contiguous(b)) segments = get_physical_addr_from_tensor_slice(b) for start, end in segments: - print(f"[{(start - b.data_ptr()) / b.element_size()}, {(end - b.data_ptr()) / b.element_size()}]") + print( + f"[{(start - b.data_ptr()) / b.element_size()}, {(end - b.data_ptr()) / b.element_size()}]" + ) diff --git a/examples/test_symbolic_execution.py b/examples/test_symbolic_execution.py index df7f2b53..c465801d 100644 --- a/examples/test_symbolic_execution.py +++ b/examples/test_symbolic_execution.py @@ -12,9 +12,11 @@ def add_kernel(x): pid = tl.program_id(0) addr = x + pid tl.load(addr) - a = torch.randn(16, dtype=torch.float32, device='cuda') + + a = torch.randn(16, dtype=torch.float32, device="cuda") add_kernel[(2,)](a) + def test_tl_make_range(): @triton_viz.trace(clients=Sanitizer(abort_on_error=True)) @triton.jit @@ -22,36 +24,44 @@ def make_range_kernel(x, BLOCK_SIZE: tl.constexpr): tl.load(x) offset = x + tl.arange(0, BLOCK_SIZE) tl.load(offset) - a = torch.randn(16, dtype=torch.float32, device='cuda') + + a = torch.randn(16, dtype=torch.float32, device="cuda") make_range_kernel[(1,)](a, BLOCK_SIZE=16) + def test_tl_add(): @triton_viz.trace(clients=Sanitizer(abort_on_error=True)) @triton.jit def program_id_kernel(x): addr = x + 1 tl.load(addr) - a = torch.randn(16, dtype=torch.float32, device='cuda') + + a = torch.randn(16, dtype=torch.float32, device="cuda") program_id_kernel[(2,)](a) + def test_tl_sub(): @triton_viz.trace(clients=Sanitizer(abort_on_error=True)) @triton.jit def sub_kernel(x): addr = x - 1 tl.load(addr) - a = torch.randn(16, dtype=torch.float32, device='cuda') + + a = torch.randn(16, dtype=torch.float32, device="cuda") sub_kernel[(2,)](a) + def test_tl_mul(): @triton_viz.trace(clients=Sanitizer(abort_on_error=True)) @triton.jit def mul_kernel(x, BLOCK_SIZE: tl.constexpr): addr = x + (tl.arange(0, BLOCK_SIZE) * 2) tl.load(addr) - a = torch.randn(32, dtype=torch.float32, device='cuda') + + a = torch.randn(32, dtype=torch.float32, device="cuda") mul_kernel[(1,)](a, BLOCK_SIZE=16) + def test_tl_div(): @triton_viz.trace(clients=Sanitizer(abort_on_error=True)) @triton.jit @@ -59,9 +69,11 @@ def div_kernel(x, BLOCK_SIZE: tl.constexpr): tl.load(x) tl.load(x + (tl.arange(0, BLOCK_SIZE) // 2)) tl.load(x + tl.arange(0, BLOCK_SIZE)) - a = torch.randn(32, dtype=torch.float32, device='cuda') + + a = torch.randn(32, dtype=torch.float32, device="cuda") div_kernel[(1,)](a, BLOCK_SIZE=16) + def test_tl_mod(): @triton_viz.trace(clients=Sanitizer(abort_on_error=True)) @triton.jit @@ -69,9 +81,11 @@ def mod_kernel(x, BLOCK_SIZE: tl.constexpr): tl.load(x) tl.load(x + (tl.arange(0, BLOCK_SIZE) % 10)) tl.load(x + tl.arange(0, BLOCK_SIZE)) - a = torch.randn(32, dtype=torch.float32, device='cuda') + + a = torch.randn(32, dtype=torch.float32, device="cuda") mod_kernel[(1,)](a, BLOCK_SIZE=16) + def test_vec_add(): @triton_viz.trace(clients=Sanitizer(abort_on_error=True)) @triton.jit @@ -87,12 +101,13 @@ def add_kernel(x_ptr, y_ptr, output_ptr, BLOCK_SIZE: tl.constexpr): access_size = 24 size = 17 BLOCK_SIZE = 8 - a = torch.randn(size, dtype=torch.float32, device='cuda') - b = torch.randn(size, dtype=torch.float32, device='cuda') - output = torch.empty_like(a, device='cuda') + a = torch.randn(size, dtype=torch.float32, device="cuda") + b = torch.randn(size, dtype=torch.float32, device="cuda") + output = torch.empty_like(a, device="cuda") grid = lambda meta: (triton.cdiv(access_size, meta["BLOCK_SIZE"]),) add_kernel[grid](a, b, output, BLOCK_SIZE=BLOCK_SIZE) + def test_vec_add_mask(): @triton_viz.trace(clients=Sanitizer(abort_on_error=True)) @triton.jit @@ -109,21 +124,26 @@ def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): access_size = 24 size = 17 BLOCK_SIZE = 8 - a = torch.randn(size, dtype=torch.float32, device='cuda') - b = torch.randn(size, dtype=torch.float32, device='cuda') - output = torch.empty_like(a, device='cuda') + a = torch.randn(size, dtype=torch.float32, device="cuda") + b = torch.randn(size, dtype=torch.float32, device="cuda") + output = torch.empty_like(a, device="cuda") grid = lambda meta: (triton.cdiv(access_size, meta["BLOCK_SIZE"]),) add_kernel[grid](a, b, output, size, BLOCK_SIZE=BLOCK_SIZE) + def test_new_axis_column(): @triton_viz.trace(clients=Sanitizer(abort_on_error=True)) @triton.jit def new_axis_kernel(out_ptr, BLOCK_ROW_SIZE: tl.constexpr): - pid = out_ptr + tl.program_id(0) * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None] + pid = ( + out_ptr + + tl.program_id(0) * BLOCK_ROW_SIZE + + tl.arange(0, BLOCK_ROW_SIZE)[:, None] + ) tl.load(pid) BLOCK_ROW_SIZE = 8 - out = torch.empty((BLOCK_ROW_SIZE, 1), dtype=torch.int32, device='cuda') + out = torch.empty((BLOCK_ROW_SIZE, 1), dtype=torch.int32, device="cuda") grid = lambda meta: (1,) new_axis_kernel[grid](out, BLOCK_ROW_SIZE=BLOCK_ROW_SIZE) @@ -132,14 +152,19 @@ def test_new_axis_row(): @triton_viz.trace(clients=Sanitizer(abort_on_error=True)) @triton.jit def new_axis_kernel(out_ptr, BLOCK_ROW_SIZE: tl.constexpr): - pid = out_ptr + tl.program_id(0) * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[None, :] + pid = ( + out_ptr + + tl.program_id(0) * BLOCK_ROW_SIZE + + tl.arange(0, BLOCK_ROW_SIZE)[None, :] + ) tl.load(pid) BLOCK_ROW_SIZE = 8 - out = torch.empty((BLOCK_ROW_SIZE, 1), dtype=torch.int32, device='cuda') + out = torch.empty((BLOCK_ROW_SIZE, 1), dtype=torch.int32, device="cuda") grid = lambda meta: (1,) new_axis_kernel[grid](out, BLOCK_ROW_SIZE=BLOCK_ROW_SIZE) + def test_tl_maximum(): @triton_viz.trace(clients=Sanitizer(abort_on_error=True)) @triton.jit @@ -158,12 +183,13 @@ def maximum_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): size = 20 BLOCK_SIZE = 8 - a = torch.randn(size, dtype=torch.float32, device='cuda') - b = torch.randn(size, dtype=torch.float32, device='cuda') - out = torch.empty_like(a, device='cuda') + a = torch.randn(size, dtype=torch.float32, device="cuda") + b = torch.randn(size, dtype=torch.float32, device="cuda") + out = torch.empty_like(a, device="cuda") grid = lambda meta: (triton.cdiv(size, meta["BLOCK_SIZE"]),) maximum_kernel[grid](a, b, out, size, BLOCK_SIZE=BLOCK_SIZE) + def test_tl_log(): @triton_viz.trace(clients=Sanitizer(abort_on_error=True)) @triton.jit @@ -180,7 +206,7 @@ def log_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): BLOCK_SIZE = 8 eps = 0.01 - a = torch.rand(size, dtype=torch.float32, device='cuda') + eps - out = torch.empty_like(a, device='cuda') + a = torch.rand(size, dtype=torch.float32, device="cuda") + eps + out = torch.empty_like(a, device="cuda") grid = lambda meta: (triton.cdiv(size, meta["BLOCK_SIZE"]),) log_kernel[grid](a, out, size, BLOCK_SIZE=BLOCK_SIZE) diff --git a/requirements.txt b/requirements.txt index 1c827fd8..07d6459c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -setuptools -triton -gradio chalk-diagrams @ git+https://github.com/chalk-diagrams/chalk.git -pyarrow +gradio pre-commit -pytest \ No newline at end of file +pyarrow +pytest +setuptools +triton diff --git a/tests/test_autotune_add.py b/tests/test_autotune_add.py index ff53fe94..e0ece54e 100644 --- a/tests/test_autotune_add.py +++ b/tests/test_autotune_add.py @@ -9,12 +9,13 @@ cfg.sanitizer_backend = "symexec" + @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE': 32}, num_warps=1), - triton.Config({'BLOCK_SIZE': 64}, num_warps=2), + triton.Config({"BLOCK_SIZE": 32}, num_warps=1), + triton.Config({"BLOCK_SIZE": 64}, num_warps=2), ], - key=['n_elements'], + key=["n_elements"], ) @triton_viz.trace(clients=Sanitizer(abort_on_error=True)) @triton.jit @@ -32,43 +33,35 @@ def add_kernel_no_mask(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constex y_val = tl.load(y_ptr + offsets) tl.store(out_ptr + offsets, x_val + y_val) + def test_autotune_add_inrange(): """ This test uses n_elements = 128, matching the size of the input tensors. It should NOT cause any out-of-bound access. """ - x = torch.randn(128, device='cuda') - y = torch.randn(128, device='cuda') + x = torch.randn(128, device="cuda") + y = torch.randn(128, device="cuda") out = torch.empty_like(x) # The kernel launch uses n_elements=128, aligned with the tensor size. - grid = lambda META: (triton.cdiv(128, META['BLOCK_SIZE']),) - add_kernel_no_mask[grid]( - x_ptr=x, - y_ptr=y, - out_ptr=out, - n_elements=128 - ) + grid = lambda META: (triton.cdiv(128, META["BLOCK_SIZE"]),) + add_kernel_no_mask[grid](x_ptr=x, y_ptr=y, out_ptr=out, n_elements=128) print("test_autotune_add_inrange() passed: No out-of-bound access.") + def test_autotune_add_out_of_bound(): """ This test deliberately sets n_elements = 256, exceeding the actual buffer size (128). It will likely cause out-of-bound reads/writes, which may trigger errors or warnings. """ - x = torch.randn(128, device='cuda') - y = torch.randn(128, device='cuda') + x = torch.randn(128, device="cuda") + y = torch.randn(128, device="cuda") out = torch.empty_like(x) # The kernel launch uses n_elements=256, exceeding the valid tensor size. - grid = lambda META: (triton.cdiv(256, META['BLOCK_SIZE']),) - add_kernel_no_mask[grid]( - x_ptr=x, - y_ptr=y, - out_ptr=out, - n_elements=256 - ) + grid = lambda META: (triton.cdiv(256, META["BLOCK_SIZE"]),) + add_kernel_no_mask[grid](x_ptr=x, y_ptr=y, out_ptr=out, n_elements=256) # Depending on hardware/drivers, this may or may not raise an error immediately. print("test_autotune_add_oob() completed: Potential out-of-bound access occurred.") diff --git a/tests/test_print_traceback.py b/tests/test_print_traceback.py index c90e9dfb..226d54e3 100644 --- a/tests/test_print_traceback.py +++ b/tests/test_print_traceback.py @@ -9,12 +9,14 @@ cfg.sanitizer_backend = "symexec" + @triton.jit def kernel_B(ptr, offset): # a simple function that adds 1 val = tl.load(ptr + offset) return val + 1 + @triton_viz.trace(clients=Sanitizer(abort_on_error=True)) @triton.jit def kernel_A(ptr, n): @@ -23,8 +25,9 @@ def kernel_A(ptr, n): val = kernel_B(ptr, pid) tl.store(ptr + pid, val) + def test_print_nested_functions(): - x = torch.arange(4, device='cuda', dtype=torch.float32) + x = torch.arange(4, device="cuda", dtype=torch.float32) print("Input:", x) # We'll launch a grid bigger than x.numel() to force a out-of-bounds error diff --git a/tests/test_trace_add_clients.py b/tests/test_trace_add_clients.py index 6ddd3fdc..9bcb5d59 100644 --- a/tests/test_trace_add_clients.py +++ b/tests/test_trace_add_clients.py @@ -9,6 +9,7 @@ # Make sure sanitizer is on. cfg.sanitizer_backend = "symexec" + def test_trace_decorator_add_clients(): """ Test goal: @@ -20,19 +21,22 @@ def test_trace_decorator_add_clients(): The final Trace object should contain exactly one instance each of Sanitizer, Profiler, and Tracer (total = 3 clients). """ + @triton_viz.trace("sanitizer") @triton_viz.trace("profiler") @triton_viz.trace("tracer") - @triton_viz.trace(Sanitizer(abort_on_error=True)) # Duplicate Sanitizer (should be ignored) + @triton_viz.trace( + Sanitizer(abort_on_error=True) + ) # Duplicate Sanitizer (should be ignored) @triton.jit def my_kernel(x_ptr, y_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - tl.store(out_ptr + offs, - tl.load(x_ptr + offs) + tl.load(y_ptr + offs)) + tl.store(out_ptr + offs, tl.load(x_ptr + offs) + tl.load(y_ptr + offs)) # Should be wrapped as a Trace object. from triton_viz.core.trace import Trace + assert isinstance(my_kernel, Trace) # Verify client de-duplication and addition logic diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index 3d541801..925c521f 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -5,9 +5,6 @@ import os import subprocess -import triton -import triton.language as tl - """ Test that triton_viz.wrapper works correctly: @@ -15,6 +12,8 @@ triton.runtime.interpreter.jit with wrapper._patched_jit 2. The first use of @triton.jit must invoke triton_viz.trace(Sanitizer) """ + + def test_cli_invocation(): """ Simulate running: @@ -75,6 +74,6 @@ def _decorator(fn): assert proc.returncode == 0, f"CLI exited with {proc.returncode}\n{proc.stderr}" # Check if trace was called once and only once trace_count = proc.stdout.count("TRACE_CALLED") - assert trace_count == 1, ( - "triton_viz.trace should be invoked exactly once via CLI path" - ) + assert ( + trace_count == 1 + ), "triton_viz.trace should be invoked exactly once via CLI path" diff --git a/triton_viz/__init__.py b/triton_viz/__init__.py index e301c8b6..799bbff2 100644 --- a/triton_viz/__init__.py +++ b/triton_viz/__init__.py @@ -1,4 +1,4 @@ from .core import trace, clear, config from .visualizer import launch -__all__ = ["trace", "clear", "launch"] +__all__ = ["trace", "clear", "config", "launch"] diff --git a/triton_viz/clients/__init__.py b/triton_viz/clients/__init__.py index 1cc1faef..a4c9004a 100644 --- a/triton_viz/clients/__init__.py +++ b/triton_viz/clients/__init__.py @@ -4,4 +4,11 @@ from .sanitizer.data import OutOfBoundsRecord from .tracer.tracer import Tracer -__all__ = ["Profiler", "Sanitizer", "LoadStoreBytes", "OpTypeCounts", "OutOfBoundsRecord", "Tracer"] +__all__ = [ + "Profiler", + "Sanitizer", + "LoadStoreBytes", + "OpTypeCounts", + "OutOfBoundsRecord", + "Tracer", +] diff --git a/triton_viz/clients/profiler/profiler.py b/triton_viz/clients/profiler/profiler.py index e38e3598..407e121f 100644 --- a/triton_viz/clients/profiler/profiler.py +++ b/triton_viz/clients/profiler/profiler.py @@ -35,8 +35,12 @@ def _report_load_store_bytes(self, type, ptr: TensorHandle, mask: TensorHandle): self.store_bytes.total_bytes_attempted += total_bytes_attempted self.store_bytes.total_bytes_true += total_bytes_true - def register_op_callback(self, op: Type[Op]) -> Tuple[Optional[Callable], Optional[Callable], Optional[Callable]]: - def pre_load_callback(ptr, mask, other, cache_modifier, eviction_policy, is_volatile): + def register_op_callback( + self, op: Type[Op] + ) -> Tuple[Optional[Callable], Optional[Callable], Optional[Callable]]: + def pre_load_callback( + ptr, mask, other, cache_modifier, eviction_policy, is_volatile + ): self._report_load_store_bytes("load", ptr, mask) def pre_store_callback(ptr, value, mask, cache_modifier, eviction_policy): diff --git a/triton_viz/clients/sanitizer/data.py b/triton_viz/clients/sanitizer/data.py index f93e5b32..35a12a47 100644 --- a/triton_viz/clients/sanitizer/data.py +++ b/triton_viz/clients/sanitizer/data.py @@ -14,12 +14,14 @@ class TracebackInfo: func_name: str line_of_code: str + @dataclass class OutOfBoundsRecord: op_type: Type[Union[Store, Load]] tensor: torch.Tensor user_code_tracebacks: List[TracebackInfo] + @dataclass class OutOfBoundsRecordBruteForce(OutOfBoundsRecord): offsets: NDArray[np.int_] @@ -28,12 +30,13 @@ class OutOfBoundsRecordBruteForce(OutOfBoundsRecord): invalid_access_masks: NDArray[np.bool_] corrected_offsets: NDArray[np.int_] + @dataclass class OutOfBoundsRecordZ3(OutOfBoundsRecord): """ Attributes: constraints (List[z3.z3.BoolRef]): - A collection of Z3 constraint expressions defining valid ranges + A collection of Z3 constraint expressions defining valid ranges for memory access. Any address falling outside these ranges is considered invalid. violation_address (int): @@ -49,5 +52,6 @@ class OutOfBoundsRecordZ3(OutOfBoundsRecord): In this scenario, 200 is an out-of-bounds address because it falls outside the valid ranges described by these constraints. """ + constraints: List[z3.z3.BoolRef] violation_address: int diff --git a/triton_viz/clients/sanitizer/sanitizer.py b/triton_viz/clients/sanitizer/sanitizer.py index 01153806..964e7230 100644 --- a/triton_viz/clients/sanitizer/sanitizer.py +++ b/triton_viz/clients/sanitizer/sanitizer.py @@ -9,14 +9,41 @@ from ...core.client import Client from ...core.data import ( - Op, RawLoad, Load, RawStore, Store, - UnaryOp, BinaryOp, TernaryOp, ProgramId, - Dot, MakeRange, AddPtr, ExpandDims, Broadcast, ReduceSum, - Splat, MakeBlockPointer, TensorPointerLoad, - TensorPointerStore, Idiv, Rsqrt, - CastImpl) -from ..utils import check_out_of_bounds_access, check_storage_contiguous, get_physical_addr_from_tensor_slice, check_inner_stride_equal_to_one -from .data import TracebackInfo, OutOfBoundsRecord, OutOfBoundsRecordBruteForce, OutOfBoundsRecordZ3 + Op, + RawLoad, + Load, + RawStore, + Store, + UnaryOp, + BinaryOp, + TernaryOp, + ProgramId, + Dot, + MakeRange, + AddPtr, + ExpandDims, + Broadcast, + ReduceSum, + Splat, + MakeBlockPointer, + TensorPointerLoad, + TensorPointerStore, + Idiv, + Rsqrt, + CastImpl, +) +from ..utils import ( + check_out_of_bounds_access, + check_storage_contiguous, + get_physical_addr_from_tensor_slice, + check_inner_stride_equal_to_one, +) +from .data import ( + TracebackInfo, + OutOfBoundsRecord, + OutOfBoundsRecordBruteForce, + OutOfBoundsRecordZ3, +) from ...core import config as cfg @@ -46,11 +73,17 @@ def print_oob_record(oob_record: OutOfBoundsRecord, max_display=10): print(" Out-Of-Bounds Access Detected ") print("============================================================") print(f"Operation: {op_type}") - print(f"Tensor Info: dtype={tensor.dtype}, shape={tensor.shape}, device={tensor.device}") + print( + f"Tensor Info: dtype={tensor.dtype}, shape={tensor.shape}, device={tensor.device}" + ) print(f"Tensor base memory address: {tensor.data_ptr()}") - print("Valid Access Range: [0, %d)" % (np.prod(tensor.shape) * tensor.element_size())) + print( + "Valid Access Range: [0, %d)" % (np.prod(tensor.shape) * tensor.element_size()) + ) for traceback_info in oob_record.user_code_tracebacks: - print(f"File: {traceback_info.filename}, Line: {traceback_info.lineno}, in {traceback_info.func_name}") + print( + f"File: {traceback_info.filename}, Line: {traceback_info.lineno}, in {traceback_info.func_name}" + ) print(f" Code: {traceback_info.line_of_code}") print("------------------------------------------------------------") @@ -81,12 +114,15 @@ def print_oob_record(oob_record: OutOfBoundsRecord, max_display=10): print(constraint) else: - raise NotImplementedError("Invalid OutOfBoundsRecord type: " + str(type(oob_record))) + raise NotImplementedError( + "Invalid OutOfBoundsRecord type: " + str(type(oob_record)) + ) print("============================================================") print(" End of Out-Of-Bounds Record Details ") print("============================================================") + def _get_traceback_info(): """ Why do both _grid_executor_call and _jit_function_call appear in the call stacks? @@ -108,12 +144,18 @@ def _get_traceback_info(): # scan the call stack for i, frame in enumerate(stack_summary): user_code_index = None - if ('_jit_function_call' in frame.name - and 'triton_viz/core/patch.py' in frame.filename): - user_code_index = i + 2 # _grid_executor_call -> run_grid_loops -> user code - elif ('_grid_executor_call' in frame.name - and 'triton_viz/core/patch.py' in frame.filename): - user_code_index = i + 2 # the same as above + if ( + "_jit_function_call" in frame.name + and "triton_viz/core/patch.py" in frame.filename + ): + user_code_index = ( + i + 2 + ) # _grid_executor_call -> run_grid_loops -> user code + elif ( + "_grid_executor_call" in frame.name + and "triton_viz/core/patch.py" in frame.filename + ): + user_code_index = i + 2 # the same as above if user_code_index is not None: frame = stack_summary[user_code_index] @@ -125,24 +167,28 @@ def _get_traceback_info(): filename=oob_filename, lineno=oob_lineno, func_name=oob_func_name, - line_of_code=oob_line_of_code + line_of_code=oob_line_of_code, ) user_code_tracebacks.append(traceback_info) return user_code_tracebacks + def _get_tensor(tensor_list, data_ptr): - # From a give ptr, get where the original tensor is stored - # Tensors have been sorted by ptr - ret_idx = 0 - for i in range(len(tensor_list)): - if data_ptr < tensor_list[i].data_ptr(): - break - ret_idx = i - return tensor_list[ret_idx] + # From a give ptr, get where the original tensor is stored + # Tensors have been sorted by ptr + ret_idx = 0 + for i in range(len(tensor_list)): + if data_ptr < tensor_list[i].data_ptr(): + break + ret_idx = i + return tensor_list[ret_idx] + class SanitizerBruteForce(Client): - def __init__(self, callpath: Optional[bool] = True, abort_on_error: Optional[bool] = True): + def __init__( + self, callpath: Optional[bool] = True, abort_on_error: Optional[bool] = True + ): self.callpath = callpath self.abort_on_error = abort_on_error self.tensors: list = [] @@ -150,11 +196,15 @@ def __init__(self, callpath: Optional[bool] = True, abort_on_error: Optional[boo def _report(self, op_type, record): traceback_info = _get_traceback_info() - oob_record = OutOfBoundsRecordBruteForce(op_type=op_type, user_code_tracebacks=traceback_info, **record) + oob_record = OutOfBoundsRecordBruteForce( + op_type=op_type, user_code_tracebacks=traceback_info, **record + ) if self.abort_on_error: if np.any(oob_record.invalid_access_masks): print_oob_record(oob_record) - assert False, "Out-of-bounds access detected. See detailed report above." + assert ( + False + ), "Out-of-bounds access detected. See detailed report above." else: self.records.append(oob_record) @@ -171,22 +221,28 @@ def grid_idx_callback(self, grid_idx: Tuple[int]): def grid_callback(self, grid: Tuple[int]): self.tensors = sorted(self.tensors, key=lambda x: x.data_ptr()) - def register_op_callback(self, op_type: Type[Op]) -> Tuple[Optional[Callable], Optional[Callable]]: - def pre_load_callback(ptr, mask, other, cache_modifier, eviction_policy, is_volatile): + def register_op_callback( + self, op_type: Type[Op] + ) -> Tuple[Optional[Callable], Optional[Callable]]: + def pre_load_callback( + ptr, mask, other, cache_modifier, eviction_policy, is_volatile + ): first_loc = np.unravel_index(np.argmax(mask, axis=None), mask.data.shape) first_ptr = ptr.data[first_loc] tensor = _get_tensor(self.tensors, first_ptr) oob = check_out_of_bounds_access(ptr.data, mask.data, tensor) self._report(op_type, oob) - ptr.data = tensor.data_ptr() + oob['corrected_offsets'] + ptr.data = tensor.data_ptr() + oob["corrected_offsets"] def pre_store_callback(ptr, value, mask, cache_modifier, eviction_policy): first_loc = np.unravel_index(np.argmax(mask, axis=None), mask.data.shape) first_ptr = ptr.data[first_loc] tensor = _get_tensor(self.tensors, first_ptr) oob = check_out_of_bounds_access(ptr.data, mask.data, tensor) - self._report(op_type, check_out_of_bounds_access(ptr.data, mask.data, tensor)) - ptr.data = tensor.data_ptr() + oob['corrected_offsets'] + self._report( + op_type, check_out_of_bounds_access(ptr.data, mask.data, tensor) + ) + ptr.data = tensor.data_ptr() + oob["corrected_offsets"] if op_type is Load: return pre_load_callback, None, None @@ -198,15 +254,17 @@ def pre_store_callback(ptr, value, mask, cache_modifier, eviction_policy): def finalize(self) -> list: return self.records + class SymbolicExprDataWrapper: - ''' + """ This wrapper is used as a workaround of triton interpreter legacy code. In def _get_bool(self) of class tensor, "data = self.handle.data return bool(data) if data.size == 1 else True" Since we replaced TensorHandle with SymbolicExpr, we need to wrap SymbolicExpr with a class that has size attribute, and data.size != 1. - ''' + """ + def __init__(self, value, symbolic_expr): self.value = value self.symbolic_expr = symbolic_expr @@ -218,7 +276,9 @@ def size(self): def __int__(self): int_val = self.symbolic_expr.eval() if not isinstance(int_val, int): - raise ValueError(f"SymbolicExprDataWrapper is type: {type(int_val)}, value: {int_val} and cannot be converted to int") + raise ValueError( + f"SymbolicExprDataWrapper is type: {type(int_val)}, value: {int_val} and cannot be converted to int" + ) return int_val def __str__(self): @@ -227,11 +287,23 @@ def __str__(self): def __repr__(self): return self.value + class SymbolicExpr: BASIC_OPS = ("const", "pid", "arange") INDIRECT_OPS = ("load", "store") - UNARY_OPS = ("cos", "exp", "exp2", "abs", "floor", "ceil", "log", - "log2", "sqrt", "sin", "rsqrt",) + UNARY_OPS = ( + "cos", + "exp", + "exp2", + "abs", + "floor", + "ceil", + "log", + "log2", + "sqrt", + "sin", + "rsqrt", + ) BINARY_OP_SYMBOL_TABLE = { "add": "+", "sub": "-", @@ -254,15 +326,16 @@ class SymbolicExpr: POINTER_OPS = ("make_block_ptr",) BROADCAST_OPS = ("splat", "expand_dims", "broadcast") SUPPORTED_OPS = ( - BASIC_OPS + - INDIRECT_OPS + - UNARY_OPS + - BINARY_OPS + - TERNARY_OPS + - REDUCE_OPS + - POINTER_OPS + - BROADCAST_OPS + BASIC_OPS + + INDIRECT_OPS + + UNARY_OPS + + BINARY_OPS + + TERNARY_OPS + + REDUCE_OPS + + POINTER_OPS + + BROADCAST_OPS ) + def __init__(self, op, *args): """ :param op: Operation type, e.g. "const", "add", "sub", "mul", "div", "pid", "arange" @@ -309,12 +382,14 @@ def __init__(self, op, *args): assert len(args) == 2, f"{self.op} op expects two arguments!" self.lhs = args[0] self.rhs = args[1] - if not self.lhs.shape: # lhs is a scalar + if not self.lhs.shape: # lhs is a scalar ret_shape = self.rhs.shape - elif not self.rhs.shape: # rhs is a scalar + elif not self.rhs.shape: # rhs is a scalar ret_shape = self.lhs.shape - else: # both are blocks - assert self.lhs.shape == self.rhs.shape, f"lhs shape {self.lhs.shape} should be equal to rhs shape {self.rhs.shape}" + else: # both are blocks + assert ( + self.lhs.shape == self.rhs.shape + ), f"lhs shape {self.lhs.shape} should be equal to rhs shape {self.rhs.shape}" ret_shape = self.lhs.shape self.shape = ret_shape elif self.op in self.TERNARY_OPS: @@ -333,7 +408,8 @@ def __init__(self, op, *args): elif self.op == "dot": self.a = args[0] self.b = args[1] - if len(args) >= 3: self.d = args[2] + if len(args) >= 3: + self.d = args[2] else: raise NotImplementedError(f"Unsupported reduce op: {self.op}") elif self.op == "make_block_ptr": @@ -424,7 +500,7 @@ def to_anytree(self): for child_name, child_expr in self._children(): # Build the child subtree child_node = child_expr.to_anytree() - # Prefix the child node’s name with the field name + # Prefix the child node's name with the field name child_node.name = f"{child_name}: {child_node.name}" child_node.parent = root @@ -509,13 +585,19 @@ def data(self): @classmethod def from_value(cls, var): triton_scala_dtypes = ( - tl.int8, tl.int16, tl.int32, tl.int64, - tl.uint8, tl.uint16, tl.uint32, tl.uint64, - tl.float16, tl.float32, tl.float64 - ) - builtin_scala_types = ( - int, float + tl.int8, + tl.int16, + tl.int32, + tl.int64, + tl.uint8, + tl.uint16, + tl.uint32, + tl.uint64, + tl.float16, + tl.float32, + tl.float64, ) + builtin_scala_types = (int, float) # if already SymbolicExpr if isinstance(var, cls): return var @@ -531,7 +613,9 @@ def from_value(cls, var): # if a pointer elif isinstance(var.dtype, tl.pointer_type): if len(var.data) != 1: - raise ValueError("Unsupported tl.pointer_type with length more than one!") + raise ValueError( + "Unsupported tl.pointer_type with length more than one!" + ) return cls("const", var.data.item(), var.get_element_ty()) else: raise ValueError("Unsupported TensorHandle dtype", var.dtype) @@ -547,8 +631,8 @@ def eval(self): - expr: Z3 expression corresponding to the root node - constraints: list of Z3 BoolExpr objects, recording all range constraints created by program_id and arange """ - self._arange_counter = 0 # Used to name arange variables - self._arange_dict = {} # make sure each arange only has one name + self._arange_counter = 0 # Used to name arange variables + self._arange_dict = {} # make sure each arange only has one name self._vars = {} self._constraints = [] expr = self._to_z3(self) @@ -593,26 +677,40 @@ def _to_z3(self, node): return If(c >= 0, c, -c) if node.op in self.UNARY_OPS: c = self._to_z3(node.arg) - if node.op == "abs": return If(c >= 0, c, -c) + if node.op == "abs": + return If(c >= 0, c, -c) raise NotImplementedError(f"Unary op {node.op} is not implemented") # Binary arithmetic, comparison, etc. if node.op in self.BINARY_OPS: l = self._to_z3(node.lhs) r = self._to_z3(node.rhs) - if node.op == "add": return l + r - if node.op == "sub": return l - r - if node.op == "mul": return l * r - if node.op in ("idiv"): return l / r - if node.op == "mod": return l % r - if node.op == "less": return l < r - if node.op == "less_equal": return l <= r - if node.op == "greater": return l > r - if node.op == "greater_equal": return l >= r - if node.op == "equal": return l == r - if node.op == "not_equal": return l != r - if node.op == "maximum": return If(l >= r, l, r) - if node.op == "bitwise_and": return And(l, r) + if node.op == "add": + return l + r + if node.op == "sub": + return l - r + if node.op == "mul": + return l * r + if node.op in ("idiv"): + return l / r + if node.op == "mod": + return l % r + if node.op == "less": + return l < r + if node.op == "less_equal": + return l <= r + if node.op == "greater": + return l > r + if node.op == "greater_equal": + return l >= r + if node.op == "equal": + return l == r + if node.op == "not_equal": + return l != r + if node.op == "maximum": + return If(l >= r, l, r) + if node.op == "bitwise_and": + return And(l, r) # where(cond, lhs, rhs) if node.op == "where": @@ -640,6 +738,7 @@ def _to_z3(self, node): # Other operations can be implemented as needed raise NotImplementedError(f"Eval for op {node.op} is not implemented") + class SanitizerSymbolicExecution(Client): def __init__(self, abort_on_error): self.abort_on_error = abort_on_error @@ -650,15 +749,19 @@ def __init__(self, abort_on_error): self.unique_load_store_id = 0 def _check_range_satisfiable(self, access_addr, expr_constraints): - out_of_bound_constraint = Not(Or( - *(And(start <= access_addr, access_addr <= end) - for start, end in self.tensor_addrs) - )) + out_of_bound_constraint = Not( + Or( + *( + And(start <= access_addr, access_addr <= end) + for start, end in self.tensor_addrs + ) + ) + ) s = Solver() s.add(out_of_bound_constraint) s.add(And(*expr_constraints)) if s.check() == sat: - print('out of bound access detected!') + print("out of bound access detected!") def _report(self, op_type, tensor, violation_address): traceback_info = _get_traceback_info() @@ -671,7 +774,9 @@ def _report(self, op_type, tensor, violation_address): ) if self.abort_on_error: print_oob_record(oob_record) - raise ValueError("Out-of-bounds access detected. See detailed report above.") + raise ValueError( + "Out-of-bounds access detected. See detailed report above." + ) else: self.records.append(oob_record) @@ -685,7 +790,9 @@ def arg_callback(self, arg, arg_cvt): elif check_inner_stride_equal_to_one(arg): tensor_physical_addresses = get_physical_addr_from_tensor_slice(arg) else: - raise ValueError("The address sanitizer only supports contiguouly stored tensors for now!") + raise ValueError( + "The address sanitizer only supports contiguouly stored tensors for now!" + ) self.tensors.append(arg) self.tensor_addrs.extend(tensor_physical_addresses) @@ -696,15 +803,21 @@ def grid_callback(self, grid: Tuple[int]): def grid_idx_callback(self, grid_idx: Tuple[int]): pass - def register_op_callback(self, op_type: Type[Op]) -> Tuple[Optional[Callable], Optional[Callable]]: + def register_op_callback( + self, op_type: Type[Op] + ) -> Tuple[Optional[Callable], Optional[Callable]]: def op_program_id_overrider(axis): assert self.grid, "Grid not initialized!" return SymbolicExpr("pid", self.grid, axis) def op_raw_load_overrider(ptr, cache_modifier, eviction_policy, is_volatile): - return op_load_overrider(ptr, None, None, cache_modifier, eviction_policy, is_volatile) + return op_load_overrider( + ptr, None, None, cache_modifier, eviction_policy, is_volatile + ) - def op_load_overrider(ptr, mask, other, cache_modifier, eviction_policy, is_volatile): + def op_load_overrider( + ptr, mask, other, cache_modifier, eviction_policy, is_volatile + ): # make sure ptr is a SymbolicExpr if isinstance(ptr, TensorHandle) and isinstance(ptr.dtype, tl.pointer_type): ptr = SymbolicExpr("load", SymbolicExpr.from_value(ptr)) @@ -760,16 +873,16 @@ def op_store_overrider(ptr, value, mask, cache_modifier, eviction_policy): def op_unary_op_overrider(arg, op): _unary_map = { - np.cos: "cos", - np.exp: "exp", - np.exp2: "exp2", - np.abs: "abs", + np.cos: "cos", + np.exp: "exp", + np.exp2: "exp2", + np.abs: "abs", np.floor: "floor", - np.ceil: "ceil", - np.log: "log", - np.log2: "log2", - np.sqrt: "sqrt", - np.sin: "sin", + np.ceil: "ceil", + np.log: "log", + np.log2: "log2", + np.sqrt: "sqrt", + np.sin: "sin", } arg = SymbolicExpr.from_value(arg) try: @@ -780,26 +893,28 @@ def op_unary_op_overrider(arg, op): def op_binary_op_overrider(lhs, rhs, op): _binary_map = { - np.add: lambda lhs, rhs: lhs + rhs, - np.subtract: lambda lhs, rhs: lhs - rhs, - np.multiply: lambda lhs, rhs: lhs * rhs, - np.divide: lambda lhs, rhs: lhs / rhs, - np.less: lambda lhs, rhs: lhs < rhs, - np.less_equal: lambda lhs, rhs: lhs <= rhs, - np.greater: lambda lhs, rhs: lhs > rhs, + np.add: lambda lhs, rhs: lhs + rhs, + np.subtract: lambda lhs, rhs: lhs - rhs, + np.multiply: lambda lhs, rhs: lhs * rhs, + np.divide: lambda lhs, rhs: lhs / rhs, + np.less: lambda lhs, rhs: lhs < rhs, + np.less_equal: lambda lhs, rhs: lhs <= rhs, + np.greater: lambda lhs, rhs: lhs > rhs, np.greater_equal: lambda lhs, rhs: lhs >= rhs, - np.not_equal: lambda lhs, rhs: lhs != rhs, - np.equal: lambda lhs, rhs: lhs == rhs, - np.fmod: lambda lhs, rhs: lhs % rhs, - np.maximum: lambda lhs, rhs: SymbolicExpr("maximum", lhs, rhs), - np.bitwise_and: lambda lhs, rhs: SymbolicExpr("bitwise_and", lhs, rhs), + np.not_equal: lambda lhs, rhs: lhs != rhs, + np.equal: lambda lhs, rhs: lhs == rhs, + np.fmod: lambda lhs, rhs: lhs % rhs, + np.maximum: lambda lhs, rhs: SymbolicExpr("maximum", lhs, rhs), + np.bitwise_and: lambda lhs, rhs: SymbolicExpr("bitwise_and", lhs, rhs), } lhs = SymbolicExpr.from_value(lhs) rhs = SymbolicExpr.from_value(rhs) try: func = _binary_map[op] except KeyError: - raise NotImplementedError(f"Unsupported binary operation: {op} between {lhs} and {rhs}") + raise NotImplementedError( + f"Unsupported binary operation: {op} between {lhs} and {rhs}" + ) return func(lhs, rhs) def op_ternary_op_overrider(lhs, rhs, other, op): @@ -809,12 +924,14 @@ def op_ternary_op_overrider(lhs, rhs, other, op): if op is np.where: return SymbolicExpr("where", lhs, rhs, other) else: - raise NotImplementedError(f"Unsupported ternary operation: {op} between {lhs}, {rhs} and {other}") + raise NotImplementedError( + f"Unsupported ternary operation: {op} between {lhs}, {rhs} and {other}" + ) def op_addptr_overrider(ptr, offset): - ''' + """ In addptr operator, ptr is a pointer address with dtype_tt, and offset is a scalar. - ''' + """ # Read dtype_tt from ptr. # Here, ptr is either a TensorHandle or a SymbolicExpr. dtype_tt = ptr.get_element_ty() @@ -864,34 +981,48 @@ def op_splat_overrider(arg, shape): arg = SymbolicExpr.from_value(arg) return SymbolicExpr("splat", arg, shape) - def op_make_block_ptr_overrider(base, shape, strides, offsets, tensor_shape, order): + def op_make_block_ptr_overrider( + base, shape, strides, offsets, tensor_shape, order + ): base = SymbolicExpr.from_value(base) - assert len(shape) == len(strides) == len(offsets) == len(tensor_shape) == len(order), \ - f"Length of shape ({len(shape)}), strides ({len(strides)}), offsets ({len(offsets)}), tensor_shape ({len(tensor_shape)}) and order ({len(order)}) must be the same!" + assert ( + len(shape) + == len(strides) + == len(offsets) + == len(tensor_shape) + == len(order) + ), f"Length of shape ({len(shape)}), strides ({len(strides)}), offsets ({len(offsets)}), tensor_shape ({len(tensor_shape)}) and order ({len(order)}) must be the same!" shape = [SymbolicExpr.from_value(shape_i) for shape_i in shape] strides = [SymbolicExpr.from_value(strides_i) for strides_i in strides] offsets = [SymbolicExpr.from_value(offset_i) for offset_i in offsets] - tensor_shape = [SymbolicExpr.from_value(tensor_shape_i) for tensor_shape_i in tensor_shape] + tensor_shape = [ + SymbolicExpr.from_value(tensor_shape_i) + for tensor_shape_i in tensor_shape + ] order = [SymbolicExpr.from_value(order_i) for order_i in order] ret = SymbolicExpr( - "make_block_ptr", - base, - shape, - strides, - offsets, - tensor_shape, - order) + "make_block_ptr", base, shape, strides, offsets, tensor_shape, order + ) ret.set_element_ty(base.get_element_ty()) print(ret) return ret - def op_tensor_pointer_load_overrider(ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile): + def op_tensor_pointer_load_overrider( + ptr, + boundary_check, + padding_option, + cache_modifier, + eviction_policy, + is_volatile, + ): raise NotImplementedError("TensorPointerLoad is not supported yet.") - def op_tensor_pointer_store_overrider(ptr, value, boundary_check, cache_modifier, eviction_policy): + def op_tensor_pointer_store_overrider( + ptr, value, boundary_check, cache_modifier, eviction_policy + ): raise NotImplementedError("TensorPointerStore is not supported yet.") def op_idiv_overrider(lhs, rhs): @@ -939,21 +1070,25 @@ def op_cast_impl_overrider(src, dst_type): def finalize(self) -> list: return [] + class NullSanitizer: """ A do-nothing object returned when the sanitizer backend is 'off'. Any attribute access raises an explicit error so misuse is obvious. """ + def __getattr__(self, name): raise RuntimeError( "Sanitizer backend is off; no sanitizer functionality is available." ) + class Sanitizer(ABC): """ Factory class that returns the concrete sanitizer implementation based on the value of ``cfg.sanitizer_backend``. """ + def __new__(cls, abort_on_error: bool = False): backend = cfg.sanitizer_backend @@ -966,9 +1101,8 @@ def __new__(cls, abort_on_error: bool = False): if backend == "off": return NullSanitizer() - raise ValueError( - f"Invalid TRITON_SANITIZER_BACKEND: {backend!r} " - ) + raise ValueError(f"Invalid TRITON_SANITIZER_BACKEND: {backend!r} ") + Sanitizer.register(SanitizerBruteForce) Sanitizer.register(SanitizerSymbolicExecution) diff --git a/triton_viz/clients/tracer/tracer.py b/triton_viz/clients/tracer/tracer.py index c5f95954..834ebb09 100644 --- a/triton_viz/clients/tracer/tracer.py +++ b/triton_viz/clients/tracer/tracer.py @@ -17,7 +17,11 @@ def _convert_grid_idx(grid_idx) -> Optional[Tuple[int, int, int]]: class Tracer(Client): - def __init__(self, callpath: Optional[bool] = True, grid_idx: Optional[Union[Tuple[int], int]] = None): + def __init__( + self, + callpath: Optional[bool] = True, + grid_idx: Optional[Union[Tuple[int], int]] = None, + ): self.callpath = callpath self.grid_idx = _convert_grid_idx(grid_idx) self.records: list = [] @@ -47,20 +51,28 @@ def grid_idx_callback(self, grid_idx: Tuple[int]): def grid_callback(self, grid: Tuple[int]): self.tensors = sorted(self.tensors, key=lambda x: x.data_ptr()) - def register_op_callback(self, op_type: Type[Op]) -> Tuple[Optional[Callable], Optional[Callable]]: - def pre_load_callback(ptr, mask, other, cache_modifier, eviction_policy, is_volatile): + def register_op_callback( + self, op_type: Type[Op] + ) -> Tuple[Optional[Callable], Optional[Callable]]: + def pre_load_callback( + ptr, mask, other, cache_modifier, eviction_policy, is_volatile + ): if not self.sample: return first_ptr = np.reshape(ptr.data, (-1))[0] tensor = self._get_tensor(first_ptr) - self.records.append(Load(tensor.data_ptr(), ptr.data - tensor.data_ptr(), mask.data)) + self.records.append( + Load(tensor.data_ptr(), ptr.data - tensor.data_ptr(), mask.data) + ) def pre_store_callback(ptr, value, mask, cache_modifier, eviction_policy): if not self.sample: return first_ptr = np.reshape(ptr.data, (-1))[0] tensor = self._get_tensor(first_ptr) - self.records.append(Store(tensor.data_ptr(), ptr.data - tensor.data_ptr(), mask.data)) + self.records.append( + Store(tensor.data_ptr(), ptr.data - tensor.data_ptr(), mask.data) + ) def post_reduce_sum_callback(ret, input, axis=None, keep_dims=False): if not self.sample: diff --git a/triton_viz/clients/utils.py b/triton_viz/clients/utils.py index 517dfc5e..53d52e45 100644 --- a/triton_viz/clients/utils.py +++ b/triton_viz/clients/utils.py @@ -4,19 +4,21 @@ import itertools -def check_out_of_bounds_access(ptrs: npt.NDArray, masks: npt.NDArray[np.bool_], tensor: torch.Tensor): +def check_out_of_bounds_access( + ptrs: npt.NDArray, masks: npt.NDArray[np.bool_], tensor: torch.Tensor +): offsets = ptrs - tensor.data_ptr() max_valid_offset = np.prod(tensor.shape) * tensor.element_size() valid_access_masks = (offsets >= 0) & (offsets < max_valid_offset) invalid_access_masks = (~valid_access_masks) & masks corrected_offsets = np.where(valid_access_masks, offsets, 0) return { - 'tensor': tensor, - 'offsets': offsets, - 'masks': masks, - 'valid_access_masks': valid_access_masks & masks, - 'invalid_access_masks': invalid_access_masks, - 'corrected_offsets': corrected_offsets, + "tensor": tensor, + "offsets": offsets, + "masks": masks, + "valid_access_masks": valid_access_masks & masks, + "invalid_access_masks": invalid_access_masks, + "corrected_offsets": corrected_offsets, } @@ -25,9 +27,12 @@ def check_storage_contiguous(tensor: torch.Tensor): # 1. Sort strides from smallest to largest # 2. If the tensor is contiguous, the stride product should be the same of the shape product of all previous dimensions from triton.runtime.jit import TensorWrapper + if isinstance(tensor, TensorWrapper): tensor = tensor.base - assert type(tensor) == torch.Tensor, f"Only torch.Tensor is supported, but found {type(tensor)}" + assert ( + type(tensor) == torch.Tensor + ), f"Only torch.Tensor is supported, but found {type(tensor)}" shape_prod = 1 indices = sorted(range(len(tensor.stride())), key=tensor.stride().__getitem__) for i, index in enumerate(indices): @@ -40,9 +45,11 @@ def check_storage_contiguous(tensor: torch.Tensor): shape_prod *= shape return True + def check_inner_stride_equal_to_one(tensor: torch.Tensor): return sorted(tensor.stride())[0] == 1 + def get_physical_addr_from_tensor_slice(tensor: torch.Tensor): if sorted(tensor.stride())[0] != 1: raise ValueError("inner dim must be contiguous!") @@ -52,9 +59,14 @@ def get_physical_addr_from_tensor_slice(tensor: torch.Tensor): segments = [] for idxs in itertools.product(*(range(tensor.size(d)) for d in outer_dims)): - offset = tensor.storage_offset() + sum(idx * tensor.stride(d) for idx, d in zip(idxs, outer_dims)) - segments.append(( - tensor.data_ptr() + offset * tensor.element_size(), - tensor.data_ptr() + (offset + tensor.size(inner_dim) - 1) * tensor.element_size() - )) + offset = tensor.storage_offset() + sum( + idx * tensor.stride(d) for idx, d in zip(idxs, outer_dims) + ) + segments.append( + ( + tensor.data_ptr() + offset * tensor.element_size(), + tensor.data_ptr() + + (offset + tensor.size(inner_dim) - 1) * tensor.element_size(), + ) + ) return segments diff --git a/triton_viz/core/__init__.py b/triton_viz/core/__init__.py index da681b32..e6a5e16b 100644 --- a/triton_viz/core/__init__.py +++ b/triton_viz/core/__init__.py @@ -1,15 +1,58 @@ from .trace import trace, clear from .data import ( - Op, ProgramId, RawStore, Store, RawLoad, - Load, UnaryOp, BinaryOp, TernaryOp, Dot, MakeRange, AddPtr, - ExpandDims, Broadcast, Reduce, ReduceSum, ReduceMax, - ReduceMin, Splat, MakeBlockPointer, TensorPointerLoad, - TensorPointerStore, Idiv, Rsqrt, CastImpl) + Op, + ProgramId, + RawStore, + Store, + RawLoad, + Load, + UnaryOp, + BinaryOp, + TernaryOp, + Dot, + MakeRange, + AddPtr, + ExpandDims, + Broadcast, + Reduce, + ReduceSum, + ReduceMax, + ReduceMin, + Splat, + MakeBlockPointer, + TensorPointerLoad, + TensorPointerStore, + Idiv, + Rsqrt, + CastImpl, +) __all__ = [ - "trace", "clear", "Op", "ProgramId", "RawStore", - "Store", "RawLoad", "Load", "UnaryOp", "BinaryOp", "TernaryOp", - "Dot", "MakeRange", "AddPtr", "ExpandDims", "Broadcast", "Reduce", - "ReduceSum", "ReduceMax", "ReduceMin", "Splat", "MakeBlockPointer", - "TensorPointerLoad", "TensorPointerStore", "Idiv", "Rsqrt", "CastImpl", + "trace", + "clear", + "Op", + "ProgramId", + "RawStore", + "Store", + "RawLoad", + "Load", + "UnaryOp", + "BinaryOp", + "TernaryOp", + "Dot", + "MakeRange", + "AddPtr", + "ExpandDims", + "Broadcast", + "Reduce", + "ReduceSum", + "ReduceMax", + "ReduceMin", + "Splat", + "MakeBlockPointer", + "TensorPointerLoad", + "TensorPointerStore", + "Idiv", + "Rsqrt", + "CastImpl", ] diff --git a/triton_viz/core/client.py b/triton_viz/core/client.py index 2c0d0591..42d17a44 100644 --- a/triton_viz/core/client.py +++ b/triton_viz/core/client.py @@ -23,7 +23,9 @@ def grid_idx_callback(self, grid_idx: Tuple[int]): pass @abstractmethod - def register_op_callback(self, op: Type[Op]) -> Tuple[Optional[Callable], Optional[Callable]]: + def register_op_callback( + self, op: Type[Op] + ) -> Tuple[Optional[Callable], Optional[Callable]]: pass @abstractmethod @@ -50,7 +52,11 @@ def patch(self): with patch_calls(): for client in self.clients: for op in op_list: - before_callback, after_callback, op_overrider = client.register_op_callback(op) + ( + before_callback, + after_callback, + op_overrider, + ) = client.register_op_callback(op) patch_op(op, before_callback, after_callback, op_overrider) try: yield diff --git a/triton_viz/core/config.py b/triton_viz/core/config.py index 59b84c9c..86de6193 100644 --- a/triton_viz/core/config.py +++ b/triton_viz/core/config.py @@ -4,6 +4,7 @@ # Back-end options recognised by the sanitizer AVAILABLE_SANITIZER_BACKENDS = ("off", "brute_force", "symexec") + class Config(types.ModuleType): def __init__(self, name: str) -> None: super().__init__(name) @@ -60,5 +61,6 @@ def report_grid_execution_progress(self, flag: bool) -> None: def available_backends(self): return AVAILABLE_SANITIZER_BACKENDS + # Replace the current module object with a live Config instance sys.modules[__name__] = Config(__name__) diff --git a/triton_viz/core/patch.py b/triton_viz/core/patch.py index 5f1710ab..2f29195b 100644 --- a/triton_viz/core/patch.py +++ b/triton_viz/core/patch.py @@ -5,13 +5,31 @@ from . import config as cfg from .data import ( - Op, RawLoad, Load, RawStore, Store, - UnaryOp, BinaryOp, TernaryOp, ProgramId, - Dot, MakeRange, AddPtr, ReduceSum, - Splat, ExpandDims, Broadcast, ReduceMax, ReduceMin, - MakeBlockPointer, TensorPointerLoad, - TensorPointerStore, Idiv, Rsqrt, - CastImpl) + Op, + RawLoad, + Load, + RawStore, + Store, + UnaryOp, + BinaryOp, + TernaryOp, + ProgramId, + Dot, + MakeRange, + AddPtr, + ReduceSum, + Splat, + ExpandDims, + Broadcast, + ReduceMax, + ReduceMin, + MakeBlockPointer, + TensorPointerLoad, + TensorPointerStore, + Idiv, + Rsqrt, + CastImpl, +) import inspect from triton.runtime.interpreter import ( GridExecutor, @@ -22,12 +40,29 @@ from triton.runtime import JITFunction op_list = [ - ProgramId, RawStore, Store, RawLoad, Load, - UnaryOp, BinaryOp, TernaryOp, Dot, MakeRange, - AddPtr, Splat, ExpandDims, Broadcast, - ReduceMax, ReduceMin, ReduceSum, - MakeBlockPointer, TensorPointerLoad, - TensorPointerStore, Idiv, Rsqrt, CastImpl, + ProgramId, + RawStore, + Store, + RawLoad, + Load, + UnaryOp, + BinaryOp, + TernaryOp, + Dot, + MakeRange, + AddPtr, + Splat, + ExpandDims, + Broadcast, + ReduceMax, + ReduceMin, + ReduceSum, + MakeBlockPointer, + TensorPointerLoad, + TensorPointerStore, + Idiv, + Rsqrt, + CastImpl, ] original_ops = { ProgramId: interpreter_builder.create_get_program_id, @@ -55,7 +90,7 @@ reduce_map: Dict[Type[Op], Callable] = { ReduceMax: tl.max, ReduceMin: tl.min, - ReduceSum: tl.sum + ReduceSum: tl.sum, } @@ -75,7 +110,10 @@ def __call__(self, *args, **kwargs): # see triton.runtime.interpreter:ReduceOps.sum # First, convert input from tl.tensor to TensorHandle. Here, input tensor is args[0] # Then, convert return value from TensorHandle to tl.tensor - ret = tl.core.tensor(self.op_overrider(args[0].handle, *args[1:], **kwargs), args[0].dtype) + ret = tl.core.tensor( + self.op_overrider(args[0].handle, *args[1:], **kwargs), + args[0].dtype, + ) else: ret = self.op_overrider(*args, **kwargs) else: @@ -86,7 +124,12 @@ def __call__(self, *args, **kwargs): return ret -def patch_op(op_type: Type[Op], before_callback: Callable, after_callback: Callable, op_overrider: Callable): +def patch_op( + op_type: Type[Op], + before_callback: Callable, + after_callback: Callable, + op_overrider: Callable, +): """ Register a callback to be called before and after an operator is executed. @@ -98,12 +141,20 @@ def patch_op(op_type: Type[Op], before_callback: Callable, after_callback: Calla # create a new function that calls the before_callback, the original op and the after_callback op_name = original_ops[op_type].__name__ current_op = getattr(interpreter_builder, op_name) - patched_op = PatchOp(current_op, op_type, before_callback, after_callback, op_overrider) - setattr(interpreter_builder, op_name, lambda *args, **kwargs: patched_op(*args, **kwargs)) + patched_op = PatchOp( + current_op, op_type, before_callback, after_callback, op_overrider + ) + setattr( + interpreter_builder, + op_name, + lambda *args, **kwargs: patched_op(*args, **kwargs), + ) elif op_type in reduce_map: op_name = reduce_map[op_type].__name__ current_op = getattr(tl, op_name) - patched_op = PatchOp(current_op, op_type, before_callback, after_callback, op_overrider) + patched_op = PatchOp( + current_op, op_type, before_callback, after_callback, op_overrider + ) setattr(tl, op_name, lambda *args, **kwargs: patched_op(*args, **kwargs)) else: raise ValueError(f"Patching operator {op_type} not supported") @@ -135,22 +186,41 @@ def _unpatch_lang(): def _grid_executor_call(self, *args_dev, **kwargs): if kwargs.pop("warmup", False): return + def run_grid_loops(): - for x in tqdm(range(grid[0]), desc='Grid X', leave=False, disable=not cfg.report_grid_execution_progress): - for y in tqdm(range(grid[1]), desc='Grid Y', leave=False, disable=not (cfg.report_grid_execution_progress and grid[1] > 1)): - for z in tqdm(range(grid[2]), desc='Grid Z', leave=False, disable=not (cfg.report_grid_execution_progress and grid[2] > 1)): + for x in tqdm( + range(grid[0]), + desc="Grid X", + leave=False, + disable=not cfg.report_grid_execution_progress, + ): + for y in tqdm( + range(grid[1]), + desc="Grid Y", + leave=False, + disable=not (cfg.report_grid_execution_progress and grid[1] > 1), + ): + for z in tqdm( + range(grid[2]), + desc="Grid Z", + leave=False, + disable=not (cfg.report_grid_execution_progress and grid[2] > 1), + ): interpreter_builder.set_grid_idx(x, y, z) client_manager.grid_idx_callback((x, y, z)) self.fn(**call_args) # if symbolic execution, only do one iteration if cfg.sanitizer_backend == "symexec": return + # Removes not used reserved keywords from kwargs # Triton doesn't support keyword-only, variable positional or variable keyword arguments # It's safe to inspect only positional or keyword arguments (i.e., argspec.args) argspec = inspect.getfullargspec(self.fn) triton_viz_args = ["client_manager"] - kwargs = {k: v for k, v in kwargs.items() if k in argspec.args or k in triton_viz_args} + kwargs = { + k: v for k, v in kwargs.items() if k in argspec.args or k in triton_viz_args + } client_manager = kwargs.pop("client_manager") args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs) # Remaps core language functions to interpreted ones diff --git a/triton_viz/core/trace.py b/triton_viz/core/trace.py index 9783bfb6..7e20ca2a 100644 --- a/triton_viz/core/trace.py +++ b/triton_viz/core/trace.py @@ -2,8 +2,7 @@ from triton.runtime.interpreter import InterpretedFunction from triton import JITFunction -import os -from typing import Tuple, Union +from typing import Union from . import config as cfg from ..clients import Sanitizer, Profiler, Tracer @@ -15,7 +14,6 @@ class Trace(KernelInterface): - @staticmethod def _normalize_client(client: Union[str, Client]) -> Client: if isinstance(client, str): @@ -89,6 +87,7 @@ def decorator(kernel) -> Trace: # If the object is neither a JITFunction nor Trace, raise an error raise TypeError(f"Expected JITFunction, got {type(kernel)}") + return decorator diff --git a/triton_viz/wrapper.py b/triton_viz/wrapper.py index ec59330a..956ff071 100644 --- a/triton_viz/wrapper.py +++ b/triton_viz/wrapper.py @@ -7,21 +7,26 @@ # store the original triton.jit _original_jit = triton.jit + def sanitizer_wrapper(kernel): abort_on_error = True tracer = triton_viz.trace(clients=Sanitizer(abort_on_error=abort_on_error)) return tracer(kernel) + def _patched_jit(fn=None, **jit_kw): - if fn is None: # @triton.jit(**opts) + if fn is None: # @triton.jit(**opts) + def _decorator(f): k = _original_jit(**jit_kw)(f) return sanitizer_wrapper(k) + return _decorator - else: # @triton.jit + else: # @triton.jit k = _original_jit(fn) return sanitizer_wrapper(k) + def apply(): """ Apply the sanitizer wrapper to triton.jit and run the user script. @@ -30,6 +35,7 @@ def apply(): triton.jit = _patched_jit triton.language.jit = _patched_jit import triton.runtime.interpreter as _interp + _interp.jit = _patched_jit # run user script From cc28e63bd721a588e3e9aeee66943eb1204dd320 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 4 Jun 2025 14:14:15 -0400 Subject: [PATCH 2/3] precommit analysis.py --- triton_viz/visualizer/analysis.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/triton_viz/visualizer/analysis.py b/triton_viz/visualizer/analysis.py index 5bd7153f..d8b29197 100644 --- a/triton_viz/visualizer/analysis.py +++ b/triton_viz/visualizer/analysis.py @@ -14,21 +14,34 @@ def analyze_records(records): record_data = [["Grid Size", tuple(grid_size)]] elif isinstance(record, OpTypeCounts): op_type_counts = record.type_counts - record_data += [[op_type, count] for op_type, count in op_type_counts.items()] + record_data += [ + [op_type, count] for op_type, count in op_type_counts.items() + ] elif isinstance(record, LoadStoreBytes): + def calculate_ratio(record): - return (record.total_bytes_true / record.total_bytes_attempted - if record.total_bytes_attempted > 0 else 0) + return ( + record.total_bytes_true / record.total_bytes_attempted + if record.total_bytes_attempted > 0 + else 0 + ) + if record.type == "load": total_load_bytes_true = record.total_bytes_true overall_load_ratio = calculate_ratio(record) - record_data.append(["Total number of bytes loaded", total_load_bytes_true]) + record_data.append( + ["Total number of bytes loaded", total_load_bytes_true] + ) record_data.append(["Masked Load Ratio", round(overall_load_ratio, 3)]) elif record.type == "store": total_store_bytes_true = record.total_bytes_true overall_store_ratio = calculate_ratio(record) - record_data.append(["Total number of bytes stored", total_store_bytes_true]) - record_data.append(["Masked Store Ratio", round(overall_store_ratio, 3)]) + record_data.append( + ["Total number of bytes stored", total_store_bytes_true] + ) + record_data.append( + ["Masked Store Ratio", round(overall_store_ratio, 3)] + ) if record_data is not None: data.append(record_data) From d4dabbcd10c92a51d854441fb8dc76659b998c34 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 4 Jun 2025 14:14:35 -0400 Subject: [PATCH 3/3] ignore workflow for now --- .github/workflows/python-app.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 220cd3af..3bc0f3e1 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -5,9 +5,11 @@ name: Python application on: push: - branches: [] + branches-ignore: + - '**' pull_request: - branches: [] + branches-ignore: + - '**' permissions: contents: read