Skip to content

Add hl.inline_asm_elementwise #328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/api/exceptions.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ These exceptions occur when Helion language functions are used incorrectly with
.. autoclass:: TracedArgNotSupported

Raised for unsupported argument types in traced functions.

.. autoclass:: InvalidAPIUsage

Raised for incorrect usage of Helion API functions.
```

## Configuration Errors
Expand Down
6 changes: 5 additions & 1 deletion helion/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ class ErrorCompilingKernel(BaseError):


class NoTensorArgs(BaseError):
message = "Kernel took no tensor args, unclear what device to use."
message = "Kernel took no tensor or device args, unclear what device to use."


class _WrapException(BaseError):
Expand Down Expand Up @@ -327,3 +327,7 @@ class DeviceTensorSubscriptAssignmentNotAllowed(BaseError):

class InvalidSequenceSubscription(BaseError):
message = "Cannot subscript a sequence with non constant indices. Got '{0!s}'. "


class InvalidAPIUsage(BaseError):
message = "Invalid usage of Helion API: {0}"
1 change: 1 addition & 0 deletions helion/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .creation_ops import full as full
from .creation_ops import zeros as zeros
from .device_print import device_print as device_print
from .inline_asm_ops import inline_asm_elementwise as inline_asm_elementwise
from .loops import grid as grid
from .loops import tile as tile
from .memory_ops import atomic_add as atomic_add
Expand Down
207 changes: 207 additions & 0 deletions helion/language/inline_asm_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
from __future__ import annotations

import ast
from typing import TYPE_CHECKING
from typing import Sequence
from typing import overload

import torch
from torch._inductor.utils import triton_type

from .. import exc
from .._compiler.ast_extension import create
from .._compiler.ast_extension import expr_from_string
from . import _decorators

if TYPE_CHECKING:
from .._compiler.inductor_lowering import CodegenState

__all__ = ["inline_asm_elementwise"]


@overload
@_decorators.api(is_device_only=True)
def inline_asm_elementwise(
asm: str,
constraints: str,
args: Sequence[torch.Tensor],
dtype: torch.dtype,
is_pure: bool,
pack: int,
) -> torch.Tensor: ...


@overload
@_decorators.api(is_device_only=True)
def inline_asm_elementwise(
asm: str,
constraints: str,
args: Sequence[torch.Tensor],
dtype: Sequence[torch.dtype],
is_pure: bool,
pack: int,
) -> tuple[torch.Tensor, ...]: ...


@_decorators.api(is_device_only=True)
def inline_asm_elementwise(
asm: str,
constraints: str,
args: Sequence[torch.Tensor],
dtype: torch.dtype | Sequence[torch.dtype],
is_pure: bool,
pack: int,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""
Execute inline assembly over a tensor. Essentially, this is map
where the function is inline assembly.

The input tensors args are implicitly broadcasted to the same shape.
dtype can be a tuple of types, in which case the output is a
tuple of tensors.

Each invocation of the inline asm processes pack elements at a
time. Exactly which set of inputs a block receives is unspecified.
Input elements of size less than 4 bytes are packed into 4-byte
registers.

This op does not support empty dtype -- the inline asm must
return at least one tensor, even if you don't need it. You can work
around this by returning a dummy tensor of arbitrary type; it shouldn't
cost you anything if you don't use it.

Args:
asm: assembly to run. Must match target's assembly format.
constraints: asm constraints in LLVM format
args: the input tensors, whose values are passed to the asm block
dtype: the element type(s) of the returned tensor(s)
is_pure: if true, the compiler assumes the asm block has no side-effects
pack: the number of elements to be processed by one instance of inline assembly

Returns:
one tensor or a tuple of tensors of the given dtypes
"""
raise exc.NotInsideKernel


@_decorators.register_fake(inline_asm_elementwise)
def _(
asm: str,
constraints: str,
args: Sequence[torch.Tensor],
dtype: torch.dtype | Sequence[torch.dtype],
is_pure: bool,
pack: int,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
from .._compiler.compile_environment import CompileEnvironment

# Basic validation
if not isinstance(asm, str):
raise exc.InvalidAPIUsage(f"asm must be a string, got {type(asm)}")
if not isinstance(constraints, str):
raise exc.InvalidAPIUsage(
f"constraints must be a string, got {type(constraints)}"
)
if not isinstance(is_pure, bool):
raise exc.InvalidAPIUsage(f"is_pure must be a bool, got {type(is_pure)}")
if not isinstance(pack, int):
raise exc.InvalidAPIUsage(f"pack must be an int, got {type(pack)}")

# Determine if we have multiple outputs
if isinstance(dtype, (tuple, list)):
dtypes = list(dtype)
has_multiple_outputs = True
else:
dtypes = [dtype]
has_multiple_outputs = False

# Validate dtype(s)
for dt in dtypes:
if not isinstance(dt, torch.dtype):
raise exc.InvalidAPIUsage(f"dtype must be torch.dtype, got {type(dt)}")

# Broadcast all inputs to the same shape
if args:
broadcast_shape = args[0].shape
for arg in args[1:]:
if arg.shape != broadcast_shape:
broadcast_shape = torch.broadcast_shapes(broadcast_shape, arg.shape)
else:
# For empty args, we need to infer the shape from context
# The problem is that without input tensors, we can't determine the proper broadcast shape
# However, when used in a tile context, the output should match the tile shape
# For the fake function, we'll use a simple placeholder shape that the compiler will handle
broadcast_shape = (1,)

env = CompileEnvironment.current()
if has_multiple_outputs:
results = []
for dt in dtypes:
# Type assertion: dt is guaranteed to be torch.dtype due to validation above
assert isinstance(dt, torch.dtype)
result = torch.empty(broadcast_shape, dtype=dt, device=env.device)
results.append(result)
return tuple(results)

# Type assertion: dtypes[0] is guaranteed to be torch.dtype due to validation above
assert isinstance(dtypes[0], torch.dtype)
return torch.empty(broadcast_shape, dtype=dtypes[0], device=env.device)


@_decorators.codegen(inline_asm_elementwise)
def _(state: CodegenState) -> ast.AST | list[ast.AST]:
# Get arguments
asm_str = state.proxy_arg(0)
constraints_str = state.proxy_arg(1)
dtype = state.proxy_arg(3)
is_pure = state.proxy_arg(4)
pack = state.proxy_arg(5)

# Convert the list of tensor args to AST
# We need to create a proper list AST with the tensor elements
raw_args = state.ast_args[2]
if isinstance(raw_args, list):
# Create AST List node with the tensor elements
args_ast = create(ast.List, elts=raw_args, ctx=ast.Load())
else:
# If it's not a list, wrap it in a list (shouldn't normally happen)
args_ast = raw_args

# Convert dtype to Triton type string(s)
if isinstance(dtype, (tuple, list)):
dtype_strs = [triton_type(dt) for dt in dtype if isinstance(dt, torch.dtype)]
dtype_arg = f"({', '.join(dtype_strs)})" # Use tuple syntax for multiple dtypes
has_multiple_outputs = True
else:
dtype_arg = (
triton_type(dtype) if isinstance(dtype, torch.dtype) else "tl.float32"
)
has_multiple_outputs = False

# Create the call to tl.inline_asm_elementwise
inline_asm_call = create(
ast.Call,
func=expr_from_string("tl.inline_asm_elementwise"),
args=[
create(ast.Constant, value=asm_str),
create(ast.Constant, value=constraints_str),
args_ast,
expr_from_string(dtype_arg),
create(ast.Constant, value=is_pure),
create(ast.Constant, value=pack),
],
keywords=[],
)

# Handle multiple outputs by creating getitem expressions
if has_multiple_outputs:
assert isinstance(dtype, (tuple, list)) # Type guard for len()
num_outputs = len(dtype)
return [
expr_from_string(
f"inline_asm_result[{i}]", inline_asm_result=inline_asm_call
)
for i in range(num_outputs)
]

return inline_asm_call
152 changes: 152 additions & 0 deletions test/test_inline_asm_elementwise.expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
This file is automatically generated by assertExpectedJournal calls in test_inline_asm_elementwise.py.
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.

--- assertExpectedJournal(TestInlineAsmElementwise.test_inline_asm_basic_compilation)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _kernel_basic_kernel(x, result, x_size_0, result_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
result_val = tl.inline_asm_elementwise('mov.u32 $0, $1;', '=r,r', [load], tl.int32, True, 1)
tl.store(result + indices_0 * result_stride_0, result_val, mask_0)

def kernel_basic(x: torch.Tensor, *, _launcher=_default_launcher):
result = torch.empty_like(x)
_BLOCK_SIZE_0 = 16
_launcher(_kernel_basic_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, result, x.size(0), result.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return result

--- assertExpectedJournal(TestInlineAsmElementwise.test_inline_asm_empty_args)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _kernel_empty_args_kernel(result, x_size_0, result_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
result_val = tl.inline_asm_elementwise('mov.u32 $0, 42;', '=r', [], tl.int32, True, 1)
tl.store(result + indices_0 * result_stride_0, result_val, mask_0)

def kernel_empty_args(x: torch.Tensor, *, _launcher=_default_launcher):
result = torch.empty_like(x)
_BLOCK_SIZE_0 = 16
_launcher(_kernel_empty_args_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), result, x.size(0), result.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return result

--- assertExpectedJournal(TestInlineAsmElementwise.test_inline_asm_multiple_outputs)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _kernel_multiple_outputs_kernel(a, b, result_c, result_d, a_size_0, a_stride_0, b_stride_0, result_c_stride_0, result_d_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < a_size_0
val_a = tl.load(a + indices_0 * a_stride_0, mask_0, other=0)
val_b = tl.load(b + indices_0 * b_stride_0, mask_0, other=0)
c_val = tl.inline_asm_elementwise('\n sub.u32 $0, $2, $3;\n sub.u32 $1, $3, $2;\n ', '=r,=r,r,r', [val_a, val_b], (tl.int32, tl.int32), True, 1)[0]
d_val = tl.inline_asm_elementwise('\n sub.u32 $0, $2, $3;\n sub.u32 $1, $3, $2;\n ', '=r,=r,r,r', [val_a, val_b], (tl.int32, tl.int32), True, 1)[1]
tl.store(result_c + indices_0 * result_c_stride_0, c_val, mask_0)
tl.store(result_d + indices_0 * result_d_stride_0, d_val, mask_0)

def kernel_multiple_outputs(a: torch.Tensor, b: torch.Tensor, *, _launcher=_default_launcher):
result_c = torch.empty_like(a)
result_d = torch.empty_like(a)
_BLOCK_SIZE_0 = 64
_launcher(_kernel_multiple_outputs_kernel, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, b, result_c, result_d, a.size(0), a.stride(0), b.stride(0), result_c.stride(0), result_d.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return (result_c, result_d)

--- assertExpectedJournal(TestInlineAsmElementwise.test_inline_asm_packed)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _kernel_packed_asm_kernel(x, result, x_size_0, result_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
val = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
result_val = tl.inline_asm_elementwise('and.b32 $0, $1, 0x1F1F1F1F; shl.b32 $0, $0, 3;', '=r,r', [val], tl.int8, True, 4)
v_0 = result_val.to(tl.int8).to(tl.uint8)
tl.store(result + indices_0 * result_stride_0, v_0, mask_0)

def kernel_packed_asm(x: torch.Tensor, *, _launcher=_default_launcher):
result = torch.empty_like(x)
_BLOCK_SIZE_0 = 512
_launcher(_kernel_packed_asm_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, result, x.size(0), result.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return result

--- assertExpectedJournal(TestInlineAsmElementwise.test_inline_asm_shift_operation)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _kernel_shift_asm_kernel(x, y, result, x_size_0, result_stride_0, x_stride_0, y_stride_0, n, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
val_x = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
val_y = tl.load(y + indices_0 * y_stride_0, mask_0, other=0)
shift_val = tl.full([_BLOCK_SIZE_0], n, tl.int32)
result_val = tl.inline_asm_elementwise('shf.l.wrap.b32 $0, $1, $2, $3;', '=r,r,r,r', [val_x, val_y, shift_val], tl.int32, True, 1)
tl.store(result + indices_0 * result_stride_0, result_val, mask_0)

def kernel_shift_asm(x: torch.Tensor, y: torch.Tensor, n: int, *, _launcher=_default_launcher):
result = torch.empty_like(x)
_BLOCK_SIZE_0 = 128
_launcher(_kernel_shift_asm_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, result, x.size(0), result.stride(0), x.stride(0), y.stride(0), n, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return result

--- assertExpectedJournal(TestInlineAsmElementwise.test_inline_asm_simple)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _kernel_simple_asm_kernel(x, result, x_size_0, result_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
val = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
result_val = tl.inline_asm_elementwise('mov.u32 $0, $1;', '=r,r', [val], tl.int32, True, 1)
tl.store(result + indices_0 * result_stride_0, result_val, mask_0)

def kernel_simple_asm(x: torch.Tensor, *, _launcher=_default_launcher):
result = torch.empty_like(x)
_BLOCK_SIZE_0 = 16
_launcher(_kernel_simple_asm_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, result, x.size(0), result.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return result
Loading
Loading