Skip to content

[CHORE] Repo-wide style cleanup #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ name: Python application

on:
push:
branches: []
branches-ignore:
- '**'
pull_request:
branches: []
branches-ignore:
- '**'

permissions:
contents: read
Expand Down
13 changes: 7 additions & 6 deletions examples/elementwise_add_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
# Custom kernel for element-wise addition with autotuning
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 32}, num_warps=1),
triton.Config({'BLOCK_SIZE': 64}, num_warps=2),
triton.Config({'BLOCK_SIZE': 128}, num_warps=4),
triton.Config({"BLOCK_SIZE": 32}, num_warps=1),
triton.Config({"BLOCK_SIZE": 64}, num_warps=2),
triton.Config({"BLOCK_SIZE": 128}, num_warps=4),
],
key=['n_elements'], # Key for selecting the optimal config based on input size
key=["n_elements"], # Key for selecting the optimal config based on input size
warmup=5,
rep=5,
)
Expand All @@ -37,10 +37,11 @@ def elementwise_add_kernel(
# Store the result
tl.store(output_ptr + offsets, output, mask=mask)


# Create PyTorch tensors as input
n_elements = 100000
x = torch.randn(n_elements, device='cuda')
y = torch.randn(n_elements, device='cuda')
x = torch.randn(n_elements, device="cuda")
y = torch.randn(n_elements, device="cuda")
output = torch.empty_like(x)

# Launch the Triton kernel
Expand Down
10 changes: 6 additions & 4 deletions examples/load_and_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton_viz
from triton_viz.clients import Sanitizer


@triton_viz.trace(clients=Sanitizer(abort_on_error=True))
@triton.jit
def simple_kernel(X_ptr, Y_ptr, BLOCK_SIZE: tl.constexpr):
Expand All @@ -15,18 +16,19 @@ def simple_kernel(X_ptr, Y_ptr, BLOCK_SIZE: tl.constexpr):

tl.store(Y_ptr + idx, x)

if __name__ == '__main__':

if __name__ == "__main__":
BLOCK_SIZE = 1024
n_elements = 512

# Create input and output tensors
X = torch.arange(n_elements, dtype=torch.float32, device='cuda')
Y = torch.empty_like(X, device='cuda')
X = torch.arange(n_elements, dtype=torch.float32, device="cuda")
Y = torch.empty_like(X, device="cuda")

# Launch the Triton kernel
grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),)
simple_kernel[grid](X, Y, BLOCK_SIZE=BLOCK_SIZE)

# Verify the results
print("Input tensor X:", X)
print("Output tensor Y:", Y)
19 changes: 14 additions & 5 deletions tests/test_non_contiguous.py → examples/test_non_contiguous.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import torch
from triton_viz.clients.utils import check_storage_contiguous, get_physical_addr_from_tensor_slice
from triton_viz.clients.utils import (
check_storage_contiguous,
get_physical_addr_from_tensor_slice,
)


def test_transpose():
Expand All @@ -8,20 +11,26 @@ def test_transpose():
print(b)
print([(0, b.numel() - 1)])


def test_2d_slice():
a = torch.arange(25).view(5, 5)
b = a[1:4, 1:4]
print('b:', b)
print("b:", b)
print("is_contiguous:", check_storage_contiguous(b))
segments = get_physical_addr_from_tensor_slice(b)
for start, end in segments:
print(f"[{(start - b.data_ptr()) / b.element_size()}, {(end - b.data_ptr()) / b.element_size()}]")
print(
f"[{(start - b.data_ptr()) / b.element_size()}, {(end - b.data_ptr()) / b.element_size()}]"
)


