diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index 9f66dd32..ffa9c2c6 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -502,11 +502,32 @@ def _extract_tile_begin_end(self, for_node: ast.For) -> tuple[object, object]: end = self.visit(args[1]) return begin, end + def _handle_tuple_unrolling( + self, + node: ast.For, + ) -> None: + """Handle unrolling of loops that iterate over tuples of tensors.""" + # Get the sequence of tensors to iterate over + sequence_value = self.visit(node.iter) + assert isinstance(sequence_value, (tuple, list)), ( + f"Expected tuple or list, got {type(sequence_value)}" + ) + # Unroll the loop by executing the body for each tensor in the sequence + for tensor_value in sequence_value: + self._assign(node.target, tensor_value) + self._body(node.body) + def visit_For(self, node: ast.For) -> None: assert isinstance(node, ExtendedAST) assert not node.orelse assert isinstance(node.iter, ExtendedAST) iter_type = node.iter._type_info + + # Check if we're iterating directly over a sequence (tuple unrolling) + if isinstance(iter_type, SequenceType): + self._handle_tuple_unrolling(node) + return + if not isinstance(iter_type, IterType): raise exc.InvalidDeviceForLoop(iter_type) inner_type: TypeInfo = iter_type.inner diff --git a/helion/language/__init__.py b/helion/language/__init__.py index eb924e66..245b9ab6 100644 --- a/helion/language/__init__.py +++ b/helion/language/__init__.py @@ -8,6 +8,7 @@ from .device_print import device_print as device_print from .inline_asm_ops import inline_asm_elementwise as inline_asm_elementwise from .loops import grid as grid +from .loops import static_range as static_range from .loops import tile as tile from .memory_ops import atomic_add as atomic_add from .memory_ops import load as load diff --git a/helion/language/loops.py b/helion/language/loops.py index e5075d0b..03d874e8 100644 --- a/helion/language/loops.py +++ b/helion/language/loops.py @@ -25,10 +25,12 @@ from .._compiler.compile_environment import CompileEnvironment from .._compiler.type_propagation import GridIndexType from .._compiler.type_propagation import IterType +from .._compiler.type_propagation import LiteralType from .._compiler.type_propagation import Origin from .._compiler.type_propagation import SequenceType from .._compiler.type_propagation import TileIndexType from .._compiler.type_propagation import TypeInfo +from .._compiler.variable_origin import GetItemOrigin from ..autotuner.config_spec import ConfigSpec from ..autotuner.config_spec import FlattenLoopSpec from ..autotuner.config_spec import L2GroupingSpec @@ -48,7 +50,7 @@ from .._compiler.inductor_lowering import CodegenState -__all__ = ["grid", "tile"] +__all__ = ["grid", "static_range", "tile"] @overload @@ -633,3 +635,266 @@ def _( @_decorators.codegen(grid) def _(state: CodegenState) -> ast.AST: return _codegen_loop_helper(state) + + +@_decorators.device_func_replacement(builtins.zip) +@_decorators.api(is_device_only=True, cache_type=True) +def _zip_replacement( + *args: tuple[object, ...] | list[object], + strict: bool = False, +) -> tuple[tuple[object, ...], ...]: + """ + Device replacement for zip() that returns tuples for unrolling. + + This replacement enables zip() to work in device kernels by converting + the zip result to a tuple of tuples, which can then be unrolled by the + existing tuple iteration logic. + + Args: + *args: Sequences to zip together + + Returns: + Tuple of tuples containing zipped elements + + Examples: + .. code-block:: python + + @helion.kernel + def kernel_with_zip(a_tensors, b_tensors): + for a, b in zip(a_tensors, b_tensors): + # This gets unrolled at compile time + result += a * b + """ + raise exc.NotInsideKernel + + +@_decorators.type_propagation(_zip_replacement) +def _( + *args: TypeInfo, + origin: Origin, + **kwargs: object, +) -> TypeInfo: + """Type propagation for zip replacement that preserves tensor types.""" + # Accept but ignore the strict keyword argument + if not args: + return SequenceType(origin, ()) + + # Convert all arguments to SequenceType + sequences = [] + for arg in args: + if not isinstance(arg, SequenceType): + raise exc.TypeInferenceError( + f"zip() argument must be a sequence, got {arg}" + ) + sequences.append(arg.unpack()) + + # Check all sequences have the same length + length = 0 + if sequences: + length = len(sequences[0]) + for i, seq in enumerate(sequences[1:], 1): + if len(seq) != length: + raise exc.TypeInferenceError( + f"zip() argument {i} has length {len(seq)}, expected {length}" + ) + + # Build result as tuple of tuples, preserving existing TypeInfo objects + result_elements = [] + for i in range(length): + # Create a tuple containing the i-th element from each sequence + tuple_elements = tuple(seq[i] for seq in sequences) + tuple_type = SequenceType(GetItemOrigin(origin, i), tuple_elements) + result_elements.append(tuple_type) + + return SequenceType(origin, tuple(result_elements)) + + +@_decorators.register_to_device_ir(_zip_replacement) +def _( + tracer: object, + *flat_args: object, +) -> object: + """Device IR handler for zip - returns the zipped result for unrolling.""" + # flat_args contains the prepared arguments: (tensor_sequences, strict_value) + if not flat_args: + return () + + # Extract sequences and strict parameter + if len(flat_args) == 2: + sequences = flat_args[0] # This should be the tuple of sequences + strict = flat_args[1] # This should be the strict parameter + assert isinstance(strict, bool) + else: + assert len(flat_args) == 1 + sequences = flat_args[0] + strict = False + return [*builtins.zip(*sequences, strict=strict)] # type: ignore[arg-type] + + +@_decorators.device_func_replacement(builtins.enumerate) +@_decorators.api(is_device_only=True, cache_type=True) +def _enumerate_replacement( + iterable: tuple[object, ...] | list[object], + start: int = 0, +) -> tuple[tuple[int, object], ...]: + """ + Device replacement for enumerate() that returns tuples for unrolling. + + This replacement enables enumerate() to work in device kernels by converting + the enumerate result to a tuple of (index, value) tuples, which can then be + unrolled by the existing tuple iteration logic. + + Args: + iterable: Sequence to enumerate + start: Starting value for the counter (default: 0) + + Returns: + Tuple of (index, value) tuples + """ + raise exc.NotInsideKernel + + +@_decorators.type_propagation(_enumerate_replacement) +def _( + iterable: TypeInfo, + start: TypeInfo | None = None, + *, + origin: Origin, +) -> TypeInfo: + """Type propagation for enumerate replacement that preserves tensor types.""" + if not isinstance(iterable, SequenceType): + raise exc.TypeInferenceError( + f"enumerate() argument must be a sequence, got {iterable}" + ) + + # Get the start value + start_value = 0 + if start is not None and start.is_literal(): + start_val = start.as_literal() + if isinstance(start_val, int): + start_value = start_val + + # Build result as tuple of (index, value) tuples + sequence_elements = iterable.unpack() + result_elements = [] + + for i, element in enumerate(sequence_elements): + # Create (index, value) tuple + index_literal = LiteralType(origin, start_value + i) + tuple_elements = (index_literal, element) + tuple_type = SequenceType(GetItemOrigin(origin, i), tuple_elements) + result_elements.append(tuple_type) + + return SequenceType(origin, tuple(result_elements)) + + +@_decorators.register_to_device_ir(_enumerate_replacement) +def _( + tracer: object, + *flat_args: object, +) -> object: + """Device IR handler for enumerate - returns the enumerated result for unrolling.""" + if len(flat_args) == 2: + iterable = flat_args[0] + start = flat_args[1] + assert isinstance(start, int) + else: + assert len(flat_args) == 1 + iterable = flat_args[0] + start = 0 + return [*builtins.enumerate(iterable, start=start)] # type: ignore[arg-type] + + +@_decorators.api(is_device_only=True, cache_type=True) +def static_range( + begin_or_end: int, + end_or_none: int | None = None, + /, + step: int = 1, +) -> Iterator[int]: + """ + Create a range that gets unrolled at compile time by iterating over constant integer values. + + This function is similar to Python's built-in range(), but it generates a sequence + of integer constants that triggers loop unrolling behavior in Helion kernels. The loop + is completely unrolled at compile time, with each iteration becoming separate + instructions in the generated code. + + Args: + begin_or_end: If 2+ positional args provided, the start of range (integer). + Otherwise, the end of range (integer). + end_or_none: If 2+ positional args provided, the end of range (integer). + step: Step size for iteration (integer, default: 1) + + Returns: + Iterator[int]: Iterator over constant integer values + + Examples: + Simple unrolled loop: + + .. code-block:: python + + @helion.kernel + def unrolled_example(x: torch.Tensor) -> torch.Tensor: + result = torch.zeros_like(x) + + for tile in hl.tile(x.size(0)): + acc = torch.zeros([tile], dtype=x.dtype, device=x.device) + # This loop gets completely unrolled + for i in hl.static_range(3): + acc += x[tile] * i + result[tile] = acc + + return result + + Range with start and step: + + .. code-block:: python + + @helion.kernel + def kernel_stepped_unroll(x: torch.Tensor) -> torch.Tensor: + result = torch.zeros_like(x) + + for tile in hl.tile(x.size(0)): + acc = torch.zeros([tile], dtype=x.dtype, device=x.device) + # Unroll loop from 2 to 8 with step 2: [2, 4, 6] + for i in hl.static_range(2, 8, 2): + acc += x[tile] * i + result[tile] = acc + + return result + + Note: + - Only constant integer values are supported + - The range must be small enough to avoid compilation timeouts + - Each iteration becomes separate instructions in the generated Triton code + - Use for small, fixed iteration counts where unrolling is beneficial + """ + raise exc.NotInsideKernel + + +@_decorators.register_fake(static_range) +def _( + begin_or_end: int, + end_or_none: int | None = None, + /, + step: int = 1, +) -> tuple[int, ...]: + """Fake function for static_range - validates integer constants and returns tuple(range(...)).""" + # Validate that inputs are compile-time constants + if end_or_none is not None: + begin_val = begin_or_end + end_val = end_or_none + else: + begin_val = 0 + end_val = begin_or_end + + if ( + not isinstance(begin_val, int) + or not isinstance(end_val, int) + or not isinstance(step, int) + ): + raise exc.TypeInferenceError("static_range requires constant integer arguments") + + # Return tuple(range(...)) which will trigger existing tuple/list unrolling + return tuple(range(begin_val, end_val, step)) diff --git a/test/test_unroll_tuples.expected b/test/test_unroll_tuples.expected new file mode 100644 index 00000000..80c871b6 --- /dev/null +++ b/test/test_unroll_tuples.expected @@ -0,0 +1,445 @@ +This file is automatically generated by assertExpectedJournal calls in test_unroll_tuples.py. +Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. + +--- assertExpectedJournal(TestUnrollTuples.test_basic_tuple_addition) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _kernel_tuple_addition_kernel(out, a_shared_tuple_item_0, a_shared_tuple_item_1, a_shared_tuple_item_2, out_size_0, a_shared_tuple_item_0_stride_0, a_shared_tuple_item_1_stride_0, a_shared_tuple_item_2_stride_0, out_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < out_size_0 + acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32) + load = tl.load(a_shared_tuple_item_0 + indices_0 * a_shared_tuple_item_0_stride_0, mask_0, other=0) + v_0 = acc + load + load_1 = tl.load(a_shared_tuple_item_1 + indices_0 * a_shared_tuple_item_1_stride_0, mask_0, other=0) + v_1 = v_0 + load_1 + load_2 = tl.load(a_shared_tuple_item_2 + indices_0 * a_shared_tuple_item_2_stride_0, mask_0, other=0) + v_2 = v_1 + load_2 + tl.store(out + indices_0 * out_stride_0, v_2, mask_0) + +def kernel_tuple_addition(a_shared_tuple: tuple[torch.Tensor, ...], *, _launcher=_default_launcher): + """Basic test: iterate over a tuple of tensors and sum them.""" + out = torch.empty_like(a_shared_tuple[0]) + _BLOCK_SIZE_0 = 32 + _launcher(_kernel_tuple_addition_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, a_shared_tuple[0], a_shared_tuple[1], a_shared_tuple[2], out.size(0), a_shared_tuple[0].stride(0), a_shared_tuple[1].stride(0), a_shared_tuple[2].stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestUnrollTuples.test_constants_iteration) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _kernel_constants_iteration_kernel(x, result, x_size_0, result_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32) + load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_0 = 1.0 + v_1 = load * v_0 + v_2 = acc + v_1 + load_1 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_3 = 2.0 + v_4 = load_1 * v_3 + v_5 = v_2 + v_4 + load_2 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_6 = 3.0 + v_7 = load_2 * v_6 + v_8 = v_5 + v_7 + tl.store(result + indices_0 * result_stride_0, v_8, mask_0) + +def kernel_constants_iteration(x: torch.Tensor, *, _launcher=_default_launcher): + """Test iteration over a tuple/list of constants.""" + result = torch.zeros_like(x) + _BLOCK_SIZE_0 = 32 + _launcher(_kernel_constants_iteration_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, result, x.size(0), result.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return result + +--- assertExpectedJournal(TestUnrollTuples.test_enumerate_constants) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _kernel_enumerate_constants_kernel(x, result, x_size_0, result_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32) + load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_0 = 2.0 + v_1 = load * v_0 + v_2 = 0.0 + v_3 = v_1 * v_2 + v_4 = acc + v_3 + load_1 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_5 = 3.0 + v_6 = load_1 * v_5 + v_7 = 1.0 + v_8 = v_6 * v_7 + v_9 = v_4 + v_8 + load_2 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_10 = 4.0 + v_11 = load_2 * v_10 + v_12 = 2.0 + v_13 = v_11 * v_12 + v_14 = v_9 + v_13 + tl.store(result + indices_0 * result_stride_0, v_14, mask_0) + +def kernel_enumerate_constants(x: torch.Tensor, *, _launcher=_default_launcher): + """Test enumerate over constants.""" + result = torch.zeros_like(x) + _BLOCK_SIZE_0 = 32 + _launcher(_kernel_enumerate_constants_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, result, x.size(0), result.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return result + +--- assertExpectedJournal(TestUnrollTuples.test_enumerate_iteration) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _kernel_enumerate_iteration_kernel(tensors_item_2, tensors_item_0, tensors_item_1, result, tensors_item_2_size_0, result_stride_0, tensors_item_0_stride_0, tensors_item_1_stride_0, tensors_item_2_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < tensors_item_2_size_0 + acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32) + load = tl.load(tensors_item_0 + indices_0 * tensors_item_0_stride_0, mask_0, other=0) + v_0 = 1.0 + v_1 = load * v_0 + v_2 = acc + v_1 + load_1 = tl.load(tensors_item_1 + indices_0 * tensors_item_1_stride_0, mask_0, other=0) + v_3 = 2.0 + v_4 = load_1 * v_3 + v_5 = v_2 + v_4 + load_2 = tl.load(tensors_item_2 + indices_0 * tensors_item_2_stride_0, mask_0, other=0) + v_6 = 3.0 + v_7 = load_2 * v_6 + v_8 = v_5 + v_7 + tl.store(result + indices_0 * result_stride_0, v_8, mask_0) + +def kernel_enumerate_iteration(tensors: tuple[torch.Tensor, torch.Tensor, torch.Tensor], *, _launcher=_default_launcher): + """Test iteration using enumerate over tensors.""" + result = torch.zeros_like(tensors[0]) + _BLOCK_SIZE_0 = 32 + _launcher(_kernel_enumerate_iteration_kernel, (triton.cdiv(tensors[2].size(0), _BLOCK_SIZE_0),), tensors[2], tensors[0], tensors[1], result, tensors[2].size(0), result.stride(0), tensors[0].stride(0), tensors[1].stride(0), tensors[2].stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return result + +--- assertExpectedJournal(TestUnrollTuples.test_enumerate_with_start) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _kernel_enumerate_with_start_kernel(result, tensors_item_0, tensors_item_1, result_size_0, result_stride_0, tensors_item_0_stride_0, tensors_item_1_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < result_size_0 + acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32) + load = tl.load(tensors_item_0 + indices_0 * tensors_item_0_stride_0, mask_0, other=0) + v_0 = 5.0 + v_1 = load * v_0 + v_2 = acc + v_1 + load_1 = tl.load(tensors_item_1 + indices_0 * tensors_item_1_stride_0, mask_0, other=0) + v_3 = 6.0 + v_4 = load_1 * v_3 + v_5 = v_2 + v_4 + tl.store(result + indices_0 * result_stride_0, v_5, mask_0) + +def kernel_enumerate_with_start(tensors: tuple[torch.Tensor, torch.Tensor], *, _launcher=_default_launcher): + """Test enumerate with custom start value.""" + result = torch.zeros_like(tensors[0]) + _BLOCK_SIZE_0 = 32 + _launcher(_kernel_enumerate_with_start_kernel, (triton.cdiv(result.size(0), _BLOCK_SIZE_0),), result, tensors[0], tensors[1], result.size(0), result.stride(0), tensors[0].stride(0), tensors[1].stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return result + +--- assertExpectedJournal(TestUnrollTuples.test_list_constants_iteration) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _kernel_list_constants_iteration_kernel(x, result, x_size_0, result_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32) + load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_0 = 0.5 + v_1 = load * v_0 + v_2 = acc + v_1 + load_1 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_3 = 1.5 + v_4 = load_1 * v_3 + v_5 = v_2 + v_4 + load_2 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_6 = 2.5 + v_7 = load_2 * v_6 + v_8 = v_5 + v_7 + tl.store(result + indices_0 * result_stride_0, v_8, mask_0) + +def kernel_list_constants_iteration(x: torch.Tensor, *, _launcher=_default_launcher): + """Test iteration over a list of constants.""" + result = torch.zeros_like(x) + _BLOCK_SIZE_0 = 32 + _launcher(_kernel_list_constants_iteration_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, result, x.size(0), result.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return result + +--- assertExpectedJournal(TestUnrollTuples.test_mixed_constants_and_tensors) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _kernel_mixed_constants_and_tensors_kernel(result, tensors_item_0, tensors_item_1, result_size_0, result_stride_0, tensors_item_0_stride_0, tensors_item_1_stride_0, constants_item_0, constants_item_1, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < result_size_0 + acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32) + load = tl.load(tensors_item_0 + indices_0 * tensors_item_0_stride_0, mask_0, other=0) + v_0 = acc + load + load_1 = tl.load(tensors_item_1 + indices_0 * tensors_item_1_stride_0, mask_0, other=0) + v_1 = v_0 + load_1 + v_2 = constants_item_0.to(tl.float32) + v_3 = v_1 * v_2 + v_4 = constants_item_1.to(tl.float32) + v_5 = v_3 * v_4 + tl.store(result + indices_0 * result_stride_0, v_5, mask_0) + +def kernel_mixed_constants_and_tensors(tensors: tuple[torch.Tensor, torch.Tensor], constants: tuple[int, int], *, _launcher=_default_launcher): + """Test mixed iteration over both tensors and constants.""" + result = torch.zeros_like(tensors[0]) + _BLOCK_SIZE_0 = 32 + _launcher(_kernel_mixed_constants_and_tensors_kernel, (triton.cdiv(result.size(0), _BLOCK_SIZE_0),), result, tensors[0], tensors[1], result.size(0), result.stride(0), tensors[0].stride(0), tensors[1].stride(0), constants[0], constants[1], _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return result + +--- assertExpectedJournal(TestUnrollTuples.test_nested_tuple_iteration) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _kernel_nested_tuple_iteration_kernel(result, a_tuple_item_0, a_tuple_item_1, b_tuple_item_0, b_tuple_item_1, result_size_0, a_tuple_item_0_stride_0, a_tuple_item_1_stride_0, b_tuple_item_0_stride_0, b_tuple_item_1_stride_0, result_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < result_size_0 + temp = tl.full([_BLOCK_SIZE_0], 0, tl.float32) + load = tl.load(a_tuple_item_0 + indices_0 * a_tuple_item_0_stride_0, mask_0, other=0) + v_0 = temp + load + load_1 = tl.load(a_tuple_item_1 + indices_0 * a_tuple_item_1_stride_0, mask_0, other=0) + v_1 = v_0 + load_1 + load_2 = tl.load(b_tuple_item_0 + indices_0 * b_tuple_item_0_stride_0, mask_0, other=0) + v_2 = v_1 * load_2 + load_3 = tl.load(b_tuple_item_1 + indices_0 * b_tuple_item_1_stride_0, mask_0, other=0) + v_3 = v_2 * load_3 + tl.store(result + indices_0 * result_stride_0, v_3, mask_0) + +def kernel_nested_tuple_iteration(a_tuple: tuple[torch.Tensor, torch.Tensor], b_tuple: tuple[torch.Tensor, torch.Tensor], *, _launcher=_default_launcher): + """Test nested iteration over multiple tuples.""" + result = torch.zeros_like(a_tuple[0]) + _BLOCK_SIZE_0 = 64 + _launcher(_kernel_nested_tuple_iteration_kernel, (triton.cdiv(result.size(0), _BLOCK_SIZE_0),), result, a_tuple[0], a_tuple[1], b_tuple[0], b_tuple[1], result.size(0), a_tuple[0].stride(0), a_tuple[1].stride(0), b_tuple[0].stride(0), b_tuple[1].stride(0), result.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return result + +--- assertExpectedJournal(TestUnrollTuples.test_single_element_tuple) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _kernel_tuple_addition_kernel(out, a_shared_tuple_item_0, out_size_0, a_shared_tuple_item_0_stride_0, out_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < out_size_0 + acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32) + load = tl.load(a_shared_tuple_item_0 + indices_0 * a_shared_tuple_item_0_stride_0, mask_0, other=0) + v_0 = acc + load + tl.store(out + indices_0 * out_stride_0, v_0, mask_0) + +def kernel_tuple_addition(a_shared_tuple: tuple[torch.Tensor, ...], *, _launcher=_default_launcher): + """Basic test: iterate over a tuple of tensors and sum them.""" + out = torch.empty_like(a_shared_tuple[0]) + _BLOCK_SIZE_0 = 16 + _launcher(_kernel_tuple_addition_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, a_shared_tuple[0], out.size(0), a_shared_tuple[0].stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestUnrollTuples.test_static_range_iteration) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _kernel_static_range_iteration_kernel(x, result, x_size_0, result_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32) + load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_0 = 1.0 + v_1 = load * v_0 + v_2 = acc + v_1 + load_1 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_3 = 2.0 + v_4 = load_1 * v_3 + v_5 = v_2 + v_4 + load_2 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_6 = 3.0 + v_7 = load_2 * v_6 + v_8 = v_5 + v_7 + load_3 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_9 = 4.0 + v_10 = load_3 * v_9 + v_11 = v_8 + v_10 + tl.store(result + indices_0 * result_stride_0, v_11, mask_0) + +def kernel_static_range_iteration(x: torch.Tensor, *, _launcher=_default_launcher): + """Test iteration using hl.static_range.""" + result = torch.zeros_like(x) + _BLOCK_SIZE_0 = 32 + _launcher(_kernel_static_range_iteration_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, result, x.size(0), result.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return result + +--- assertExpectedJournal(TestUnrollTuples.test_static_range_with_start) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _kernel_static_range_with_start_kernel(x, result, x_size_0, result_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32) + load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_0 = 2.0 + v_1 = load * v_0 + v_2 = acc + v_1 + load_1 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_3 = 3.0 + v_4 = load_1 * v_3 + v_5 = v_2 + v_4 + load_2 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_6 = 4.0 + v_7 = load_2 * v_6 + v_8 = v_5 + v_7 + tl.store(result + indices_0 * result_stride_0, v_8, mask_0) + +def kernel_static_range_with_start(x: torch.Tensor, *, _launcher=_default_launcher): + """Test static_range with start parameter.""" + result = torch.zeros_like(x) + _BLOCK_SIZE_0 = 32 + _launcher(_kernel_static_range_with_start_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, result, x.size(0), result.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return result + +--- assertExpectedJournal(TestUnrollTuples.test_tuple_with_scaling_factors) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _kernel_tuple_with_scaling_kernel(tensor3, tensor1, tensor2, output, tensor3_size_0, output_stride_0, tensor1_stride_0, tensor2_stride_0, tensor3_stride_0, scale1, scale2, scale3, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < tensor3_size_0 + temp = tl.full([_BLOCK_SIZE_0], 0, tl.float32) + load = tl.load(tensor1 + indices_0 * tensor1_stride_0, mask_0, other=0) + v_0 = load * scale1 + v_1 = temp + v_0 + load_1 = tl.load(tensor2 + indices_0 * tensor2_stride_0, mask_0, other=0) + v_2 = load_1 * scale2 + v_3 = v_1 + v_2 + load_2 = tl.load(tensor3 + indices_0 * tensor3_stride_0, mask_0, other=0) + v_4 = load_2 * scale3 + v_5 = v_3 + v_4 + tl.store(output + indices_0 * output_stride_0, v_5, mask_0) + +def kernel_tuple_with_scaling(tensor1: torch.Tensor, tensor2: torch.Tensor, tensor3: torch.Tensor, scale1: float, scale2: float, scale3: float, *, _launcher=_default_launcher): + """Test iteration over tensors with corresponding scalar multipliers.""" + output = torch.zeros_like(tensor1) + _BLOCK_SIZE_0 = 64 + _launcher(_kernel_tuple_with_scaling_kernel, (triton.cdiv(tensor3.size(0), _BLOCK_SIZE_0),), tensor3, tensor1, tensor2, output, tensor3.size(0), output.stride(0), tensor1.stride(0), tensor2.stride(0), tensor3.stride(0), scale1, scale2, scale3, _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return output + +--- assertExpectedJournal(TestUnrollTuples.test_zip_iteration) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _kernel_zip_iteration_kernel(result, tensors_a_item_0, tensors_b_item_0, tensors_a_item_1, tensors_b_item_1, result_size_0, result_stride_0, tensors_a_item_0_stride_0, tensors_a_item_1_stride_0, tensors_b_item_0_stride_0, tensors_b_item_1_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < result_size_0 + acc = tl.full([_BLOCK_SIZE_0], 0, tl.float32) + load = tl.load(tensors_a_item_0 + indices_0 * tensors_a_item_0_stride_0, mask_0, other=0) + load_1 = tl.load(tensors_b_item_0 + indices_0 * tensors_b_item_0_stride_0, mask_0, other=0) + v_0 = load * load_1 + v_1 = acc + v_0 + load_2 = tl.load(tensors_a_item_1 + indices_0 * tensors_a_item_1_stride_0, mask_0, other=0) + load_3 = tl.load(tensors_b_item_1 + indices_0 * tensors_b_item_1_stride_0, mask_0, other=0) + v_2 = load_2 * load_3 + v_3 = v_1 + v_2 + tl.store(result + indices_0 * result_stride_0, v_3, mask_0) + +def kernel_zip_iteration(tensors_a: tuple[torch.Tensor, torch.Tensor], tensors_b: tuple[torch.Tensor, torch.Tensor], *, _launcher=_default_launcher): + """Test iteration over zip of tuples.""" + result = torch.zeros_like(tensors_a[0]) + _BLOCK_SIZE_0 = 64 + _launcher(_kernel_zip_iteration_kernel, (triton.cdiv(result.size(0), _BLOCK_SIZE_0),), result, tensors_a[0], tensors_b[0], tensors_a[1], tensors_b[1], result.size(0), result.stride(0), tensors_a[0].stride(0), tensors_a[1].stride(0), tensors_b[0].stride(0), tensors_b[1].stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return result diff --git a/test/test_unroll_tuples.py b/test/test_unroll_tuples.py new file mode 100644 index 00000000..0869df72 --- /dev/null +++ b/test/test_unroll_tuples.py @@ -0,0 +1,438 @@ +from __future__ import annotations + +import unittest + +import torch + +import helion +from helion._testing import DEVICE +from helion._testing import TestCase +from helion._testing import code_and_output +import helion.language as hl + + +@helion.kernel(use_default_config=True) +def kernel_tuple_addition( + a_shared_tuple: tuple[torch.Tensor, ...], +) -> torch.Tensor: + """Basic test: iterate over a tuple of tensors and sum them.""" + out = torch.empty_like(a_shared_tuple[0]) + for tile_n in hl.tile(out.size(0)): + acc = torch.zeros([tile_n], dtype=torch.float32, device=out.device) + for a_tensor in a_shared_tuple: + acc += a_tensor[tile_n] + out[tile_n] = acc + return out + + +@helion.kernel(use_default_config=True) +def kernel_tuple_with_scaling( + tensor1: torch.Tensor, + tensor2: torch.Tensor, + tensor3: torch.Tensor, + scale1: float, + scale2: float, + scale3: float, +) -> torch.Tensor: + """Test iteration over tensors with corresponding scalar multipliers.""" + tensors = (tensor1, tensor2, tensor3) + scales = (scale1, scale2, scale3) + output = torch.zeros_like(tensor1) + for tile_idx in hl.tile(output.size(0)): + temp = torch.zeros([tile_idx], dtype=torch.float32, device=output.device) + for tensor, scale in zip(tensors, scales, strict=True): + temp += tensor[tile_idx] * scale + output[tile_idx] = temp + return output + + +@helion.kernel(use_default_config=True) +def kernel_nested_tuple_iteration( + a_tuple: tuple[torch.Tensor, torch.Tensor], + b_tuple: tuple[torch.Tensor, torch.Tensor], +) -> torch.Tensor: + """Test nested iteration over multiple tuples.""" + result = torch.zeros_like(a_tuple[0]) + for tile_idx in hl.tile(result.size(0)): + temp = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) + + for a_tensor in a_tuple: + temp += a_tensor[tile_idx] + + for b_tensor in b_tuple: + temp *= b_tensor[tile_idx] + + result[tile_idx] = temp + return result + + +@helion.kernel(use_default_config=True) +def kernel_constants_iteration( + x: torch.Tensor, +) -> torch.Tensor: + """Test iteration over a tuple/list of constants.""" + result = torch.zeros_like(x) + for tile_idx in hl.tile(result.size(0)): + acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) + # Iterate over constants + for multiplier in (1, 2, 3): + acc += x[tile_idx] * multiplier + result[tile_idx] = acc + return result + + +@helion.kernel(use_default_config=True) +def kernel_list_constants_iteration( + x: torch.Tensor, +) -> torch.Tensor: + """Test iteration over a list of constants.""" + result = torch.zeros_like(x) + for tile_idx in hl.tile(result.size(0)): + acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) + # Iterate over constants in a list + for multiplier in [0.5, 1.5, 2.5]: + acc += x[tile_idx] * multiplier + result[tile_idx] = acc + return result + + +@helion.kernel(use_default_config=True) +def kernel_zip_iteration( + tensors_a: tuple[torch.Tensor, torch.Tensor], + tensors_b: tuple[torch.Tensor, torch.Tensor], +) -> torch.Tensor: + """Test iteration over zip of tuples.""" + result = torch.zeros_like(tensors_a[0]) + for tile_idx in hl.tile(result.size(0)): + acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) + # Iterate over zip of tensors + for a_tensor, b_tensor in zip(tensors_a, tensors_b, strict=False): + acc += a_tensor[tile_idx] * b_tensor[tile_idx] + result[tile_idx] = acc + return result + + +@helion.kernel(use_default_config=True) +def kernel_static_range_iteration( + x: torch.Tensor, +) -> torch.Tensor: + """Test iteration using hl.static_range.""" + result = torch.zeros_like(x) + for tile_idx in hl.tile(result.size(0)): + acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) + # Use static_range for unrolled loop + for i in hl.static_range(4): + acc += x[tile_idx] * (i + 1) + result[tile_idx] = acc + return result + + +@helion.kernel(use_default_config=True) +def kernel_static_range_with_start( + x: torch.Tensor, +) -> torch.Tensor: + """Test static_range with start parameter.""" + result = torch.zeros_like(x) + for tile_idx in hl.tile(result.size(0)): + acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) + # Use static_range(start, end) + for i in hl.static_range(2, 5): + acc += x[tile_idx] * i + result[tile_idx] = acc + return result + + +@helion.kernel(use_default_config=True) +def kernel_mixed_constants_and_tensors( + tensors: tuple[torch.Tensor, torch.Tensor], + constants: tuple[int, int], +) -> torch.Tensor: + """Test mixed iteration over both tensors and constants.""" + result = torch.zeros_like(tensors[0]) + for tile_idx in hl.tile(result.size(0)): + acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) + + # First, iterate over tensors + for tensor in tensors: + acc += tensor[tile_idx] + + # Then, iterate over constants and multiply + for constant in constants: + acc *= constant + + result[tile_idx] = acc + return result + + +@helion.kernel(use_default_config=True) +def kernel_enumerate_iteration( + tensors: tuple[torch.Tensor, torch.Tensor, torch.Tensor], +) -> torch.Tensor: + """Test iteration using enumerate over tensors.""" + result = torch.zeros_like(tensors[0]) + for tile_idx in hl.tile(result.size(0)): + acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) + # Iterate with enumerate to get index and tensor + for i, tensor in enumerate(tensors): + acc += tensor[tile_idx] * (i + 1) # Weight by index + 1 + result[tile_idx] = acc + return result + + +@helion.kernel(use_default_config=True) +def kernel_enumerate_with_start( + tensors: tuple[torch.Tensor, torch.Tensor], +) -> torch.Tensor: + """Test enumerate with custom start value.""" + result = torch.zeros_like(tensors[0]) + for tile_idx in hl.tile(result.size(0)): + acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) + # Enumerate starting from 5 + for i, tensor in enumerate(tensors, start=5): + acc += tensor[tile_idx] * i + result[tile_idx] = acc + return result + + +@helion.kernel(use_default_config=True) +def kernel_enumerate_constants( + x: torch.Tensor, +) -> torch.Tensor: + """Test enumerate over constants.""" + result = torch.zeros_like(x) + for tile_idx in hl.tile(result.size(0)): + acc = torch.zeros([tile_idx], dtype=torch.float32, device=result.device) + # Enumerate over constant values + for i, multiplier in enumerate((2, 3, 4)): + acc += x[tile_idx] * multiplier * i + result[tile_idx] = acc + return result + + +class TestUnrollTuples(TestCase): + def test_basic_tuple_addition(self): + """Test basic iteration over tuple of tensors with addition.""" + size = (32,) + tensor1 = torch.randn(size, device=DEVICE) + tensor2 = torch.randn(size, device=DEVICE) + tensor3 = torch.randn(size, device=DEVICE) + + tuple_arg = (tensor1, tensor2, tensor3) + + code, result = code_and_output(kernel_tuple_addition, (tuple_arg,)) + + # Validate generated code + self.assertExpectedJournal(code) + + # Test correctness + expected = tensor1 + tensor2 + tensor3 + torch.testing.assert_close(result, expected) + + def test_tuple_with_scaling_factors(self): + """Test iteration with corresponding scalar values.""" + size = (48,) + tensor1 = torch.randn(size, device=DEVICE) + tensor2 = torch.randn(size, device=DEVICE) + tensor3 = torch.randn(size, device=DEVICE) + + scale1, scale2, scale3 = 2.0, 0.5, 1.5 + + code, result = code_and_output( + kernel_tuple_with_scaling, + (tensor1, tensor2, tensor3, scale1, scale2, scale3), + ) + + # Validate generated code + self.assertExpectedJournal(code) + + # Test correctness + expected = tensor1 * scale1 + tensor2 * scale2 + tensor3 * scale3 + torch.testing.assert_close(result, expected) + + def test_nested_tuple_iteration(self): + """Test nested loops over multiple tuples.""" + size = (40,) + a1 = torch.randn(size, device=DEVICE) + a2 = torch.randn(size, device=DEVICE) + b1 = torch.randn(size, device=DEVICE) + 1.0 # Avoid zeros for multiplication + b2 = torch.randn(size, device=DEVICE) + 1.0 + + a_tuple = (a1, a2) + b_tuple = (b1, b2) + + code, result = code_and_output( + kernel_nested_tuple_iteration, (a_tuple, b_tuple) + ) + + # Validate generated code + self.assertExpectedJournal(code) + + # Test correctness + temp = a1 + a2 + expected = temp * b1 * b2 + torch.testing.assert_close(result, expected) + + def test_single_element_tuple(self): + """Test with single-element tuple.""" + size = (16,) + tensor = torch.randn(size, device=DEVICE) + + tuple_arg = (tensor,) + + code, result = code_and_output(kernel_tuple_addition, (tuple_arg,)) + + # Validate generated code + self.assertExpectedJournal(code) + + # Test correctness - should just copy the tensor + torch.testing.assert_close(result, tensor) + + def test_constants_iteration(self): + """Test iteration over tuple of constants.""" + size = (24,) + x = torch.randn(size, device=DEVICE) + + code, result = code_and_output(kernel_constants_iteration, (x,)) + + # Validate generated code + self.assertExpectedJournal(code) + + # Test correctness - should be x * (1 + 2 + 3) = x * 6 + expected = x * 6 + torch.testing.assert_close(result, expected) + + def test_list_constants_iteration(self): + """Test iteration over list of constants.""" + size = (20,) + x = torch.randn(size, device=DEVICE) + + code, result = code_and_output(kernel_list_constants_iteration, (x,)) + + # Validate generated code + self.assertExpectedJournal(code) + + # Test correctness - should be x * (0.5 + 1.5 + 2.5) = x * 4.5 + expected = x * 4.5 + torch.testing.assert_close(result, expected) + + def test_zip_iteration(self): + """Test iteration over zip of tuples.""" + # Create one reference tensor and use it to create others with same size + reference = torch.randn((36,), device=DEVICE) + a1 = torch.randn_like(reference) + a2 = torch.randn_like(reference) + b1 = torch.randn_like(reference) + b2 = torch.randn_like(reference) + + tensors_a = (a1, a2) + tensors_b = (b1, b2) + + code, result = code_and_output(kernel_zip_iteration, (tensors_a, tensors_b)) + + # Validate generated code + self.assertExpectedJournal(code) + + # Test correctness - should be a1*b1 + a2*b2 + expected = a1 * b1 + a2 * b2 + torch.testing.assert_close(result, expected) + + def test_static_range_iteration(self): + """Test iteration using hl.static_range.""" + size = (28,) + x = torch.randn(size, device=DEVICE) + + code, result = code_and_output(kernel_static_range_iteration, (x,)) + + # Validate generated code + self.assertExpectedJournal(code) + + # Test correctness - should be x * (1 + 2 + 3 + 4) = x * 10 + expected = x * 10 + torch.testing.assert_close(result, expected) + + def test_static_range_with_start(self): + """Test static_range with start parameter.""" + size = (18,) + x = torch.randn(size, device=DEVICE) + + code, result = code_and_output(kernel_static_range_with_start, (x,)) + + # Validate generated code + self.assertExpectedJournal(code) + + # Test correctness - should be x * (2 + 3 + 4) = x * 9 + expected = x * 9 + torch.testing.assert_close(result, expected) + + def test_mixed_constants_and_tensors(self): + """Test mixed iteration over both tensors and constants.""" + size = (22,) + tensor1 = torch.randn(size, device=DEVICE) + tensor2 = torch.randn(size, device=DEVICE) + + tensors = (tensor1, tensor2) + constants = (2, 3) + + code, result = code_and_output( + kernel_mixed_constants_and_tensors, (tensors, constants) + ) + + # Validate generated code + self.assertExpectedJournal(code) + + # Test correctness - should be (tensor1 + tensor2) * 2 * 3 + expected = (tensor1 + tensor2) * 2 * 3 + torch.testing.assert_close(result, expected) + + def test_enumerate_iteration(self): + """Test iteration using enumerate over tensors.""" + size = (24,) + tensor1 = torch.randn(size, device=DEVICE) + tensor2 = torch.randn(size, device=DEVICE) + tensor3 = torch.randn(size, device=DEVICE) + + tensors = (tensor1, tensor2, tensor3) + + code, result = code_and_output(kernel_enumerate_iteration, (tensors,)) + + # Validate generated code + self.assertExpectedJournal(code) + + # Test correctness - should be tensor1*1 + tensor2*2 + tensor3*3 + expected = tensor1 * 1 + tensor2 * 2 + tensor3 * 3 + torch.testing.assert_close(result, expected) + + def test_enumerate_with_start(self): + """Test enumerate with custom start value.""" + size = (18,) + tensor1 = torch.randn(size, device=DEVICE) + tensor2 = torch.randn(size, device=DEVICE) + + tensors = (tensor1, tensor2) + + code, result = code_and_output(kernel_enumerate_with_start, (tensors,)) + + # Validate generated code + self.assertExpectedJournal(code) + + # Test correctness - should be tensor1*5 + tensor2*6 (start=5) + expected = tensor1 * 5 + tensor2 * 6 + torch.testing.assert_close(result, expected) + + def test_enumerate_constants(self): + """Test enumerate over constants.""" + size = (20,) + x = torch.randn(size, device=DEVICE) + + code, result = code_and_output(kernel_enumerate_constants, (x,)) + + # Validate generated code + self.assertExpectedJournal(code) + + # Test correctness - should be x*(2*0 + 3*1 + 4*2) = x*(0 + 3 + 8) = x*11 + expected = x * 11 + torch.testing.assert_close(result, expected) + + +if __name__ == "__main__": + unittest.main()