Skip to content

Implement static tuple unrolling and hl.static_range #329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,11 +502,32 @@ def _extract_tile_begin_end(self, for_node: ast.For) -> tuple[object, object]:
end = self.visit(args[1])
return begin, end

def _handle_tuple_unrolling(
self,
node: ast.For,
) -> None:
"""Handle unrolling of loops that iterate over tuples of tensors."""
# Get the sequence of tensors to iterate over
sequence_value = self.visit(node.iter)
assert isinstance(sequence_value, (tuple, list)), (
f"Expected tuple or list, got {type(sequence_value)}"
)
# Unroll the loop by executing the body for each tensor in the sequence
for tensor_value in sequence_value:
self._assign(node.target, tensor_value)
self._body(node.body)

def visit_For(self, node: ast.For) -> None:
assert isinstance(node, ExtendedAST)
assert not node.orelse
assert isinstance(node.iter, ExtendedAST)
iter_type = node.iter._type_info

# Check if we're iterating directly over a sequence (tuple unrolling)
if isinstance(iter_type, SequenceType):
self._handle_tuple_unrolling(node)
return

if not isinstance(iter_type, IterType):
raise exc.InvalidDeviceForLoop(iter_type)
inner_type: TypeInfo = iter_type.inner
Expand Down
1 change: 1 addition & 0 deletions helion/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .device_print import device_print as device_print
from .inline_asm_ops import inline_asm_elementwise as inline_asm_elementwise
from .loops import grid as grid
from .loops import static_range as static_range
from .loops import tile as tile
from .memory_ops import atomic_add as atomic_add
from .memory_ops import load as load
Expand Down
267 changes: 266 additions & 1 deletion helion/language/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.type_propagation import GridIndexType
from .._compiler.type_propagation import IterType
from .._compiler.type_propagation import LiteralType
from .._compiler.type_propagation import Origin
from .._compiler.type_propagation import SequenceType
from .._compiler.type_propagation import TileIndexType
from .._compiler.type_propagation import TypeInfo
from .._compiler.variable_origin import GetItemOrigin
from ..autotuner.config_spec import ConfigSpec
from ..autotuner.config_spec import FlattenLoopSpec
from ..autotuner.config_spec import L2GroupingSpec
Expand All @@ -48,7 +50,7 @@
from .._compiler.inductor_lowering import CodegenState


__all__ = ["grid", "tile"]
__all__ = ["grid", "static_range", "tile"]