def test_3d_slice():
a = torch.arange(125).view(5, 5, 5)
b = a[1:4, 1:4, 1:4]
print('b:', b)
print("b:", b)
print("is_contiguous:", check_storage_contiguous(b))
segments = get_physical_addr_from_tensor_slice(b)
for start, end in segments:
print(f"[{(start - b.data_ptr()) / b.element_size()}, {(end - b.data_ptr()) / b.element_size()}]")
print(
f"[{(start - b.data_ptr()) / b.element_size()}, {(end - b.data_ptr()) / b.element_size()}]"
)
70 changes: 48 additions & 22 deletions examples/test_symbolic_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,66 +12,80 @@ def add_kernel(x):
pid = tl.program_id(0)
addr = x + pid
tl.load(addr)
a = torch.randn(16, dtype=torch.float32, device='cuda')

a = torch.randn(16, dtype=torch.float32, device="cuda")
add_kernel[(2,)](a)


def test_tl_make_range():
@triton_viz.trace(clients=Sanitizer(abort_on_error=True))
@triton.jit
def make_range_kernel(x, BLOCK_SIZE: tl.constexpr):
tl.load(x)
offset = x + tl.arange(0, BLOCK_SIZE)
tl.load(offset)
a = torch.randn(16, dtype=torch.float32, device='cuda')

a = torch.randn(16, dtype=torch.float32, device="cuda")
make_range_kernel[(1,)](a, BLOCK_SIZE=16)


def test_tl_add():
@triton_viz.trace(clients=Sanitizer(abort_on_error=True))
@triton.jit
def program_id_kernel(x):
addr = x + 1
tl.load(addr)
a = torch.randn(16, dtype=torch.float32, device='cuda')

a = torch.randn(16, dtype=torch.float32, device="cuda")
program_id_kernel[(2,)](a)


def test_tl_sub():
@triton_viz.trace(clients=Sanitizer(abort_on_error=True))
@triton.jit
def sub_kernel(x):
addr = x - 1
tl.load(addr)
a = torch.randn(16, dtype=torch.float32, device='cuda')

a = torch.randn(16, dtype=torch.float32, device="cuda")
sub_kernel[(2,)](a)


def test_tl_mul():
@triton_viz.trace(clients=Sanitizer(abort_on_error=True))
@triton.jit
def mul_kernel(x, BLOCK_SIZE: tl.constexpr):
addr = x + (tl.arange(0, BLOCK_SIZE) * 2)
tl.load(addr)
a = torch.randn(32, dtype=torch.float32, device='cuda')

a = torch.randn(32, dtype=torch.float32, device="cuda")
mul_kernel[(1,)](a, BLOCK_SIZE=16)


