Skip to content

Commit 25af263

Browse files
committed
Implement static tuple unrolling and hl.static_range
stack-info: PR: #329, branch: jansel/stack/115
1 parent 56e78ad commit 25af263

File tree

5 files changed

+1171
-1
lines changed

5 files changed

+1171
-1
lines changed

helion/_compiler/device_ir.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,11 +502,32 @@ def _extract_tile_begin_end(self, for_node: ast.For) -> tuple[object, object]:
502502
end = self.visit(args[1])
503503
return begin, end
504504

505+
def _handle_tuple_unrolling(
506+
self,
507+
node: ast.For,
508+
) -> None:
509+
"""Handle unrolling of loops that iterate over tuples of tensors."""
510+
# Get the sequence of tensors to iterate over
511+
sequence_value = self.visit(node.iter)
512+
assert isinstance(sequence_value, (tuple, list)), (
513+
f"Expected tuple or list, got {type(sequence_value)}"
514+
)
515+
# Unroll the loop by executing the body for each tensor in the sequence
516+
for tensor_value in sequence_value:
517+
self._assign(node.target, tensor_value)
518+
self._body(node.body)
519+
505520
def visit_For(self, node: ast.For) -> None:
506521
assert isinstance(node, ExtendedAST)
507522
assert not node.orelse
508523
assert isinstance(node.iter, ExtendedAST)
509524
iter_type = node.iter._type_info
525+
526+
# Check if we're iterating directly over a sequence (tuple unrolling)
527+
if isinstance(iter_type, SequenceType):
528+
self._handle_tuple_unrolling(node)
529+
return
530+
510531
if not isinstance(iter_type, IterType):
511532
raise exc.InvalidDeviceForLoop(iter_type)
512533
inner_type: TypeInfo = iter_type.inner

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .device_print import device_print as device_print
99
from .inline_asm_ops import inline_asm_elementwise as inline_asm_elementwise
1010
from .loops import grid as grid
11+
from .loops import static_range as static_range
1112
from .loops import tile as tile
1213
from .memory_ops import atomic_add as atomic_add
1314
from .memory_ops import load as load

helion/language/loops.py

Lines changed: 266 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525
from .._compiler.compile_environment import CompileEnvironment
2626
from .._compiler.type_propagation import GridIndexType
2727
from .._compiler.type_propagation import IterType
28+
from .._compiler.type_propagation import LiteralType
2829
from .._compiler.type_propagation import Origin
2930
from .._compiler.type_propagation import SequenceType
3031
from .._compiler.type_propagation import TileIndexType
3132
from .._compiler.type_propagation import TypeInfo
33+
from .._compiler.variable_origin import GetItemOrigin
3234
from ..autotuner.config_spec import ConfigSpec
3335
from ..autotuner.config_spec import FlattenLoopSpec
3436
from ..autotuner.config_spec import L2GroupingSpec
@@ -48,7 +50,7 @@
4850
from .._compiler.inductor_lowering import CodegenState
4951

5052

51-
__all__ = ["grid", "tile"]
53+
__all__ = ["grid", "static_range", "tile"]
5254

5355

