|
25 | 25 | from .._compiler.compile_environment import CompileEnvironment
|
26 | 26 | from .._compiler.type_propagation import GridIndexType
|
27 | 27 | from .._compiler.type_propagation import IterType
|
| 28 | +from .._compiler.type_propagation import LiteralType |
28 | 29 | from .._compiler.type_propagation import Origin
|
29 | 30 | from .._compiler.type_propagation import SequenceType
|
30 | 31 | from .._compiler.type_propagation import TileIndexType
|
31 | 32 | from .._compiler.type_propagation import TypeInfo
|
| 33 | +from .._compiler.variable_origin import GetItemOrigin |
32 | 34 | from ..autotuner.config_spec import ConfigSpec
|
33 | 35 | from ..autotuner.config_spec import FlattenLoopSpec
|
34 | 36 | from ..autotuner.config_spec import L2GroupingSpec
|
|
48 | 50 | from .._compiler.inductor_lowering import CodegenState
|
49 | 51 |
|
50 | 52 |
|
51 |
| -__all__ = ["grid", "tile"] |
| 53 | +__all__ = ["grid", "static_range", "tile"] |
52 | 54 |
|
53 | 55 |
|
54 | 56 | @overload
|
@@ -633,3 +635,266 @@ def _(
|
633 | 635 | @_decorators.codegen(grid)
|
634 | 636 | def _(state: CodegenState) -> ast.AST:
|
635 | 637 | return _codegen_loop_helper(state)
|
| 638 | + |
| 639 | + |
| 640 | +@_decorators.device_func_replacement(builtins.zip) |
| 641 | +@_decorators.api(is_device_only=True, cache_type=True) |
| 642 | +def _zip_replacement( |
| 643 | + *args: tuple[object, ...] | list[object], |
| 644 | + strict: bool = False, |
| 645 | +) -> tuple[tuple[object, ...], ...]: |
| 646 | + """ |
| 647 | + Device replacement for zip() that returns tuples for unrolling. |
| 648 | +
|
| 649 | + This replacement enables zip() to work in device kernels by converting |
| 650 | + the zip result to a tuple of tuples, which can then be unrolled by the |
| 651 | + existing tuple iteration logic. |
| 652 | +
|
| 653 | + Args: |
| 654 | + *args: Sequences to zip together |
| 655 | +
|
| 656 | + Returns: |
| 657 | + Tuple of tuples containing zipped elements |
| 658 | +
|
| 659 | + Examples: |
| 660 | + .. code-block:: python |
| 661 | +
|
| 662 | + @helion.kernel |
| 663 | + def kernel_with_zip(a_tensors, b_tensors): |
| 664 | + for a, b in zip(a_tensors, b_tensors): |
| 665 | + # This gets unrolled at compile time |
| 666 | + result += a * b |
| 667 | + """ |
| 668 | + raise exc.NotInsideKernel |
| 669 | + |
| 670 | + |
| 671 | +@_decorators.type_propagation(_zip_replacement) |
| 672 | +def _( |
| 673 | + *args: TypeInfo, |
| 674 | + origin: Origin, |
| 675 | + **kwargs: object, |
| 676 | +) -> TypeInfo: |
| 677 | + """Type propagation for zip replacement that preserves tensor types.""" |
| 678 | + # Accept but ignore the strict keyword argument |
| 679 | + if not args: |
| 680 | + return SequenceType(origin, ()) |
| 681 | + |
| 682 | + # Convert all arguments to SequenceType |
| 683 | + sequences = [] |
| 684 | + for arg in args: |
| 685 | + if not isinstance(arg, SequenceType): |
| 686 | + raise exc.TypeInferenceError( |
| 687 | + f"zip() argument must be a sequence, got {arg}" |
| 688 | + ) |
| 689 | + sequences.append(arg.unpack()) |
| 690 | + |
| 691 | + # Check all sequences have the same length |
| 692 | + length = 0 |
| 693 | + if sequences: |
| 694 | + length = len(sequences[0]) |
| 695 | + for i, seq in enumerate(sequences[1:], 1): |
| 696 | + if len(seq) != length: |
| 697 | + raise exc.TypeInferenceError( |
| 698 | + f"zip() argument {i} has length {len(seq)}, expected {length}" |
| 699 | + ) |
| 700 | + |
| 701 | + # Build result as tuple of tuples, preserving existing TypeInfo objects |
| 702 | + result_elements = [] |
| 703 | + for i in range(length): |
| 704 | + # Create a tuple containing the i-th element from each sequence |
| 705 | + tuple_elements = tuple(seq[i] for seq in sequences) |
| 706 | + tuple_type = SequenceType(GetItemOrigin(origin, i), tuple_elements) |
| 707 | + result_elements.append(tuple_type) |
| 708 | + |
| 709 | + return SequenceType(origin, tuple(result_elements)) |
| 710 | + |
| 711 | + |
| 712 | +@_decorators.register_to_device_ir(_zip_replacement) |
| 713 | +def _( |
| 714 | + tracer: object, |
| 715 | + *flat_args: object, |
| 716 | +) -> object: |
| 717 | + """Device IR handler for zip - returns the zipped result for unrolling.""" |
| 718 | + # flat_args contains the prepared arguments: (tensor_sequences, strict_value) |
| 719 | + if not flat_args: |
| 720 | + return () |
| 721 | + |
| 722 | + # Extract sequences and strict parameter |
| 723 | + if len(flat_args) == 2: |
| 724 | + sequences = flat_args[0] # This should be the tuple of sequences |
| 725 | + strict = flat_args[1] # This should be the strict parameter |
| 726 | + assert isinstance(strict, bool) |
| 727 | + else: |
| 728 | + assert len(flat_args) == 1 |
| 729 | + sequences = flat_args[0] |
| 730 | + strict = False |
| 731 | + return [*builtins.zip(*sequences, strict=strict)] # type: ignore[arg-type] |
| 732 | + |
| 733 | + |
| 734 | +@_decorators.device_func_replacement(builtins.enumerate) |
| 735 | +@_decorators.api(is_device_only=True, cache_type=True) |
| 736 | +def _enumerate_replacement( |
| 737 | + iterable: tuple[object, ...] | list[object], |
| 738 | + start: int = 0, |
| 739 | +) -> tuple[tuple[int, object], ...]: |
| 740 | + """ |
| 741 | + Device replacement for enumerate() that returns tuples for unrolling. |
| 742 | +
|
| 743 | + This replacement enables enumerate() to work in device kernels by converting |
| 744 | + the enumerate result to a tuple of (index, value) tuples, which can then be |
| 745 | + unrolled by the existing tuple iteration logic. |
| 746 | +
|
| 747 | + Args: |
| 748 | + iterable: Sequence to enumerate |
| 749 | + start: Starting value for the counter (default: 0) |
| 750 | +
|
| 751 | + Returns: |
| 752 | + Tuple of (index, value) tuples |
| 753 | + """ |
| 754 | + raise exc.NotInsideKernel |
| 755 | + |
| 756 | + |
| 757 | +@_decorators.type_propagation(_enumerate_replacement) |
| 758 | +def _( |
| 759 | + iterable: TypeInfo, |
| 760 | + start: TypeInfo | None = None, |
| 761 | + *, |
| 762 | + origin: Origin, |
| 763 | +) -> TypeInfo: |
| 764 | + """Type propagation for enumerate replacement that preserves tensor types.""" |
| 765 | + if not isinstance(iterable, SequenceType): |
| 766 | + raise exc.TypeInferenceError( |
| 767 | + f"enumerate() argument must be a sequence, got {iterable}" |
| 768 | + ) |
| 769 | + |
| 770 | + # Get the start value |
| 771 | + start_value = 0 |
| 772 | + if start is not None and start.is_literal(): |
| 773 | + start_val = start.as_literal() |
| 774 | + if isinstance(start_val, int): |
| 775 | + start_value = start_val |
| 776 | + |
| 777 | + # Build result as tuple of (index, value) tuples |
| 778 | + sequence_elements = iterable.unpack() |
| 779 | + result_elements = [] |
| 780 | + |
| 781 | + for i, element in enumerate(sequence_elements): |
| 782 | + # Create (index, value) tuple |
| 783 | + index_literal = LiteralType(origin, start_value + i) |
| 784 | + tuple_elements = (index_literal, element) |
| 785 | + tuple_type = SequenceType(GetItemOrigin(origin, i), tuple_elements) |
| 786 | + result_elements.append(tuple_type) |
| 787 | + |
| 788 | + return SequenceType(origin, tuple(result_elements)) |
| 789 | + |
| 790 | + |
| 791 | +@_decorators.register_to_device_ir(_enumerate_replacement) |
| 792 | +def _( |
| 793 | + tracer: object, |
| 794 | + *flat_args: object, |
| 795 | +) -> object: |
| 796 | + """Device IR handler for enumerate - returns the enumerated result for unrolling.""" |
| 797 | + if len(flat_args) == 2: |
| 798 | + iterable = flat_args[0] |
| 799 | + start = flat_args[1] |
| 800 | + assert isinstance(start, int) |
| 801 | + else: |
| 802 | + assert len(flat_args) == 1 |
| 803 | + iterable = flat_args[0] |
| 804 | + start = 0 |
| 805 | + return [*builtins.enumerate(iterable, start=start)] # type: ignore[arg-type] |
| 806 | + |
| 807 | + |
| 808 | +@_decorators.api(is_device_only=True, cache_type=True) |
| 809 | +def static_range( |
| 810 | + begin_or_end: int, |
| 811 | + end_or_none: int | None = None, |
| 812 | + /, |
| 813 | + step: int = 1, |
| 814 | +) -> Iterator[int]: |
| 815 | + """ |
| 816 | + Create a range that gets unrolled at compile time by iterating over constant integer values. |
| 817 | +
|
| 818 | + This function is similar to Python's built-in range(), but it generates a sequence |
| 819 | + of integer constants that triggers loop unrolling behavior in Helion kernels. The loop |
| 820 | + is completely unrolled at compile time, with each iteration becoming separate |
| 821 | + instructions in the generated code. |
| 822 | +
|
| 823 | + Args: |
| 824 | + begin_or_end: If 2+ positional args provided, the start of range (integer). |
| 825 | + Otherwise, the end of range (integer). |
| 826 | + end_or_none: If 2+ positional args provided, the end of range (integer). |
| 827 | + step: Step size for iteration (integer, default: 1) |
| 828 | +
|
| 829 | + Returns: |
| 830 | + Iterator[int]: Iterator over constant integer values |
| 831 | +
|
| 832 | + Examples: |
| 833 | + Simple unrolled loop: |
| 834 | +
|
| 835 | + .. code-block:: python |
| 836 | +
|
| 837 | + @helion.kernel |
| 838 | + def unrolled_example(x: torch.Tensor) -> torch.Tensor: |
| 839 | + result = torch.zeros_like(x) |
| 840 | +
|
| 841 | + for tile in hl.tile(x.size(0)): |
| 842 | + acc = torch.zeros([tile], dtype=x.dtype, device=x.device) |
| 843 | + # This loop gets completely unrolled |
| 844 | + for i in hl.static_range(3): |
| 845 | + acc += x[tile] * i |
| 846 | + result[tile] = acc |
| 847 | +
|
| 848 | + return result |
| 849 | +
|
| 850 | + Range with start and step: |
| 851 | +
|
| 852 | + .. code-block:: python |
| 853 | +
|
| 854 | + @helion.kernel |
| 855 | + def kernel_stepped_unroll(x: torch.Tensor) -> torch.Tensor: |
| 856 | + result = torch.zeros_like(x) |
| 857 | +
|
| 858 | + for tile in hl.tile(x.size(0)): |
| 859 | + acc = torch.zeros([tile], dtype=x.dtype, device=x.device) |
| 860 | + # Unroll loop from 2 to 8 with step 2: [2, 4, 6] |
| 861 | + for i in hl.static_range(2, 8, 2): |
| 862 | + acc += x[tile] * i |
| 863 | + result[tile] = acc |
| 864 | +
|
| 865 | + return result |
| 866 | +
|
| 867 | + Note: |
| 868 | + - Only constant integer values are supported |
| 869 | + - The range must be small enough to avoid compilation timeouts |
| 870 | + - Each iteration becomes separate instructions in the generated Triton code |
| 871 | + - Use for small, fixed iteration counts where unrolling is beneficial |
| 872 | + """ |
| 873 | + raise exc.NotInsideKernel |
| 874 | + |
| 875 | + |
| 876 | +@_decorators.register_fake(static_range) |
| 877 | +def _( |
| 878 | + begin_or_end: int, |
| 879 | + end_or_none: int | None = None, |
| 880 | + /, |
| 881 | + step: int = 1, |
| 882 | +) -> tuple[int, ...]: |
| 883 | + """Fake function for static_range - validates integer constants and returns tuple(range(...)).""" |
| 884 | + # Validate that inputs are compile-time constants |
| 885 | + if end_or_none is not None: |
| 886 | + begin_val = begin_or_end |
| 887 | + end_val = end_or_none |
| 888 | + else: |
| 889 | + begin_val = 0 |
| 890 | + end_val = begin_or_end |
| 891 | + |
| 892 | + if ( |
| 893 | + not isinstance(begin_val, int) |
| 894 | + or not isinstance(end_val, int) |
| 895 | + or not isinstance(step, int) |
| 896 | + ): |
| 897 | + raise exc.TypeInferenceError("static_range requires constant integer arguments") |
| 898 | + |
| 899 | + # Return tuple(range(...)) which will trigger existing tuple/list unrolling |
| 900 | + return tuple(range(begin_val, end_val, step)) |
0 commit comments