Skip to content

[DEV] Indirect load #75

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: keren/v2.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions tests/test_indirect_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import triton
import triton.language as tl

import triton_viz
from triton_viz.clients import Sanitizer
from triton_viz import config as cfg


cfg.sanitizer_backend = "symexec"

@triton_viz.trace(clients=Sanitizer(abort_on_error=True))
@triton.jit
def indirect_load_kernel(idx_ptr, src_ptr, dst_ptr, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)

indices = tl.load(idx_ptr + offsets)
out_val = tl.load(src_ptr + indices)
tl.store(dst_ptr + offsets, out_val)

def test_indirect_load_inrange():
idx = torch.arange(128, dtype=torch.int32)
src = torch.rand(128)
dst = torch.empty_like(src)

grid = lambda META: (triton.cdiv(128, META['BLOCK_SIZE']),)
indirect_load_kernel[grid](idx, src, dst, BLOCK_SIZE=32)

# assert torch.allclose(dst, src), "indirect load failed"
print("✓ test passed")


# def test_indirect_load_out_of_bound():
# """
# This test deliberately sets n_elements = 256, exceeding the actual buffer size (128).
# It will likely cause out-of-bound reads/writes, which may trigger errors or warnings.
# """
# x = torch.arange(128, device="cuda", dtype=torch.int32)
# out = torch.empty_like(x)

# # The kernel launch uses n_elements=256, which exceeds the size of x.
# grid = lambda META: (triton.cdiv(256, META["BLOCK_SIZE"]),)
# indirect_load_kernel[grid](x_ptr=x, out_ptr=out, n_elements=256)

# print("test_indirect_load_out_of_bound() passed: Out-of-bound access detected.")
44 changes: 43 additions & 1 deletion triton_viz/clients/sanitizer/sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,11 @@ def __init__(self, op, *args):
self.dtype_tt = None
self.shape = []
self.children = {} # Used for storing child expressions
# check if the number of arguments is correct
# Functions and arguments for concretization
self._concrete_fn = None
self._concrete_args = ()
self._concrete_kwargs = {}
# leaf nodes
if self.op == "const":
self.value = args[0]
if len(args) >= 2:
Expand Down Expand Up @@ -755,6 +759,38 @@ def _to_z3(self, node):
# Other operations can be implemented as needed
raise NotImplementedError(f"Eval for op {node.op} is not implemented")

def has_op(self, op_name: str) -> bool:
if self.op == op_name:
return True
for child_key, child_symbolic_expr in self.children.items():
if child_symbolic_expr.has_op(op_name):
return True
return False

@staticmethod
def _concretize_item(obj):
return obj.concretize() if isinstance(obj, SymbolicExpr) else obj

def concretize(self):
"""
Concretize the symbolic expression into a concrete value.
This is used to evaluate the symbolic expression and return a concrete value.
"""
if self.op == "splat":
print("op:", self.op)
print("arg:", self._concrete_args)
print("kwargs:", self._concrete_kwargs)
if self.op == "const":
return self.value
if self._concrete_fn is None:
raise RuntimeError("Concrete function is not set for this SymbolicExpr.")
new_args = [self._concretize_item(a) for a in self._concrete_args]
new_kw = {k: self._concretize_item(v) for k, v in self._concrete_kwargs.items()}
return self._concrete_fn(*new_args, **new_kw)

class ConstTupleExpr(SymbolicExpr):
def __init__(self, value):
super().__init__("const", tuple(value))


class ConstTupleExpr(SymbolicExpr):
Expand Down Expand Up @@ -840,6 +876,12 @@ def op_raw_load_overrider(ptr, cache_modifier, eviction_policy, is_volatile):
def op_load_overrider(
ptr, mask, other, cache_modifier, eviction_policy, is_volatile
):
# deal with indirect loads
if isinstance(ptr, SymbolicExpr) and ptr.has_op("load"):
print("indirect loading:", ptr)
# ptr = ptr.concretize()
# print('concretized ptr:', ptr)

# make sure ptr is a SymbolicExpr
if isinstance(ptr, TensorHandle) and isinstance(ptr.dtype, tl.pointer_type):
ptr = SymbolicExpr("load", SymbolicExpr.from_value(ptr))
Expand Down
6 changes: 6 additions & 0 deletions triton_viz/core/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ def __call__(self, *args, **kwargs):
)
else:
ret = self.op_overrider(*args, **kwargs)
from ..clients.sanitizer.sanitizer import SymbolicExpr

if isinstance(ret, SymbolicExpr):
ret._concrete_fn = self.op
ret._concrete_args = args
ret._concrete_kwargs = kwargs
else:
ret = self.op(*args, **kwargs)
if self.after_callback:
Expand Down
Loading