def test_tl_div():
@triton_viz.trace(clients=Sanitizer(abort_on_error=True))
@triton.jit
def div_kernel(x, BLOCK_SIZE: tl.constexpr):
tl.load(x)
tl.load(x + (tl.arange(0, BLOCK_SIZE) // 2))
tl.load(x + tl.arange(0, BLOCK_SIZE))
a = torch.randn(32, dtype=torch.float32, device='cuda')

a = torch.randn(32, dtype=torch.float32, device="cuda")
div_kernel[(1,)](a, BLOCK_SIZE=16)


def test_tl_mod():
@triton_viz.trace(clients=Sanitizer(abort_on_error=True))
@triton.jit
def mod_kernel(x, BLOCK_SIZE: tl.constexpr):
tl.load(x)
tl.load(x + (tl.arange(0, BLOCK_SIZE) % 10))
tl.load(x + tl.arange(0, BLOCK_SIZE))
a = torch.randn(32, dtype=torch.float32, device='cuda')

a = torch.randn(32, dtype=torch.float32, device="cuda")
mod_kernel[(1,)](a, BLOCK_SIZE=16)


def test_vec_add():
@triton_viz.trace(clients=Sanitizer(abort_on_error=True))
@triton.jit
Expand All @@ -87,12 +101,13 @@ def add_kernel(x_ptr, y_ptr, output_ptr, BLOCK_SIZE: tl.constexpr):
access_size = 24
size = 17
BLOCK_SIZE = 8
a = torch.randn(size, dtype=torch.float32, device='cuda')
b = torch.randn(size, dtype=torch.float32, device='cuda')
output = torch.empty_like(a, device='cuda')
a = torch.randn(size, dtype=torch.float32, device="cuda")
b = torch.randn(size, dtype=torch.float32, device="cuda")
output = torch.empty_like(a, device="cuda")
grid = lambda meta: (triton.cdiv(access_size, meta["BLOCK_SIZE"]),)
add_kernel[grid](a, b, output, BLOCK_SIZE=BLOCK_SIZE)


def test_vec_add_mask():
@triton_viz.trace(clients=Sanitizer(abort_on_error=True))
@triton.jit
Expand All @@ -109,21 +124,26 @@ def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
access_size = 24
size = 17
BLOCK_SIZE = 8
a = torch.randn(size, dtype=torch.float32, device='cuda')
b = torch.randn(size, dtype=torch.float32, device='cuda')
output = torch.empty_like(a, device='cuda')
a = torch.randn(size, dtype=torch.float32, device="cuda")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those devices shouldn't be set to cuda

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If TRITON_INTERPRET=1, it will be fine to set devices to cuda.
However triton-viz cannot work with TRITON_INTERPRET=1 now.

b = torch.randn(size, dtype=torch.float32, device="cuda")
output = torch.empty_like(a, device="cuda")
grid = lambda meta: (triton.cdiv(access_size, meta["BLOCK_SIZE"]),)
add_kernel[grid](a, b, output, size, BLOCK_SIZE=BLOCK_SIZE)


def test_new_axis_column():
@triton_viz.trace(clients=Sanitizer(abort_on_error=True))
@triton.jit
def new_axis_kernel(out_ptr, BLOCK_ROW_SIZE: tl.constexpr):
pid = out_ptr + tl.program_id(0) * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[:, None]
pid = (
out_ptr
+ tl.program_id(0) * BLOCK_ROW_SIZE
+ tl.arange(0, BLOCK_ROW_SIZE)[:, None]
)
tl.load(pid)

BLOCK_ROW_SIZE = 8
out = torch.empty((BLOCK_ROW_SIZE, 1), dtype=torch.int32, device='cuda')
out = torch.empty((BLOCK_ROW_SIZE, 1), dtype=torch.int32, device="cuda")
grid = lambda meta: (1,)
new_axis_kernel[grid](out, BLOCK_ROW_SIZE=BLOCK_ROW_SIZE)

Expand All @@ -132,14 +152,19 @@ def test_new_axis_row():
@triton_viz.trace(clients=Sanitizer(abort_on_error=True))
@triton.jit
def new_axis_kernel(out_ptr, BLOCK_ROW_SIZE: tl.constexpr):
pid = out_ptr + tl.program_id(0) * BLOCK_ROW_SIZE + tl.arange(0, BLOCK_ROW_SIZE)[None, :]
pid = (
out_ptr
+ tl.program_id(0) * BLOCK_ROW_SIZE
+ tl.arange(0, BLOCK_ROW_SIZE)[None, :]
)
tl.load(pid)

BLOCK_ROW_SIZE = 8
out = torch.empty((BLOCK_ROW_SIZE, 1), dtype=torch.int32, device='cuda')
out = torch.empty((BLOCK_ROW_SIZE, 1), dtype=torch.int32, device="cuda")
grid = lambda meta: (1,)
new_axis_kernel[grid](out, BLOCK_ROW_SIZE=BLOCK_ROW_SIZE)


def test_tl_maximum():
@triton_viz.trace(clients=Sanitizer(abort_on_error=True))
@triton.jit
Expand All @@ -158,12 +183,13 @@ def maximum_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):

size = 20
BLOCK_SIZE = 8
a = torch.randn(size, dtype=torch.float32, device='cuda')
b = torch.randn(size, dtype=torch.float32, device='cuda')
out = torch.empty_like(a, device='cuda')
a = torch.randn(size, dtype=torch.float32, device="cuda")
b = torch.randn(size, dtype=torch.float32, device="cuda")
out = torch.empty_like(a, device="cuda")
grid = lambda meta: (triton.cdiv(size, meta["BLOCK_SIZE"]),)
maximum_kernel[grid](a, b, out, size, BLOCK_SIZE=BLOCK_SIZE)


def test_tl_log():
@triton_viz.trace(clients=Sanitizer(abort_on_error=True))
@triton.jit
Expand All @@ -180,7 +206,7 @@ def log_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
BLOCK_SIZE = 8
eps = 0.01

a = torch.rand(size, dtype=torch.float32, device='cuda') + eps
out = torch.empty_like(a, device='cuda')
a = torch.rand(size, dtype=torch.float32, device="cuda") + eps
out = torch.empty_like(a, device="cuda")
grid = lambda meta: (triton.cdiv(size, meta["BLOCK_SIZE"]),)
log_kernel[grid](a, out, size, BLOCK_SIZE=BLOCK_SIZE)
10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
setuptools
triton
gradio
chalk-diagrams @ git+https://github.com/chalk-diagrams/chalk.git
pyarrow
gradio
pre-commit
pytest
pyarrow
pytest
setuptools
triton
35 changes: 14 additions & 21 deletions tests/test_autotune_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@

cfg.sanitizer_backend = "symexec"


@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 32}, num_warps=1),
triton.Config({'BLOCK_SIZE': 64}, num_warps=2),
triton.Config({"BLOCK_SIZE": 32}, num_warps=1),
triton.Config({"BLOCK_SIZE": 64}, num_warps=2),
],
key=['n_elements'],
key=["n_elements"],
)
@triton_viz.trace(clients=Sanitizer(abort_on_error=True))
@triton.jit
Expand All @@ -32,43 +33,35 @@ def add_kernel_no_mask(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constex
y_val = tl.load(y_ptr + offsets)
tl.store(out_ptr + offsets, x_val + y_val)


def test_autotune_add_inrange():
"""
This test uses n_elements = 128, matching the size of the input tensors.
It should NOT cause any out-of-bound access.
"""
x = torch.randn(128, device='cuda')
y = torch.randn(128, device='cuda')
x = torch.randn(128, device="cuda")
y = torch.randn(128, device="cuda")
out = torch.empty_like(x)

# The kernel launch uses n_elements=128, aligned with the tensor size.
grid = lambda META: (triton.cdiv(128, META['BLOCK_SIZE']),)
add_kernel_no_mask[grid](
x_ptr=x,
y_ptr=y,
out_ptr=out,
n_elements=128
)
grid = lambda META: (triton.cdiv(128, META["BLOCK_SIZE"]),)
add_kernel_no_mask[grid](x_ptr=x, y_ptr=y, out_ptr=out, n_elements=128)

print("test_autotune_add_inrange() passed: No out-of-bound access.")


def test_autotune_add_out_of_bound():
"""
This test deliberately sets n_elements = 256, exceeding the actual buffer size (128).
It will likely cause out-of-bound reads/writes, which may trigger errors or warnings.
"""
x = torch.randn(128, device='cuda')
y = torch.randn(128, device='cuda')
x = torch.randn(128, device="cuda")
y = torch.randn(128, device="cuda")
out = torch.empty_like(x)

# The kernel launch uses n_elements=256, exceeding the valid tensor size.
grid = lambda META: (triton.cdiv(256, META['BLOCK_SIZE']),)
add_kernel_no_mask[grid](
x_ptr=x,
y_ptr=y,
out_ptr=out,
n_elements=256
)
grid = lambda META: (triton.cdiv(256, META["BLOCK_SIZE"]),)
add_kernel_no_mask[grid](x_ptr=x, y_ptr=y, out_ptr=out, n_elements=256)

# Depending on hardware/drivers, this may or may not raise an error immediately.
print("test_autotune_add_oob() completed: Potential out-of-bound access occurred.")
Loading