From 33fe88df873e903cc60d42c79a578819870a3a01 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 11 Jun 2025 18:02:08 -0400 Subject: [PATCH 1/2] wip: indirect load --- tests/test_indirect_load.py | 47 +++++++++++++++++++++++ triton_viz/clients/sanitizer/sanitizer.py | 45 +++++++++++++++++++++- triton_viz/core/patch.py | 6 +++ 3 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 tests/test_indirect_load.py diff --git a/tests/test_indirect_load.py b/tests/test_indirect_load.py new file mode 100644 index 00000000..1d735b26 --- /dev/null +++ b/tests/test_indirect_load.py @@ -0,0 +1,47 @@ +import torch +import triton +import triton.language as tl + +import triton_viz +from triton_viz.clients import Sanitizer +from triton_viz import config as cfg + + +cfg.sanitizer_backend = "symexec" + +@triton_viz.trace(clients=Sanitizer(abort_on_error=True)) +@triton.jit +def indirect_load_kernel(idx_ptr, src_ptr, dst_ptr, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + indices = tl.load(idx_ptr + offsets) + out_val = tl.load(src_ptr + indices) + tl.store(dst_ptr + offsets, out_val) + +def test_indirect_load_inrange(): + idx = torch.arange(128, device="cuda", dtype=torch.int32) + src = torch.rand(128, device="cuda") + dst = torch.empty_like(src) + + grid = lambda META: (triton.cdiv(128, META['BLOCK_SIZE']),) + indirect_load_kernel[grid](idx, src, dst, BLOCK_SIZE=32) + + # assert torch.allclose(dst, src), "indirect load failed" + print("✓ test passed") + + +# def test_indirect_load_out_of_bound(): +# """ +# This test deliberately sets n_elements = 256, exceeding the actual buffer size (128). +# It will likely cause out-of-bound reads/writes, which may trigger errors or warnings. +# """ +# x = torch.arange(128, device="cuda", dtype=torch.int32) +# out = torch.empty_like(x) + +# # The kernel launch uses n_elements=256, which exceeds the size of x. +# grid = lambda META: (triton.cdiv(256, META["BLOCK_SIZE"]),) +# indirect_load_kernel[grid](x_ptr=x, out_ptr=out, n_elements=256) + +# print("test_indirect_load_out_of_bound() passed: Out-of-bound access detected.") diff --git a/triton_viz/clients/sanitizer/sanitizer.py b/triton_viz/clients/sanitizer/sanitizer.py index 964e7230..09efa7e7 100644 --- a/triton_viz/clients/sanitizer/sanitizer.py +++ b/triton_viz/clients/sanitizer/sanitizer.py @@ -348,7 +348,12 @@ def __init__(self, op, *args): self.attrs = {} self.dtype_tt = None self.shape = [] - # check if the number of arguments is correct + self.children = {} # Used for storing child expressions + # Functions and arguments for concretization + self._concrete_fn = None + self._concrete_args = () + self._concrete_kwargs = {} + # leaf nodes if self.op == "const": self.value = args[0] if len(args) >= 2: @@ -738,6 +743,38 @@ def _to_z3(self, node): # Other operations can be implemented as needed raise NotImplementedError(f"Eval for op {node.op} is not implemented") + def has_op(self, op_name: str) -> bool: + if self.op == op_name: + return True + for child_key, child_symbolic_expr in self.children.items(): + if child_symbolic_expr.has_op(op_name): + return True + return False + + @staticmethod + def _concretize_item(obj): + return obj.concretize() if isinstance(obj, SymbolicExpr) else obj + + def concretize(self): + """ + Concretize the symbolic expression into a concrete value. + This is used to evaluate the symbolic expression and return a concrete value. + """ + if self.op == "splat": + print("op:", self.op) + print("arg:", self._concrete_args) + print("kwargs:", self._concrete_kwargs) + if self.op == "const": + return self.value + if self._concrete_fn is None: + raise RuntimeError("Concrete function is not set for this SymbolicExpr.") + new_args = [self._concretize_item(a) for a in self._concrete_args] + new_kw = {k: self._concretize_item(v) for k, v in self._concrete_kwargs.items()} + return self._concrete_fn(*new_args, **new_kw) + +class ConstTupleExpr(SymbolicExpr): + def __init__(self, value): + super().__init__("const", tuple(value)) class SanitizerSymbolicExecution(Client): def __init__(self, abort_on_error): @@ -818,6 +855,12 @@ def op_raw_load_overrider(ptr, cache_modifier, eviction_policy, is_volatile): def op_load_overrider( ptr, mask, other, cache_modifier, eviction_policy, is_volatile ): + # deal with indirect loads + if isinstance(ptr, SymbolicExpr) and ptr.has_op("load"): + print("indirect loading:", ptr) + # ptr = ptr.concretize() + # print('concretized ptr:', ptr) + # make sure ptr is a SymbolicExpr if isinstance(ptr, TensorHandle) and isinstance(ptr.dtype, tl.pointer_type): ptr = SymbolicExpr("load", SymbolicExpr.from_value(ptr)) diff --git a/triton_viz/core/patch.py b/triton_viz/core/patch.py index 2f29195b..2f82fdc7 100644 --- a/triton_viz/core/patch.py +++ b/triton_viz/core/patch.py @@ -116,6 +116,12 @@ def __call__(self, *args, **kwargs): ) else: ret = self.op_overrider(*args, **kwargs) + from ..clients.sanitizer.sanitizer import SymbolicExpr + + if isinstance(ret, SymbolicExpr): + ret._concrete_fn = self.op + ret._concrete_args = args + ret._concrete_kwargs = kwargs else: ret = self.op(*args, **kwargs) if self.after_callback: From 494cc0ce077d0b421a02234a0031248210f0cb32 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 11 Jun 2025 18:08:22 -0400 Subject: [PATCH 2/2] remove cuda dependencies in indirect load's unittest --- tests/test_indirect_load.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_indirect_load.py b/tests/test_indirect_load.py index 1d735b26..f6afe6fc 100644 --- a/tests/test_indirect_load.py +++ b/tests/test_indirect_load.py @@ -21,8 +21,8 @@ def indirect_load_kernel(idx_ptr, src_ptr, dst_ptr, BLOCK_SIZE: tl.constexpr): tl.store(dst_ptr + offsets, out_val) def test_indirect_load_inrange(): - idx = torch.arange(128, device="cuda", dtype=torch.int32) - src = torch.rand(128, device="cuda") + idx = torch.arange(128, dtype=torch.int32) + src = torch.rand(128) dst = torch.empty_like(src) grid = lambda META: (triton.cdiv(128, META['BLOCK_SIZE']),)