Skip to content

Commit 2a92f79

Browse files
committed
Add hl.inline_asm_elementwise
stack-info: PR: #328, branch: jansel/stack/114
1 parent 17b2668 commit 2a92f79

File tree

6 files changed

+611
-1
lines changed

6 files changed

+611
-1
lines changed

docs/api/exceptions.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,10 @@ These exceptions occur when Helion language functions are used incorrectly with
202202
.. autoclass:: TracedArgNotSupported
203203
204204
Raised for unsupported argument types in traced functions.
205+
206+
.. autoclass:: InvalidAPIUsage
207+
208+
Raised for incorrect usage of Helion API functions.
205209
```
206210

207211
## Configuration Errors

helion/exc.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ class ErrorCompilingKernel(BaseError):
236236

237237

238238
class NoTensorArgs(BaseError):
239-
message = "Kernel took no tensor args, unclear what device to use."
239+
message = "Kernel took no tensor or device args, unclear what device to use."
240240

241241

242242
class _WrapException(BaseError):
@@ -323,3 +323,7 @@ class CannotReadDeviceVariableOnHost(BaseError):
323323

324324
class DeviceTensorSubscriptAssignmentNotAllowed(BaseError):
325325
message = "Cannot assign to subscript of device tensor '{0}'."
326+
327+
328+
class InvalidAPIUsage(BaseError):
329+
message = "Invalid usage of Helion API: {0}"

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .creation_ops import full as full
77
from .creation_ops import zeros as zeros
88
from .device_print import device_print as device_print
9+
from .inline_asm_ops import inline_asm_elementwise as inline_asm_elementwise
910
from .loops import grid as grid
1011
from .loops import tile as tile
1112
from .memory_ops import atomic_add as atomic_add

helion/language/inline_asm_ops.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
from typing import TYPE_CHECKING
5+
from typing import Sequence
6+
from typing import overload
7+
8+
import torch
9+
from torch._inductor.utils import triton_type
10+
11+
from .. import exc
12+
from .._compiler.ast_extension import create
13+
from .._compiler.ast_extension import expr_from_string
14+
from . import _decorators
15+
16+
if TYPE_CHECKING:
17+
from .._compiler.inductor_lowering import CodegenState
18+
19+
__all__ = ["inline_asm_elementwise"]
20+
21+
22+
@overload
23+
@_decorators.api(is_device_only=True)
24+
def inline_asm_elementwise(
25+
asm: str,
26+
constraints: str,
27+
args: Sequence[torch.Tensor],
28+
dtype: torch.dtype,
29+
is_pure: bool,
30+
pack: int,
31+
) -> torch.Tensor: ...
32+
33+
34+
@overload
35+
@_decorators.api(is_device_only=True)
36+
def inline_asm_elementwise(
37+
asm: str,
38+
constraints: str,
39+
args: Sequence[torch.Tensor],
40+
dtype: Sequence[torch.dtype],
41+
is_pure: bool,
42+
pack: int,
43+
) -> tuple[torch.Tensor, ...]: ...
44+
45+
46+
@_decorators.api(is_device_only=True)
47+
def inline_asm_elementwise(
48+
asm: str,
49+
constraints: str,
50+
args: Sequence[torch.Tensor],
51+
dtype: torch.dtype | Sequence[torch.dtype],
52+
is_pure: bool,
53+
pack: int,
54+
) -> torch.Tensor | tuple[torch.Tensor, ...]:
55+
"""
56+
Execute inline assembly over a tensor. Essentially, this is map
57+
where the function is inline assembly.
58+
59+
The input tensors args are implicitly broadcasted to the same shape.
60+
dtype can be a tuple of types, in which case the output is a
61+
tuple of tensors.
62+
63+
Each invocation of the inline asm processes pack elements at a
64+
time. Exactly which set of inputs a block receives is unspecified.
65+
Input elements of size less than 4 bytes are packed into 4-byte
66+
registers.
67+
68+
This op does not support empty dtype -- the inline asm must
69+
return at least one tensor, even if you don't need it. You can work
70+
around this by returning a dummy tensor of arbitrary type; it shouldn't
71+
cost you anything if you don't use it.
72+
73+
Args:
74+
asm: assembly to run. Must match target's assembly format.
75+
constraints: asm constraints in LLVM format
76+
args: the input tensors, whose values are passed to the asm block
77+
dtype: the element type(s) of the returned tensor(s)
78+
is_pure: if true, the compiler assumes the asm block has no side-effects
79+
pack: the number of elements to be processed by one instance of inline assembly
80+
81+
Returns:
82+
one tensor or a tuple of tensors of the given dtypes
83+
"""
84+
raise exc.NotInsideKernel
85+
86+
87+
@_decorators.register_fake(inline_asm_elementwise)
88+
def _(
89+
asm: str,
90+
constraints: str,
91+
args: Sequence[torch.Tensor],
92+
dtype: torch.dtype | Sequence[torch.dtype],
93+
is_pure: bool,
94+
pack: int,
95+
) -> torch.Tensor | tuple[torch.Tensor, ...]:
96+
from .._compiler.compile_environment import CompileEnvironment
97+
98+
# Basic validation
99+
if not isinstance(asm, str):
100+
raise exc.InvalidAPIUsage(f"asm must be a string, got {type(asm)}")
101+
if not isinstance(constraints, str):
102+
raise exc.InvalidAPIUsage(
103+
f"constraints must be a string, got {type(constraints)}"
104+
)
105+
if not isinstance(is_pure, bool):
106+
raise exc.InvalidAPIUsage(f"is_pure must be a bool, got {type(is_pure)}")
107+
if not isinstance(pack, int):
108+
raise exc.InvalidAPIUsage(f"pack must be an int, got {type(pack)}")
109+
110+
# Determine if we have multiple outputs
111+
if isinstance(dtype, (tuple, list)):
112+
dtypes = list(dtype)
113+
has_multiple_outputs = True
114+
else:
115+
dtypes = [dtype]
116+
has_multiple_outputs = False
117+
118+
# Validate dtype(s)
119+
for dt in dtypes:
120+
if not isinstance(dt, torch.dtype):
121+
raise exc.InvalidAPIUsage(f"dtype must be torch.dtype, got {type(dt)}")
122+
123+
# Broadcast all inputs to the same shape
124+
if args:
125+
broadcast_shape = args[0].shape
126+
for arg in args[1:]:
127+
if arg.shape != broadcast_shape:
128+
broadcast_shape = torch.broadcast_shapes(broadcast_shape, arg.shape)
129+
else:
130+
# For empty args, we need to infer the shape from context
131+
# The problem is that without input tensors, we can't determine the proper broadcast shape
132+
# However, when used in a tile context, the output should match the tile shape
133+
# For the fake function, we'll use a simple placeholder shape that the compiler will handle
134+
broadcast_shape = (1,)
135+
136+
env = CompileEnvironment.current()
137+
if has_multiple_outputs:
138+
results = []
139+
for dt in dtypes:
140+
# Type assertion: dt is guaranteed to be torch.dtype due to validation above
141+
assert isinstance(dt, torch.dtype)
142+
result = torch.empty(broadcast_shape, dtype=dt, device=env.device)
143+
results.append(result)
144+
return tuple(results)
145+
146+
# Type assertion: dtypes[0] is guaranteed to be torch.dtype due to validation above
147+
assert isinstance(dtypes[0], torch.dtype)
148+
return torch.empty(broadcast_shape, dtype=dtypes[0], device=env.device)
149+
150+
151+
@_decorators.codegen(inline_asm_elementwise)
152+
def _(state: CodegenState) -> ast.AST | list[ast.AST]:
153+
# Get arguments
154+
asm_str = state.proxy_arg(0)
155+
constraints_str = state.proxy_arg(1)
156+
dtype = state.proxy_arg(3)
157+
is_pure = state.proxy_arg(4)
158+
pack = state.proxy_arg(5)
159+
160+
# Convert the list of tensor args to AST
161+
# We need to create a proper list AST with the tensor elements
162+
raw_args = state.ast_args[2]
163+
if isinstance(raw_args, list):
164+
# Create AST List node with the tensor elements
165+
args_ast = create(ast.List, elts=raw_args, ctx=ast.Load())
166+
else:
167+
# If it's not a list, wrap it in a list (shouldn't normally happen)
168+
args_ast = raw_args
169+
170+
# Convert dtype to Triton type string(s)
171+
if isinstance(dtype, (tuple, list)):
172+
dtype_strs = [triton_type(dt) for dt in dtype if isinstance(dt, torch.dtype)]
173+
dtype_arg = f"({', '.join(dtype_strs)})" # Use tuple syntax for multiple dtypes
174+
has_multiple_outputs = True
175+
else:
176+
dtype_arg = (
177+
triton_type(dtype) if isinstance(dtype, torch.dtype) else "tl.float32"
178+
)
179+
has_multiple_outputs = False
180+
181+
# Create the call to tl.inline_asm_elementwise
182+
inline_asm_call = create(
183+
ast.Call,
184+
func=expr_from_string("tl.inline_asm_elementwise"),
185+
args=[
186+
create(ast.Constant, value=asm_str),
187+
create(ast.Constant, value=constraints_str),
188+
args_ast,
189+
expr_from_string(dtype_arg),
190+
create(ast.Constant, value=is_pure),
191+
create(ast.Constant, value=pack),
192+
],
193+
keywords=[],
194+
)
195+
196+
# Handle multiple outputs by creating getitem expressions
197+
if has_multiple_outputs:
198+
assert isinstance(dtype, (tuple, list)) # Type guard for len()
199+
num_outputs = len(dtype)
200+
return [
201+
expr_from_string(
202+
f"inline_asm_result[{i}]", inline_asm_result=inline_asm_call
203+
)
204+
for i in range(num_outputs)
205+
]
206+
207+
return inline_asm_call
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
This file is automatically generated by assertExpectedJournal calls in test_inline_asm_elementwise.py.
2+
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.
3+
4+
--- assertExpectedJournal(TestInlineAsmElementwise.test_inline_asm_basic_compilation)
5+
from __future__ import annotations
6+
7+
import torch
8+
import triton
9+
import triton.language as tl
10+
from helion.runtime import default_launcher as _default_launcher
11+
12+
@triton.jit
13+
def _kernel_basic_kernel(x, result, x_size_0, result_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
14+
pid_0 = tl.program_id(0)
15+
offset_0 = pid_0 * _BLOCK_SIZE_0
16+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
17+
mask_0 = indices_0 < x_size_0
18+
load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
19+
result_val = tl.inline_asm_elementwise('mov.u32 $0, $1;', '=r,r', [load], tl.int32, True, 1)
20+
tl.store(result + indices_0 * result_stride_0, result_val, mask_0)
21+
22+
def kernel_basic(x: torch.Tensor, *, _launcher=_default_launcher):
23+
result = torch.empty_like(x)
24+
_BLOCK_SIZE_0 = 16
25+
_launcher(_kernel_basic_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, result, x.size(0), result.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
26+
return result
27+
28+
--- assertExpectedJournal(TestInlineAsmElementwise.test_inline_asm_empty_args)
29+
from __future__ import annotations
30+
31+
import torch
32+
import triton
33+
import triton.language as tl
34+
from helion.runtime import default_launcher as _default_launcher
35+
36+
@triton.jit
37+
def _kernel_empty_args_kernel(result, x_size_0, result_stride_0, _BLOCK_SIZE_0: tl.constexpr):
38+
pid_0 = tl.program_id(0)
39+
offset_0 = pid_0 * _BLOCK_SIZE_0
40+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
41+
mask_0 = indices_0 < x_size_0
42+
result_val = tl.inline_asm_elementwise('mov.u32 $0, 42;', '=r', [], tl.int32, True, 1)
43+
tl.store(result + indices_0 * result_stride_0, result_val, mask_0)
44+
45+
def kernel_empty_args(x: torch.Tensor, *, _launcher=_default_launcher):
46+
result = torch.empty_like(x)
47+
_BLOCK_SIZE_0 = 16
48+
_launcher(_kernel_empty_args_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), result, x.size(0), result.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
49+
return result
50+
51+
--- assertExpectedJournal(TestInlineAsmElementwise.test_inline_asm_multiple_outputs)
52+
from __future__ import annotations
53+
54+
import torch
55+
import triton
56+
import triton.language as tl
57+
from helion.runtime import default_launcher as _default_launcher
58+
59+
@triton.jit
60+
def _kernel_multiple_outputs_kernel(a, b, result_c, result_d, a_size_0, a_stride_0, b_stride_0, result_c_stride_0, result_d_stride_0, _BLOCK_SIZE_0: tl.constexpr):
61+
pid_0 = tl.program_id(0)
62+
offset_0 = pid_0 * _BLOCK_SIZE_0
63+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
64+
mask_0 = indices_0 < a_size_0
65+
val_a = tl.load(a + indices_0 * a_stride_0, mask_0, other=0)
66+
val_b = tl.load(b + indices_0 * b_stride_0, mask_0, other=0)
67+
c_val = tl.inline_asm_elementwise('\n sub.u32 $0, $2, $3;\n sub.u32 $1, $3, $2;\n ', '=r,=r,r,r', [val_a, val_b], (tl.int32, tl.int32), True, 1)[0]
68+
d_val = tl.inline_asm_elementwise('\n sub.u32 $0, $2, $3;\n sub.u32 $1, $3, $2;\n ', '=r,=r,r,r', [val_a, val_b], (tl.int32, tl.int32), True, 1)[1]
69+
tl.store(result_c + indices_0 * result_c_stride_0, c_val, mask_0)
70+
tl.store(result_d + indices_0 * result_d_stride_0, d_val, mask_0)
71+
72+
def kernel_multiple_outputs(a: torch.Tensor, b: torch.Tensor, *, _launcher=_default_launcher):
73+
result_c = torch.empty_like(a)
74+
result_d = torch.empty_like(a)
75+
_BLOCK_SIZE_0 = 64
76+
_launcher(_kernel_multiple_outputs_kernel, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, b, result_c, result_d, a.size(0), a.stride(0), b.stride(0), result_c.stride(0), result_d.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
77+
return (result_c, result_d)
78+
79+
--- assertExpectedJournal(TestInlineAsmElementwise.test_inline_asm_packed)
80+
from __future__ import annotations
81+
82+
import torch
83+
import triton
84+
import triton.language as tl
85+
from helion.runtime import default_launcher as _default_launcher
86+
87+
@triton.jit
88+
def _kernel_packed_asm_kernel(x, result, x_size_0, result_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
89+
pid_0 = tl.program_id(0)
90+
offset_0 = pid_0 * _BLOCK_SIZE_0
91+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
92+
mask_0 = indices_0 < x_size_0
93+
val = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
94+
result_val = tl.inline_asm_elementwise('and.b32 $0, $1, 0x1F1F1F1F; shl.b32 $0, $0, 3;', '=r,r', [val], tl.int8, True, 4)
95+
v_0 = result_val.to(tl.int8).to(tl.uint8)
96+
tl.store(result + indices_0 * result_stride_0, v_0, mask_0)
97+
98+
def kernel_packed_asm(x: torch.Tensor, *, _launcher=_default_launcher):
99+
result = torch.empty_like(x)
100+
_BLOCK_SIZE_0 = 512
101+
_launcher(_kernel_packed_asm_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, result, x.size(0), result.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
102+
return result
103+
104+
--- assertExpectedJournal(TestInlineAsmElementwise.test_inline_asm_shift_operation)
105+
from __future__ import annotations
106+
107+
import torch
108+
import triton
109+
import triton.language as tl
110+
from helion.runtime import default_launcher as _default_launcher
111+
112+
@triton.jit
113+
def _kernel_shift_asm_kernel(x, y, result, x_size_0, result_stride_0, x_stride_0, y_stride_0, n, _BLOCK_SIZE_0: tl.constexpr):
114+
pid_0 = tl.program_id(0)
115+
offset_0 = pid_0 * _BLOCK_SIZE_0
116+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
117+
mask_0 = indices_0 < x_size_0
118+
val_x = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
119+
val_y = tl.load(y + indices_0 * y_stride_0, mask_0, other=0)
120+
shift_val = tl.full([_BLOCK_SIZE_0], n, tl.int32)
121+
result_val = tl.inline_asm_elementwise('shf.l.wrap.b32 $0, $1, $2, $3;', '=r,r,r,r', [val_x, val_y, shift_val], tl.int32, True, 1)
122+
tl.store(result + indices_0 * result_stride_0, result_val, mask_0)
123+
124+
def kernel_shift_asm(x: torch.Tensor, y: torch.Tensor, n: int, *, _launcher=_default_launcher):
125+
result = torch.empty_like(x)
126+
_BLOCK_SIZE_0 = 128
127+
_launcher(_kernel_shift_asm_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, result, x.size(0), result.stride(0), x.stride(0), y.stride(0), n, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
128+
return result
129+
130+
--- assertExpectedJournal(TestInlineAsmElementwise.test_inline_asm_simple)
131+
from __future__ import annotations
132+
133+
import torch
134+
import triton
135+
import triton.language as tl
136+
from helion.runtime import default_launcher as _default_launcher
137+
138+
@triton.jit
139+
def _kernel_simple_asm_kernel(x, result, x_size_0, result_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
140+
pid_0 = tl.program_id(0)
141+
offset_0 = pid_0 * _BLOCK_SIZE_0
142+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
143+
mask_0 = indices_0 < x_size_0
144+
val = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
145+
result_val = tl.inline_asm_elementwise('mov.u32 $0, $1;', '=r,r', [val], tl.int32, True, 1)
146+
tl.store(result + indices_0 * result_stride_0, result_val, mask_0)
147+
148+
def kernel_simple_asm(x: torch.Tensor, *, _launcher=_default_launcher):
149+
result = torch.empty_like(x)
150+
_BLOCK_SIZE_0 = 16
151+
_launcher(_kernel_simple_asm_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, result, x.size(0), result.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
152+
return result

0 commit comments

Comments
 (0)