From e5d49b75ca7ddeaf0f343817a7b0e937a4ce5bdc Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Tue, 15 Jul 2025 08:53:59 -0700 Subject: [PATCH] Add hl.inline_asm_elementwise stack-info: PR: https://github.com/pytorch-labs/helion/pull/328, branch: jansel/stack/114 --- docs/api/exceptions.md | 4 + helion/exc.py | 6 +- helion/language/__init__.py | 1 + helion/language/inline_asm_ops.py | 207 ++++++++++++++++++ test/test_inline_asm_elementwise.expected | 152 ++++++++++++++ test/test_inline_asm_elementwise.py | 242 ++++++++++++++++++++++ 6 files changed, 611 insertions(+), 1 deletion(-) create mode 100644 helion/language/inline_asm_ops.py create mode 100644 test/test_inline_asm_elementwise.expected create mode 100644 test/test_inline_asm_elementwise.py diff --git a/docs/api/exceptions.md b/docs/api/exceptions.md index 1541fe1e..645ea07d 100644 --- a/docs/api/exceptions.md +++ b/docs/api/exceptions.md @@ -202,6 +202,10 @@ These exceptions occur when Helion language functions are used incorrectly with .. autoclass:: TracedArgNotSupported Raised for unsupported argument types in traced functions. + +.. autoclass:: InvalidAPIUsage + + Raised for incorrect usage of Helion API functions. ``` ## Configuration Errors diff --git a/helion/exc.py b/helion/exc.py index c1d979ca..b27f3d03 100644 --- a/helion/exc.py +++ b/helion/exc.py @@ -236,7 +236,7 @@ class ErrorCompilingKernel(BaseError): class NoTensorArgs(BaseError): - message = "Kernel took no tensor args, unclear what device to use." + message = "Kernel took no tensor or device args, unclear what device to use." class _WrapException(BaseError): @@ -327,3 +327,7 @@ class DeviceTensorSubscriptAssignmentNotAllowed(BaseError): class InvalidSequenceSubscription(BaseError): message = "Cannot subscript a sequence with non constant indices. Got '{0!s}'. " + + +class InvalidAPIUsage(BaseError): + message = "Invalid usage of Helion API: {0}" diff --git a/helion/language/__init__.py b/helion/language/__init__.py index e38d6ad1..eb924e66 100644 --- a/helion/language/__init__.py +++ b/helion/language/__init__.py @@ -6,6 +6,7 @@ from .creation_ops import full as full from .creation_ops import zeros as zeros from .device_print import device_print as device_print +from .inline_asm_ops import inline_asm_elementwise as inline_asm_elementwise from .loops import grid as grid from .loops import tile as tile from .memory_ops import atomic_add as atomic_add diff --git a/helion/language/inline_asm_ops.py b/helion/language/inline_asm_ops.py new file mode 100644 index 00000000..936acd0d --- /dev/null +++ b/helion/language/inline_asm_ops.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +import ast +from typing import TYPE_CHECKING +from typing import Sequence +from typing import overload + +import torch +from torch._inductor.utils import triton_type + +from .. import exc +from .._compiler.ast_extension import create +from .._compiler.ast_extension import expr_from_string +from . import _decorators + +if TYPE_CHECKING: + from .._compiler.inductor_lowering import CodegenState + +__all__ = ["inline_asm_elementwise"] + + +@overload +@_decorators.api(is_device_only=True) +def inline_asm_elementwise( + asm: str, + constraints: str, + args: Sequence[torch.Tensor], + dtype: torch.dtype, + is_pure: bool, + pack: int, +) -> torch.Tensor: ... + + +@overload +@_decorators.api(is_device_only=True) +def inline_asm_elementwise( + asm: str, + constraints: str, + args: Sequence[torch.Tensor], + dtype: Sequence[torch.dtype], + is_pure: bool, + pack: int, +) -> tuple[torch.Tensor, ...]: ... + + +@_decorators.api(is_device_only=True) +def inline_asm_elementwise( + asm: str, + constraints: str, + args: Sequence[torch.Tensor], + dtype: torch.dtype | Sequence[torch.dtype], + is_pure: bool, + pack: int, +) -> torch.Tensor | tuple[torch.Tensor, ...]: + """ + Execute inline assembly over a tensor. Essentially, this is map + where the function is inline assembly. + + The input tensors args are implicitly broadcasted to the same shape. + dtype can be a tuple of types, in which case the output is a + tuple of tensors. + + Each invocation of the inline asm processes pack elements at a + time. Exactly which set of inputs a block receives is unspecified. + Input elements of size less than 4 bytes are packed into 4-byte + registers. + + This op does not support empty dtype -- the inline asm must + return at least one tensor, even if you don't need it. You can work + around this by returning a dummy tensor of arbitrary type; it shouldn't + cost you anything if you don't use it. + + Args: + asm: assembly to run. Must match target's assembly format. + constraints: asm constraints in LLVM format + args: the input tensors, whose values are passed to the asm block + dtype: the element type(s) of the returned tensor(s) + is_pure: if true, the compiler assumes the asm block has no side-effects + pack: the number of elements to be processed by one instance of inline assembly + + Returns: + one tensor or a tuple of tensors of the given dtypes + """ + raise exc.NotInsideKernel + + +@_decorators.register_fake(inline_asm_elementwise) +def _( + asm: str, + constraints: str, + args: Sequence[torch.Tensor], + dtype: torch.dtype | Sequence[torch.dtype], + is_pure: bool, + pack: int, +) -> torch.Tensor | tuple[torch.Tensor, ...]: + from .._compiler.compile_environment import CompileEnvironment + + # Basic validation + if not isinstance(asm, str): + raise exc.InvalidAPIUsage(f"asm must be a string, got {type(asm)}") + if not isinstance(constraints, str): + raise exc.InvalidAPIUsage( + f"constraints must be a string, got {type(constraints)}" + ) + if not isinstance(is_pure, bool): + raise exc.InvalidAPIUsage(f"is_pure must be a bool, got {type(is_pure)}") + if not isinstance(pack, int): + raise exc.InvalidAPIUsage(f"pack must be an int, got {type(pack)}") + + # Determine if we have multiple outputs + if isinstance(dtype, (tuple, list)): + dtypes = list(dtype) + has_multiple_outputs = True + else: + dtypes = [dtype] + has_multiple_outputs = False + + # Validate dtype(s) + for dt in dtypes: + if not isinstance(dt, torch.dtype): + raise exc.InvalidAPIUsage(f"dtype must be torch.dtype, got {type(dt)}") + + # Broadcast all inputs to the same shape + if args: + broadcast_shape = args[0].shape + for arg in args[1:]: + if arg.shape != broadcast_shape: + broadcast_shape = torch.broadcast_shapes(broadcast_shape, arg.shape) + else: + # For empty args, we need to infer the shape from context + # The problem is that without input tensors, we can't determine the proper broadcast shape + # However, when used in a tile context, the output should match the tile shape + # For the fake function, we'll use a simple placeholder shape that the compiler will handle + broadcast_shape = (1,) + + env = CompileEnvironment.current() + if has_multiple_outputs: + results = [] + for dt in dtypes: + # Type assertion: dt is guaranteed to be torch.dtype due to validation above + assert isinstance(dt, torch.dtype) + result = torch.empty(broadcast_shape, dtype=dt, device=env.device) + results.append(result) + return tuple(results) + + # Type assertion: dtypes[0] is guaranteed to be torch.dtype due to validation above + assert isinstance(dtypes[0], torch.dtype) + return torch.empty(broadcast_shape, dtype=dtypes[0], device=env.device) + + +@_decorators.codegen(inline_asm_elementwise) +def _(state: CodegenState) -> ast.AST | list[ast.AST]: + # Get arguments + asm_str = state.proxy_arg(0) + constraints_str = state.proxy_arg(1) + dtype = state.proxy_arg(3) + is_pure = state.proxy_arg(4) + pack = state.proxy_arg(5) + + # Convert the list of tensor args to AST + # We need to create a proper list AST with the tensor elements + raw_args = state.ast_args[2] + if isinstance(raw_args, list): + # Create AST List node with the tensor elements + args_ast = create(ast.List, elts=raw_args, ctx=ast.Load()) + else: + # If it's not a list, wrap it in a list (shouldn't normally happen) + args_ast = raw_args + + # Convert dtype to Triton type string(s) + if isinstance(dtype, (tuple, list)): + dtype_strs = [triton_type(dt) for dt in dtype if isinstance(dt, torch.dtype)] + dtype_arg = f"({', '.join(dtype_strs)})" # Use tuple syntax for multiple dtypes + has_multiple_outputs = True + else: + dtype_arg = ( + triton_type(dtype) if isinstance(dtype, torch.dtype) else "tl.float32" + ) + has_multiple_outputs = False + + # Create the call to tl.inline_asm_elementwise + inline_asm_call = create( + ast.Call, + func=expr_from_string("tl.inline_asm_elementwise"), + args=[ + create(ast.Constant, value=asm_str), + create(ast.Constant, value=constraints_str), + args_ast, + expr_from_string(dtype_arg), + create(ast.Constant, value=is_pure), + create(ast.Constant, value=pack), + ], + keywords=[], + ) + + # Handle multiple outputs by creating getitem expressions + if has_multiple_outputs: + assert isinstance(dtype, (tuple, list)) # Type guard for len() + num_outputs = len(dtype) + return [ + expr_from_string( + f"inline_asm_result[{i}]", inline_asm_result=inline_asm_call + ) + for i in range(num_outputs) + ] + + return inline_asm_call diff --git a/test/test_inline_asm_elementwise.expected b/test/test_inline_asm_elementwise.expected new file mode 100644 index 00000000..ed5e4a21 --- /dev/null +++ b/test/test_inline_asm_elementwise.expected @@ -0,0 +1,152 @@ +This file is automatically generated by assertExpectedJournal calls in test_inline_asm_elementwise.py. +Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. + +--- assertExpectedJournal(TestInlineAsmElementwise.test_inline_asm_basic_compilation) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _kernel_basic_kernel(x, result, x_size_0, result_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + result_val = tl.inline_asm_elementwise('mov.u32 $0, $1;', '=r,r', [load], tl.int32, True, 1) + tl.store(result + indices_0 * result_stride_0, result_val, mask_0) + +def kernel_basic(x: torch.Tensor, *, _launcher=_default_launcher): + result = torch.empty_like(x) + _BLOCK_SIZE_0 = 16 + _launcher(_kernel_basic_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, result, x.size(0), result.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return result + +--- assertExpectedJournal(TestInlineAsmElementwise.test_inline_asm_empty_args) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _kernel_empty_args_kernel(result, x_size_0, result_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + result_val = tl.inline_asm_elementwise('mov.u32 $0, 42;', '=r', [], tl.int32, True, 1) + tl.store(result + indices_0 * result_stride_0, result_val, mask_0) + +def kernel_empty_args(x: torch.Tensor, *, _launcher=_default_launcher): + result = torch.empty_like(x) + _BLOCK_SIZE_0 = 16 + _launcher(_kernel_empty_args_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), result, x.size(0), result.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return result + +--- assertExpectedJournal(TestInlineAsmElementwise.test_inline_asm_multiple_outputs) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _kernel_multiple_outputs_kernel(a, b, result_c, result_d, a_size_0, a_stride_0, b_stride_0, result_c_stride_0, result_d_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < a_size_0 + val_a = tl.load(a + indices_0 * a_stride_0, mask_0, other=0) + val_b = tl.load(b + indices_0 * b_stride_0, mask_0, other=0) + c_val = tl.inline_asm_elementwise('\n sub.u32 $0, $2, $3;\n sub.u32 $1, $3, $2;\n ', '=r,=r,r,r', [val_a, val_b], (tl.int32, tl.int32), True, 1)[0] + d_val = tl.inline_asm_elementwise('\n sub.u32 $0, $2, $3;\n sub.u32 $1, $3, $2;\n ', '=r,=r,r,r', [val_a, val_b], (tl.int32, tl.int32), True, 1)[1] + tl.store(result_c + indices_0 * result_c_stride_0, c_val, mask_0) + tl.store(result_d + indices_0 * result_d_stride_0, d_val, mask_0) + +def kernel_multiple_outputs(a: torch.Tensor, b: torch.Tensor, *, _launcher=_default_launcher): + result_c = torch.empty_like(a) + result_d = torch.empty_like(a) + _BLOCK_SIZE_0 = 64 + _launcher(_kernel_multiple_outputs_kernel, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, b, result_c, result_d, a.size(0), a.stride(0), b.stride(0), result_c.stride(0), result_d.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return (result_c, result_d) + +--- assertExpectedJournal(TestInlineAsmElementwise.test_inline_asm_packed) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _kernel_packed_asm_kernel(x, result, x_size_0, result_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + val = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + result_val = tl.inline_asm_elementwise('and.b32 $0, $1, 0x1F1F1F1F; shl.b32 $0, $0, 3;', '=r,r', [val], tl.int8, True, 4) + v_0 = result_val.to(tl.int8).to(tl.uint8) + tl.store(result + indices_0 * result_stride_0, v_0, mask_0) + +def kernel_packed_asm(x: torch.Tensor, *, _launcher=_default_launcher): + result = torch.empty_like(x) + _BLOCK_SIZE_0 = 512 + _launcher(_kernel_packed_asm_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, result, x.size(0), result.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return result + +--- assertExpectedJournal(TestInlineAsmElementwise.test_inline_asm_shift_operation) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _kernel_shift_asm_kernel(x, y, result, x_size_0, result_stride_0, x_stride_0, y_stride_0, n, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + val_x = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + val_y = tl.load(y + indices_0 * y_stride_0, mask_0, other=0) + shift_val = tl.full([_BLOCK_SIZE_0], n, tl.int32) + result_val = tl.inline_asm_elementwise('shf.l.wrap.b32 $0, $1, $2, $3;', '=r,r,r,r', [val_x, val_y, shift_val], tl.int32, True, 1) + tl.store(result + indices_0 * result_stride_0, result_val, mask_0) + +def kernel_shift_asm(x: torch.Tensor, y: torch.Tensor, n: int, *, _launcher=_default_launcher): + result = torch.empty_like(x) + _BLOCK_SIZE_0 = 128 + _launcher(_kernel_shift_asm_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, result, x.size(0), result.stride(0), x.stride(0), y.stride(0), n, _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return result + +--- assertExpectedJournal(TestInlineAsmElementwise.test_inline_asm_simple) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _kernel_simple_asm_kernel(x, result, x_size_0, result_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + val = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + result_val = tl.inline_asm_elementwise('mov.u32 $0, $1;', '=r,r', [val], tl.int32, True, 1) + tl.store(result + indices_0 * result_stride_0, result_val, mask_0) + +def kernel_simple_asm(x: torch.Tensor, *, _launcher=_default_launcher): + result = torch.empty_like(x) + _BLOCK_SIZE_0 = 16 + _launcher(_kernel_simple_asm_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, result, x.size(0), result.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return result diff --git a/test/test_inline_asm_elementwise.py b/test/test_inline_asm_elementwise.py new file mode 100644 index 00000000..e91fa910 --- /dev/null +++ b/test/test_inline_asm_elementwise.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +import unittest + +import pytest +import torch + +import helion +from helion._testing import DEVICE +from helion._testing import TestCase +from helion._testing import code_and_output +import helion.language as hl + + +class TestInlineAsmElementwise(TestCase): + @pytest.mark.skipif( + DEVICE.type != "cuda", reason="inline_asm_elementwise is only supported on CUDA" + ) + def test_inline_asm_simple(self): + """Test basic inline_asm_elementwise with simple assembly""" + + @helion.kernel(use_default_config=True) + def kernel_simple_asm(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for tile in hl.tile(x.shape): + val = x[tile] + # Simple mov instruction - copy input to output + result_val = hl.inline_asm_elementwise( + "mov.u32 $0, $1;", + "=r,r", + [val], + dtype=val.dtype, + is_pure=True, + pack=1, + ) + result[tile] = result_val + return result + + x = torch.randint(0, 100, [16], device=DEVICE, dtype=torch.int32) + code, result = code_and_output(kernel_simple_asm, (x,)) + self.assertExpectedJournal(code) + torch.testing.assert_close(result, x) + + @pytest.mark.skipif( + DEVICE.type != "cuda", reason="inline_asm_elementwise is only supported on CUDA" + ) + def test_inline_asm_shift_operation(self): + """Test inline_asm_elementwise with shift operation (similar to Triton test)""" + + @helion.kernel(use_default_config=True) + def kernel_shift_asm(x: torch.Tensor, y: torch.Tensor, n: int) -> torch.Tensor: + result = torch.empty_like(x) + for tile in hl.tile(x.shape): + val_x = x[tile] + val_y = y[tile] + shift_val = hl.full(tile, n, dtype=torch.int32) + # Shift left wrap operation + result_val = hl.inline_asm_elementwise( + "shf.l.wrap.b32 $0, $1, $2, $3;", + "=r,r,r,r", + [val_x, val_y, shift_val], + dtype=torch.int32, + is_pure=True, + pack=1, + ) + result[tile] = result_val + return result + + shape = [128] + x = torch.randint(0, 2**16, shape, device=DEVICE, dtype=torch.int32) + y = torch.randint(0, 2**16, shape, device=DEVICE, dtype=torch.int32) + n = 17 + + code, result = code_and_output(kernel_shift_asm, (x, y, n)) + self.assertExpectedJournal(code) + + # Expected: (y << n) | (x >> (32 - n)) + expected = (y << n) | (x >> (32 - n)) + torch.testing.assert_close(result, expected) + + @pytest.mark.skipif( + DEVICE.type != "cuda", reason="inline_asm_elementwise is only supported on CUDA" + ) + def test_inline_asm_multiple_outputs(self): + """Test inline_asm_elementwise with multiple outputs""" + + @helion.kernel(use_default_config=True) + def kernel_multiple_outputs( + a: torch.Tensor, b: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + result_c = torch.empty_like(a) + result_d = torch.empty_like(a) + + for tile in hl.tile(a.shape): + val_a = a[tile] + val_b = b[tile] + + # C = A - B, D = B - A + c_val, d_val = hl.inline_asm_elementwise( + """ + sub.u32 $0, $2, $3; + sub.u32 $1, $3, $2; + """, + "=r,=r,r,r", + [val_a, val_b], + dtype=(torch.int32, torch.int32), + is_pure=True, + pack=1, + ) + result_c[tile] = c_val + result_d[tile] = d_val + + return result_c, result_d + + shape = [64] + a = torch.randint(0, 2**16, shape, device=DEVICE, dtype=torch.int32) + b = torch.randint(0, 2**16, shape, device=DEVICE, dtype=torch.int32) + + code, (result_c, result_d) = code_and_output(kernel_multiple_outputs, (a, b)) + self.assertExpectedJournal(code) + + # Expected results + expected_c = a - b + expected_d = b - a + + torch.testing.assert_close(result_c, expected_c) + torch.testing.assert_close(result_d, expected_d) + + @pytest.mark.skipif( + DEVICE.type != "cuda", reason="inline_asm_elementwise is only supported on CUDA" + ) + def test_inline_asm_packed(self): + """Test inline_asm_elementwise with pack > 1""" + + @helion.kernel(use_default_config=True) + def kernel_packed_asm(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for tile in hl.tile(x.shape): + val = x[tile] + # Shift 4x8bit values together, pack=4 + result_val = hl.inline_asm_elementwise( + "and.b32 $0, $1, 0x1F1F1F1F; shl.b32 $0, $0, 3;", + "=r,r", + [val], + dtype=torch.int8, + is_pure=True, + pack=4, + ) + result[tile] = result_val + return result + + shape = [512] + x = torch.randint(0, 256, shape, device=DEVICE, dtype=torch.uint8) + + code, result = code_and_output(kernel_packed_asm, (x,)) + self.assertExpectedJournal(code) + + # Expected: x shifted left by 3 (x << 3) + expected = x << 3 + torch.testing.assert_close(result, expected) + + def test_inline_asm_error_cases(self): + """Test error cases for inline_asm_elementwise""" + + @helion.kernel(use_default_config=True) + def kernel_invalid_asm(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for tile in hl.tile(x.shape): + # Should raise error - invalid dtype + result_val = hl.inline_asm_elementwise( + "mov.u32 $0, $1;", + "=r,r", + [x[tile]], + dtype="invalid_dtype", # Invalid dtype + is_pure=True, + pack=1, + ) + result[tile] = result_val + return result + + x = torch.randint(0, 100, [16], device=DEVICE, dtype=torch.int32) + with self.assertRaises(helion.exc.InvalidAPIUsage): + code, result = code_and_output(kernel_invalid_asm, (x,)) + + @pytest.mark.skipif( + DEVICE.type != "cuda", reason="inline_asm_elementwise is only supported on CUDA" + ) + def test_inline_asm_empty_args(self): + """Test inline_asm_elementwise with empty args (should work like Triton)""" + + @helion.kernel(use_default_config=True) + def kernel_empty_args(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for tile in hl.tile(x.shape): + # Empty args should work - generates output with context shape + result_val = hl.inline_asm_elementwise( + "mov.u32 $0, 42;", # No input registers, just output constant + "=r", # Only output constraint + [], # Empty args + dtype=torch.int32, + is_pure=True, + pack=1, + ) + result[tile] = result_val + return result + + x = torch.randint(0, 100, [16], device=DEVICE, dtype=torch.int32) + # This should work without error + code, result = code_and_output(kernel_empty_args, (x,)) + self.assertExpectedJournal(code) + + # Should create a tensor filled with 42 + expected = torch.full([16], 42, dtype=torch.int32, device=DEVICE) + torch.testing.assert_close(result, expected) + + def test_inline_asm_basic_compilation(self): + """Test that inline_asm_elementwise compiles without errors (no CUDA requirement)""" + + @helion.kernel(use_default_config=True) + def kernel_basic(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for tile in hl.tile(x.shape): + # Simple compilation test + result_val = hl.inline_asm_elementwise( + "mov.u32 $0, $1;", + "=r,r", + [x[tile]], + dtype=torch.int32, + is_pure=True, + pack=1, + ) + result[tile] = result_val + return result + + x = torch.randint(0, 100, [16], device=DEVICE, dtype=torch.int32) + # Just test that it compiles + code, result = code_and_output(kernel_basic, (x,)) + self.assertExpectedJournal(code) + + +if __name__ == "__main__": + unittest.main()