@overload
Expand Down Expand Up @@ -633,3 +635,266 @@ def _(
@_decorators.codegen(grid)
def _(state: CodegenState) -> ast.AST:
return _codegen_loop_helper(state)


@_decorators.device_func_replacement(builtins.zip)
@_decorators.api(is_device_only=True, cache_type=True)
def _zip_replacement(
*args: tuple[object, ...] | list[object],
strict: bool = False,
) -> tuple[tuple[object, ...], ...]:
"""
Device replacement for zip() that returns tuples for unrolling.

This replacement enables zip() to work in device kernels by converting
the zip result to a tuple of tuples, which can then be unrolled by the
existing tuple iteration logic.

Args:
*args: Sequences to zip together

Returns:
Tuple of tuples containing zipped elements

Examples:
.. code-block:: python

@helion.kernel
def kernel_with_zip(a_tensors, b_tensors):
for a, b in zip(a_tensors, b_tensors):
# This gets unrolled at compile time
result += a * b
"""
raise exc.NotInsideKernel


@_decorators.type_propagation(_zip_replacement)
def _(
*args: TypeInfo,
origin: Origin,
**kwargs: object,
) -> TypeInfo:
"""Type propagation for zip replacement that preserves tensor types."""
# Accept but ignore the strict keyword argument
if not args:
return SequenceType(origin, ())

# Convert all arguments to SequenceType
sequences = []
for arg in args:
if not isinstance(arg, SequenceType):
raise exc.TypeInferenceError(
f"zip() argument must be a sequence, got {arg}"
)
sequences.append(arg.unpack())

# Check all sequences have the same length
length = 0
if sequences:
length = len(sequences[0])
for i, seq in enumerate(sequences[1:], 1):
if len(seq) != length:
raise exc.TypeInferenceError(
f"zip() argument {i} has length {len(seq)}, expected {length}"
)

# Build result as tuple of tuples, preserving existing TypeInfo objects
result_elements = []
for i in range(length):
# Create a tuple containing the i-th element from each sequence
tuple_elements = tuple(seq[i] for seq in sequences)
tuple_type = SequenceType(GetItemOrigin(origin, i), tuple_elements)
result_elements.append(tuple_type)

return SequenceType(origin, tuple(result_elements))


@_decorators.register_to_device_ir(_zip_replacement)
def _(
tracer: object,
*flat_args: object,
) -> object:
"""Device IR handler for zip - returns the zipped result for unrolling."""
# flat_args contains the prepared arguments: (tensor_sequences, strict_value)
if not flat_args:
return ()

# Extract sequences and strict parameter
if len(flat_args) == 2:
sequences = flat_args[0] # This should be the tuple of sequences
strict = flat_args[1] # This should be the strict parameter
assert isinstance(strict, bool)
else:
assert len(flat_args) == 1
sequences = flat_args[0]
strict = False
return [*builtins.zip(*sequences, strict=strict)] # type: ignore[arg-type]


@_decorators.device_func_replacement(builtins.enumerate)
@_decorators.api(is_device_only=True, cache_type=True)
def _enumerate_replacement(
iterable: tuple[object, ...] | list[object],
start: int = 0,
) -> tuple[tuple[int, object], ...]:
"""
Device replacement for enumerate() that returns tuples for unrolling.

This replacement enables enumerate() to work in device kernels by converting
the enumerate result to a tuple of (index, value) tuples, which can then be
unrolled by the existing tuple iteration logic.

Args:
iterable: Sequence to enumerate
start: Starting value for the counter (default: 0)

Returns:
Tuple of (index, value) tuples
"""
raise exc.NotInsideKernel


@_decorators.type_propagation(_enumerate_replacement)
def _(
iterable: TypeInfo,
start: TypeInfo | None = None,
*,
origin: Origin,
) -> TypeInfo:
"""Type propagation for enumerate replacement that preserves tensor types."""
if not isinstance(iterable, SequenceType):
raise exc.TypeInferenceError(
f"enumerate() argument must be a sequence, got {iterable}"
)

# Get the start value
start_value = 0
if start is not None and start.is_literal():
start_val = start.as_literal()
if isinstance(start_val, int):
start_value = start_val

# Build result as tuple of (index, value) tuples
sequence_elements = iterable.unpack()
result_elements = []

for i, element in enumerate(sequence_elements):
# Create (index, value) tuple
index_literal = LiteralType(origin, start_value + i)
tuple_elements = (index_literal, element)
tuple_type = SequenceType(GetItemOrigin(origin, i), tuple_elements)
result_elements.append(tuple_type)

return SequenceType(origin, tuple(result_elements))


@_decorators.register_to_device_ir(_enumerate_replacement)
def _(
tracer: object,
*flat_args: object,
) -> object:
"""Device IR handler for enumerate - returns the enumerated result for unrolling."""
if len(flat_args) == 2:
iterable = flat_args[0]
start = flat_args[1]
assert isinstance(start, int)
else:
assert len(flat_args) == 1
iterable = flat_args[0]
start = 0
return [*builtins.enumerate(iterable, start=start)] # type: ignore[arg-type]


@_decorators.api(is_device_only=True, cache_type=True)
def static_range(
begin_or_end: int,
end_or_none: int | None = None,
/,
step: int = 1,
) -> Iterator[int]:
"""
Create a range that gets unrolled at compile time by iterating over constant integer values.

This function is similar to Python's built-in range(), but it generates a sequence
of integer constants that triggers loop unrolling behavior in Helion kernels. The loop
is completely unrolled at compile time, with each iteration becoming separate
instructions in the generated code.

Args:
begin_or_end: If 2+ positional args provided, the start of range (integer).
Otherwise, the end of range (integer).
end_or_none: If 2+ positional args provided, the end of range (integer).
step: Step size for iteration (integer, default: 1)

Returns:
Iterator[int]: Iterator over constant integer values

Examples:
Simple unrolled loop:

.. code-block:: python

@helion.kernel
def unrolled_example(x: torch.Tensor) -> torch.Tensor:
result = torch.zeros_like(x)

for tile in hl.tile(x.size(0)):
acc = torch.zeros([tile], dtype=x.dtype, device=x.device)
# This loop gets completely unrolled
for i in hl.static_range(3):
acc += x[tile] * i
result[tile] = acc

return result

Range with start and step:

.. code-block:: python

@helion.kernel
def kernel_stepped_unroll(x: torch.Tensor) -> torch.Tensor:
result = torch.zeros_like(x)

for tile in hl.tile(x.size(0)):
acc = torch.zeros([tile], dtype=x.dtype, device=x.device)
# Unroll loop from 2 to 8 with step 2: [2, 4, 6]
for i in hl.static_range(2, 8, 2):
acc += x[tile] * i
result[tile] = acc

return result

Note:
- Only constant integer values are supported
- The range must be small enough to avoid compilation timeouts
- Each iteration becomes separate instructions in the generated Triton code
- Use for small, fixed iteration counts where unrolling is beneficial
"""
raise exc.NotInsideKernel


@_decorators.register_fake(static_range)
def _(
begin_or_end: int,
end_or_none: int | None = None,
/,
step: int = 1,
) -> tuple[int, ...]:
"""Fake function for static_range - validates integer constants and returns tuple(range(...))."""
# Validate that inputs are compile-time constants
if end_or_none is not None:
begin_val = begin_or_end
end_val = end_or_none
else:
begin_val = 0
end_val = begin_or_end

if (
not isinstance(begin_val, int)
or not isinstance(end_val, int)
or not isinstance(step, int)
):
raise exc.TypeInferenceError("static_range requires constant integer arguments")

# Return tuple(range(...)) which will trigger existing tuple/list unrolling
return tuple(range(begin_val, end_val, step))
Loading
Loading