5456
@overload
@@ -633,3 +635,266 @@ def _(
633635
@_decorators.codegen(grid)
634636
def _(state: CodegenState) -> ast.AST:
635637
return _codegen_loop_helper(state)
638+
639+
640+
@_decorators.device_func_replacement(builtins.zip)
641+
@_decorators.api(is_device_only=True, cache_type=True)
642+
def _zip_replacement(
643+
*args: tuple[object, ...] | list[object],
644+
strict: bool = False,
645+
) -> tuple[tuple[object, ...], ...]:
646+
"""
647+
Device replacement for zip() that returns tuples for unrolling.
648+
649+
This replacement enables zip() to work in device kernels by converting
650+
the zip result to a tuple of tuples, which can then be unrolled by the
651+
existing tuple iteration logic.
652+
653+
Args:
654+
*args: Sequences to zip together
655+
656+
Returns:
657+
Tuple of tuples containing zipped elements
658+
659+
Examples:
660+
.. code-block:: python
661+
662+
@helion.kernel
663+
def kernel_with_zip(a_tensors, b_tensors):
664+
for a, b in zip(a_tensors, b_tensors):
665+
# This gets unrolled at compile time
666+
result += a * b
667+
"""
668+
raise exc.NotInsideKernel
669+
670+
671+
@_decorators.type_propagation(_zip_replacement)
672+
def _(
673+
*args: TypeInfo,
674+
origin: Origin,
675+
**kwargs: object,
676+
) -> TypeInfo:
677+
"""Type propagation for zip replacement that preserves tensor types."""
678+
# Accept but ignore the strict keyword argument
679+
if not args:
680+
return SequenceType(origin, ())
681+
682+
# Convert all arguments to SequenceType
683+
sequences = []
684+
for arg in args:
685+
if not isinstance(arg, SequenceType):
686+
raise exc.TypeInferenceError(
687+
f"zip() argument must be a sequence, got {arg}"
688+
)
689+
sequences.append(arg.unpack())
690+
691+
# Check all sequences have the same length
692+
length = 0
693+
if sequences:
694+
length = len(sequences[0])
695+
for i, seq in enumerate(sequences[1:], 1):
696+
if len(seq) != length:
697+
raise exc.TypeInferenceError(
698+
f"zip() argument {i} has length {len(seq)}, expected {length}"
699+
)
700+
701+
# Build result as tuple of tuples, preserving existing TypeInfo objects
702+
result_elements = []
703+
for i in range(length):
704+
# Create a tuple containing the i-th element from each sequence
705+
tuple_elements = tuple(seq[i] for seq in sequences)
706+
tuple_type = SequenceType(GetItemOrigin(origin, i), tuple_elements)
707+
result_elements.append(tuple_type)
708+
709+
return SequenceType(origin, tuple(result_elements))
710+
711+
712+
@_decorators.register_to_device_ir(_zip_replacement)
713+
def _(
714+
tracer: object,
715+
*flat_args: object,
716+
) -> object:
717+
"""Device IR handler for zip - returns the zipped result for unrolling."""
718+
# flat_args contains the prepared arguments: (tensor_sequences, strict_value)
719+
if not flat_args:
720+
return ()
721+
722+
# Extract sequences and strict parameter
723+
if len(flat_args) == 2:
724+
sequences = flat_args[0] # This should be the tuple of sequences
725+
strict = flat_args[1] # This should be the strict parameter
726+
assert isinstance(strict, bool)
727+
else:
728+
assert len(flat_args) == 1
729+
sequences = flat_args[0]
730+
strict = False
731+
return [*builtins.zip(*sequences, strict=strict)] # type: ignore[arg-type]
732+
733+
734+
@_decorators.device_func_replacement(builtins.enumerate)
735+
@_decorators.api(is_device_only=True, cache_type=True)
736+
def _enumerate_replacement(
737+
iterable: tuple[object, ...] | list[object],
738+
start: int = 0,
739+
) -> tuple[tuple[int, object], ...]:
740+
"""
741+
Device replacement for enumerate() that returns tuples for unrolling.
742+
743+
This replacement enables enumerate() to work in device kernels by converting
744+
the enumerate result to a tuple of (index, value) tuples, which can then be
745+
unrolled by the existing tuple iteration logic.
746+
747+
Args:
748+
iterable: Sequence to enumerate
749+
start: Starting value for the counter (default: 0)
750+
751+
Returns:
752+
Tuple of (index, value) tuples
753+
"""
754+
raise exc.NotInsideKernel
755+
756+
757+
@_decorators.type_propagation(_enumerate_replacement)
758+
def _(
759+
iterable: TypeInfo,
760+
start: TypeInfo | None = None,
761+
*,
762+
origin: Origin,
763+
) -> TypeInfo:
764+
"""Type propagation for enumerate replacement that preserves tensor types."""
765+
if not isinstance(iterable, SequenceType):
766+
raise exc.TypeInferenceError(
767+
f"enumerate() argument must be a sequence, got {iterable}"
768+
)
769+
770+
# Get the start value
771+
start_value = 0
772+
if start is not None and start.is_literal():
773+
start_val = start.as_literal()
774+
if isinstance(start_val, int):
775+
start_value = start_val
776+
777+
# Build result as tuple of (index, value) tuples
778+
sequence_elements = iterable.unpack()
779+
result_elements = []
780+
781+
for i, element in enumerate(sequence_elements):
782+
# Create (index, value) tuple
783+
index_literal = LiteralType(origin, start_value + i)
784+
tuple_elements = (index_literal, element)
785+
tuple_type = SequenceType(GetItemOrigin(origin, i), tuple_elements)
786+
result_elements.append(tuple_type)
787+
788+
return SequenceType(origin, tuple(result_elements))
789+
790+
791+
@_decorators.register_to_device_ir(_enumerate_replacement)
792+
def _(
793+
tracer: object,
794+
*flat_args: object,
795+
) -> object:
796+
"""Device IR handler for enumerate - returns the enumerated result for unrolling."""
797+
if len(flat_args) == 2:
798+
iterable = flat_args[0]
799+
start = flat_args[1]
800+
assert isinstance(start, int)
801+
else:
802+
assert len(flat_args) == 1
803+
iterable = flat_args[0]
804+
start = 0
805+
return [*builtins.enumerate(iterable, start=start)] # type: ignore[arg-type]
806+
807+
808+
@_decorators.api(is_device_only=True, cache_type=True)
809+
def static_range(
810+
begin_or_end: int,
811+
end_or_none: int | None = None,
812+
/,
813+
step: int = 1,
814+
) -> Iterator[int]:
815+
"""
816+
Create a range that gets unrolled at compile time by iterating over constant integer values.
817+
818+
This function is similar to Python's built-in range(), but it generates a sequence
819+
of integer constants that triggers loop unrolling behavior in Helion kernels. The loop
820+
is completely unrolled at compile time, with each iteration becoming separate
821+
instructions in the generated code.
822+
823+
Args:
824+
begin_or_end: If 2+ positional args provided, the start of range (integer).
825+
Otherwise, the end of range (integer).
826+
end_or_none: If 2+ positional args provided, the end of range (integer).
827+
step: Step size for iteration (integer, default: 1)
828+
829+
Returns:
830+
Iterator[int]: Iterator over constant integer values
831+
832+
Examples:
833+
Simple unrolled loop:
834+
835+
.. code-block:: python
836+
837+
@helion.kernel
838+
def unrolled_example(x: torch.Tensor) -> torch.Tensor:
839+
result = torch.zeros_like(x)
840+
841+
for tile in hl.tile(x.size(0)):
842+
acc = torch.zeros([tile], dtype=x.dtype, device=x.device)
843+
# This loop gets completely unrolled
844+
for i in hl.static_range(3):
845+
acc += x[tile] * i
846+
result[tile] = acc
847+
848+
return result
849+
850+
Range with start and step:
851+
852+
.. code-block:: python
853+
854+
@helion.kernel
855+
def kernel_stepped_unroll(x: torch.Tensor) -> torch.Tensor:
856+
result = torch.zeros_like(x)
857+
858+
for tile in hl.tile(x.size(0)):
859+
acc = torch.zeros([tile], dtype=x.dtype, device=x.device)
860+
# Unroll loop from 2 to 8 with step 2: [2, 4, 6]
861+
for i in hl.static_range(2, 8, 2):
862+
acc += x[tile] * i
863+
result[tile] = acc
864+
865+
return result
866+
867+
Note:
868+
- Only constant integer values are supported
869+
- The range must be small enough to avoid compilation timeouts
870+
- Each iteration becomes separate instructions in the generated Triton code
871+
- Use for small, fixed iteration counts where unrolling is beneficial
872+
"""
873+
raise exc.NotInsideKernel
874+
875+
876+
@_decorators.register_fake(static_range)
877+
def _(
878+
begin_or_end: int,
879+
end_or_none: int | None = None,
880+
/,
881+
step: int = 1,
882+
) -> tuple[int, ...]:
883+
"""Fake function for static_range - validates integer constants and returns tuple(range(...))."""
884+
# Validate that inputs are compile-time constants
885+
if end_or_none is not None:
886+
begin_val = begin_or_end
887+
end_val = end_or_none
888+
else:
889+
begin_val = 0
890+
end_val = begin_or_end
891+
892+
if (
893+
not isinstance(begin_val, int)
894+
or not isinstance(end_val, int)
895+
or not isinstance(step, int)
896+
):
897+
raise exc.TypeInferenceError("static_range requires constant integer arguments")
898+
899+
# Return tuple(range(...)) which will trigger existing tuple/list unrolling
900+
return tuple(range(begin_val, end_val, step))

0 commit comments

Comments
 (0)