diff --git a/README.md b/README.md index fac68dcf..c8c9680d 100644 --- a/README.md +++ b/README.md @@ -216,6 +216,10 @@ parameter for `tl.range()` calls. `True` sets `warp_specialize=True`, Only available on CUDA devices with Blackwell or newer architectures when `allow_warp_specialize` setting is enabled. +* **static\_ranges** (`list[bool]`): +Contains one entry per loop dimension with static bounds, controlling whether to use +`tl.static_range()` calls. `True` generates `tl.static_range()` and ignores range_* configs for that loop. `False` generates `tl.range()`. + * **reduction\_loops** (`list[int | None]`): Contains one entry per reduction dimension (see `examples/softmax.py`). Using `None` triggers a persistent reduction, diff --git a/helion/_compiler/reduction_strategy.py b/helion/_compiler/reduction_strategy.py index c1323a38..9e40e24d 100644 --- a/helion/_compiler/reduction_strategy.py +++ b/helion/_compiler/reduction_strategy.py @@ -253,12 +253,17 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState: ) ) - range_extra = self.get_tl_range_kwargs(state, self.block_index) for_node = create( ast.For, target=create(ast.Name, id=offset_var, ctx=ast.Store()), iter=expr_from_string( - f"tl.range(0, ({state.sympy_expr(numel)}), {block_size_var}{range_extra})" + self.get_range_call_str( + state, + [self.block_index], + begin="0", + end=state.sympy_expr(numel), + step=block_size_var, + ), ), body=body, orelse=[], diff --git a/helion/_compiler/tile_strategy.py b/helion/_compiler/tile_strategy.py index bec1e9ff..b9de5bd7 100644 --- a/helion/_compiler/tile_strategy.py +++ b/helion/_compiler/tile_strategy.py @@ -125,7 +125,7 @@ def mask_var(self, block_idx: int) -> str | None: def block_size_var(self, block_idx: int) -> str | None: return self.fn.block_size_var_cache.get((block_idx,)) - def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> str: + def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> list[str]: """Get the range_extra string for loop unroll factor and num_stages based on config.""" env = CompileEnvironment.current() kwargs = [] @@ -159,10 +159,37 @@ def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> str: ) if range_flatten is not None: kwargs.append(f"flatten={range_flatten}") + return kwargs - if kwargs: - return f", {', '.join(kwargs)}" - return "" + def get_range_call_str( + self, + state: CodegenState, + block_ids: list[int], + *, + begin: str | None = None, + end: str, + step: str | None = None, + ) -> str: + env = CompileEnvironment.current() + use_static_range = all( + env.config_spec.static_ranges.config_get( + state.config.static_ranges, block_idx, None + ) + is True + for block_idx in block_ids + ) + + range_args = [] + if begin is not None: + range_args.append(begin) + range_args.append(end) + if step is not None: + range_args.append(f"step={step}") + + if use_static_range: + return f"tl.static_range({', '.join(range_args)})" + range_kwargs = self.get_tl_range_kwargs(state, block_ids[0]) + return f"tl.range({', '.join(range_args + range_kwargs)})" def user_size(self, block_index: int) -> sympy.Expr: raise NotImplementedError @@ -407,12 +434,12 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState: ) dtype = CompileEnvironment.current().triton_index_type() lid = self.new_var("lid") - range_extra = self.get_tl_range_kwargs(state, self.block_ids[0]) + end_var = f"tl.cdiv({state.sympy_expr(total_numel)}, {block_size_var})" for_node = create( ast.For, target=create(ast.Name, id=lid, ctx=ast.Store()), iter=expr_from_string( - f"tl.range(tl.cdiv({state.sympy_expr(total_numel)}, {block_size_var}){range_extra})" + self.get_range_call_str(state, self.block_ids, end=end_var) ), body=( body := [ @@ -624,12 +651,17 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState: end_expr=self._fold_tile_end_op(state, proxy_end, block_size), ) - range_extra = self.get_tl_range_kwargs(state, block_idx) for_node = create( ast.For, target=create(ast.Name, id=offset_var, ctx=ast.Store()), iter=expr_from_string( - f"tl.range(begin, end, {block_size_var}{range_extra})", + self.get_range_call_str( + state, + [block_idx], + begin="begin", + end="end", + step=block_size_var, + ), begin=self._to_ast(begin, to_dtype=dtype), end=self._to_ast(end, to_dtype=dtype), ), diff --git a/helion/autotuner/block_id_sequence.py b/helion/autotuner/block_id_sequence.py index a405756c..17171fff 100644 --- a/helion/autotuner/block_id_sequence.py +++ b/helion/autotuner/block_id_sequence.py @@ -110,6 +110,10 @@ def block_id_lookup(self, block_id: int) -> _BlockIdItemT: """Return the index of the block_id in the config.""" return self._data[self._block_id_to_index[block_id]] + def valid_block_ids(self) -> list[int]: + """Return the list of valid block_ids.""" + return list(self._block_id_to_index.keys()) + def disable_block_id(self, block_id: int) -> None: """Remove configuration choice for the given block_id.""" self._data = [x for x in self._data if block_id not in x.block_ids] @@ -132,6 +136,24 @@ def _flat_config( """Map a flattened version of the config using the given function.""" return [spec._flat_config(base, fn) for spec in self._data] + def _reset_config_to_default( + self, name: str, values: object, *, block_ids: list[int] | None = None + ) -> list[object]: + """Set the config values to the default values. If block_ids is provided, only set those values.""" + if not values: + return [] + assert isinstance(values, list) + assert len(values) == len(self) + + if block_ids is None: + block_ids = self.valid_block_ids() + for block_id in block_ids: + if block_id not in self._block_id_to_index: + continue + index = self._block_id_to_index[block_id] + values[index] = self._data[index]._fill_missing() + return values + def _normalize( self, name: str, values: object, *, flatten: bool = False ) -> list[object]: diff --git a/helion/autotuner/config_spec.py b/helion/autotuner/config_spec.py index df7c9fba..d478da53 100644 --- a/helion/autotuner/config_spec.py +++ b/helion/autotuner/config_spec.py @@ -44,6 +44,7 @@ "range_num_stages", "range_multi_buffers", "range_flattens", + "static_ranges", "num_warps", "num_stages", "pid_type", @@ -85,6 +86,9 @@ class ConfigSpec: range_flattens: BlockIdSequence[RangeFlattenSpec] = dataclasses.field( default_factory=BlockIdSequence ) + static_ranges: BlockIdSequence[StaticRangeSpec] = dataclasses.field( + default_factory=BlockIdSequence + ) user_defined_tunables: dict[str, ConfigSpecFragment] = dataclasses.field( default_factory=dict ) @@ -109,6 +113,7 @@ def _remove_duplicates(self) -> None: self.range_num_stages._remove_duplicates() self.range_multi_buffers._remove_duplicates() self.range_flattens._remove_duplicates() + self.static_ranges._remove_duplicates() def disallow_pid_type(self, pid_type: PidTypeLiteral) -> None: """Disallow a pid_type from being used in the config.""" @@ -135,6 +140,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None: "range_num_stage", "range_multi_buffer", "range_flatten", + "static_range", ): if name in config: names = f"{name}s" @@ -153,11 +159,32 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None: ("range_num_stages", self.range_num_stages, True), ("range_multi_buffers", self.range_multi_buffers, True), ("range_flattens", self.range_flattens, True), + ("static_ranges", self.static_ranges, True), ]: config[name] = mapping._normalize( name, config.get(name, ()), flatten=flatten ) + static_range_block_ids = [] + for block_id in self.static_ranges.valid_block_ids(): + use_static_range = self.static_ranges.config_get( + config.get("static_ranges", ()), # pyre-ignore[6] + block_id, + ) + if use_static_range: + static_range_block_ids.append(block_id) + + for name, mapping in ( + ("range_unroll_factors", self.range_unroll_factors), + ("range_warp_specializes", self.range_warp_specialize), + ("range_num_stages", self.range_num_stages), + ("range_multi_buffers", self.range_multi_buffers), + ("range_flattens", self.range_flattens), + ): + config[name] = mapping._reset_config_to_default( + name, config.get(name, ()), block_ids=static_range_block_ids + ) + for name in ( "loop_orders", "l2_groupings", @@ -168,6 +195,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None: "range_num_stages", "range_multi_buffers", "range_flattens", + "static_ranges", ): if not config[name]: config.pop(name) @@ -209,6 +237,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf "range_num_stages": self.range_num_stages._flat_config(self, fn), "range_multi_buffers": self.range_multi_buffers._flat_config(self, fn), "range_flattens": self.range_flattens._flat_config(self, fn), + "static_ranges": self.static_ranges._flat_config(self, fn), "num_warps": fn(NumWarpsFragment(1, 32, DEFAULT_NUM_WARPS)), "num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)), "indexing": fn(EnumFragment(self._valid_indexing_types())), @@ -228,6 +257,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf "range_num_stages", "range_multi_buffers", "range_flattens", + "static_ranges", ): if not config[name]: config.pop(name) @@ -416,6 +446,20 @@ class RangeFlattenSpec(_OptionalBoolSpec): pass +class StaticRangeSpec(_BlockIdItem): + def _fragment(self, base: ConfigSpec) -> BooleanFragment: + return BooleanFragment() + + def _normalize(self, name: str, value: object) -> bool: + if not isinstance(value, bool): + raise InvalidConfig(f"{name} must be a boolean, got {value!r}") + return value + + def _fill_missing(self) -> bool: + """Provide a value when not provided by the user.""" + return False + + def _product(seq: Sequence[int]) -> int: """Return the product of the elements in the sequence.""" return functools.reduce(operator.mul, seq, 1) diff --git a/helion/language/loops.py b/helion/language/loops.py index addd8226..bcbb5be7 100644 --- a/helion/language/loops.py +++ b/helion/language/loops.py @@ -33,6 +33,7 @@ from ..autotuner.config_spec import RangeNumStagesSpec from ..autotuner.config_spec import RangeUnrollFactorSpec from ..autotuner.config_spec import RangeWarpSpecializeSpec +from ..autotuner.config_spec import StaticRangeSpec from . import _decorators from helion.language.tile_proxy import Tile @@ -151,6 +152,23 @@ def _check_matching(a: object, b: object) -> None: ) +def _is_constexpr_int(a: object) -> bool: + """Check if the arg is specialized.""" + return isinstance(a, int) + # TODO(joydddd): render SymInt backed by Int as constexpr. + # Now the specialized constexpr is assigned to a dynamic variable first + # and then used as a variable. However args to static_range must be constexpr. + # e.g. + # hl.specialize(x.size(0)) + # for i in hl.grid(x.size(0)) + # -> + # symbol_0 = 64 + # for i in tl.static_range(symbol_0): + # + # if isinstance(a, torch.SymInt): + # return isinstance(a._sympy_(), sympy.Integer) + + def _normalize_begin_end( begin_or_end: TypeInfo, end_or_none: TypeInfo | None, @@ -225,6 +243,10 @@ def _( [x.block_id for x in results], is_tile=True, has_begin=not all((isinstance(x, int) and x == 0) for x in proxy_begin), + is_static=all( + _is_constexpr_int(x) or x is None + for x in (*proxy_begin, *proxy_end, *proxy_block_size) + ), ) if unpack: (result,) = results @@ -234,7 +256,11 @@ def _( def _add_config_choices( - block_ids: list[int], *, is_tile: bool = False, has_begin: bool = False + block_ids: list[int], + *, + is_tile: bool = False, + has_begin: bool = False, + is_static: bool = False, ) -> None: config_spec = CompileEnvironment.current().config_spec @@ -254,6 +280,8 @@ def _add_config_choices( else: params = inspect.signature(triton.language.range).parameters for block_id in block_ids: + if is_static: + config_spec.static_ranges.append(StaticRangeSpec([block_id])) if "loop_unroll_factor" in params: config_spec.range_unroll_factors.append( RangeUnrollFactorSpec([block_id]) @@ -420,6 +448,10 @@ def _( [x.block_id for x in results], is_tile=False, has_begin=not all((isinstance(x, int) and x == 0) for x in proxy_begin), + is_static=all( + _is_constexpr_int(x) or x is None + for x in (*proxy_begin, *proxy_end, *proxy_step) + ), ) if unpack: (result,) = results diff --git a/helion/runtime/config.py b/helion/runtime/config.py index 41885420..f811f66e 100644 --- a/helion/runtime/config.py +++ b/helion/runtime/config.py @@ -31,6 +31,7 @@ def __init__( range_num_stages: list[int] | None = None, range_multi_buffers: list[bool | None] | None = None, range_flattens: list[bool | None] | None = None, + static_ranges: list[bool] | None = None, num_warps: int | None = None, num_stages: int | None = None, pid_type: PidTypeLiteral | None = None, @@ -51,6 +52,7 @@ def __init__( range_num_stages: Number of stages for tl.range calls. range_multi_buffers: Controls disallow_acc_multi_buffer for tl.range calls. range_flattens: Controls flatten parameter for tl.range calls. + static_ranges: Whether to use tl.static_range instead tl.range. num_warps: Number of warps per block. num_stages: Number of stages for software pipelining. pid_type: Program ID type strategy ("flat", "xyz", "persistent_blocked", "persistent_interleaved"). @@ -69,6 +71,7 @@ def __init__( "range_num_stages": range_num_stages, "range_multi_buffers": range_multi_buffers, "range_flattens": range_flattens, + "static_ranges": static_ranges, "num_warps": num_warps, "num_stages": num_stages, "indexing": indexing, @@ -174,6 +177,10 @@ def range_multi_buffers(self) -> list[bool | None]: def range_flattens(self) -> list[bool | None]: return cast("list[bool | None]", self.config.get("range_flattens", [])) + @property + def static_ranges(self) -> list[bool]: + return cast("list[bool]", self.config.get("static_ranges", [])) + @property def indexing(self) -> IndexingLiteral: return self.config.get("indexing", "pointer") # type: ignore diff --git a/test/test_autotuner.expected b/test/test_autotuner.expected index b2d53550..84cf25ee 100644 --- a/test/test_autotuner.expected +++ b/test/test_autotuner.expected @@ -2,16 +2,16 @@ This file is automatically generated by assertExpectedJournal calls in test_auto Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. --- assertExpectedJournal(TestAutotuner.test_config_fragment0) -helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=4, num_stages=3, indexing='pointer', pid_type='flat') -helion.Config(block_sizes=[16, 64, 32], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_warp_specializes=[False], range_num_stages=[1], range_multi_buffers=[True], range_flattens=[False], num_warps=1, num_stages=7, indexing='tensor_descriptor', pid_type='flat') -helion.Config(block_sizes=[32, 32, 16], loop_orders=[[1, 0]], l2_groupings=[2], range_unroll_factors=[4], range_warp_specializes=[True], range_num_stages=[1], range_multi_buffers=[True], range_flattens=[False], num_warps=16, num_stages=6, indexing='block_ptr', pid_type='persistent_interleaved') -helion.Config(block_sizes=[32, 16, 16], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[3], range_warp_specializes=[True], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[None], num_warps=1, num_stages=4, indexing='block_ptr', pid_type='persistent_interleaved') -helion.Config(block_sizes=[32, 16, 16], loop_orders=[[1, 0]], l2_groupings=[4], range_unroll_factors=[4], range_warp_specializes=[False], range_num_stages=[1], range_multi_buffers=[True], range_flattens=[True], num_warps=16, num_stages=2, indexing='pointer', pid_type='persistent_interleaved') -helion.Config(block_sizes=[16, 32, 64], loop_orders=[[0, 1]], l2_groupings=[8], range_unroll_factors=[1], range_warp_specializes=[True], range_num_stages=[1], range_multi_buffers=[True], range_flattens=[None], num_warps=4, num_stages=2, indexing='tensor_descriptor', pid_type='flat') -helion.Config(block_sizes=[16, 16, 32], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[4], range_warp_specializes=[None], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[None], num_warps=16, num_stages=1, indexing='tensor_descriptor', pid_type='persistent_blocked') -helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[32], range_unroll_factors=[2], range_warp_specializes=[None], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[None], num_warps=32, num_stages=7, indexing='block_ptr', pid_type='flat') -helion.Config(block_sizes=[16, 32, 64], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[3], range_warp_specializes=[False], range_num_stages=[1], range_multi_buffers=[None], range_flattens=[False], num_warps=16, num_stages=6, indexing='block_ptr', pid_type='flat') -helion.Config(block_sizes=[32, 32, 16], loop_orders=[[1, 0]], l2_groupings=[64], range_unroll_factors=[2], range_warp_specializes=[True], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[True], num_warps=4, num_stages=3, indexing='tensor_descriptor', pid_type='persistent_blocked') +helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], static_ranges=[False], num_warps=4, num_stages=3, indexing='pointer', pid_type='flat') +helion.Config(block_sizes=[16, 64, 32], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_warp_specializes=[False], range_num_stages=[1], range_multi_buffers=[True], range_flattens=[False], static_ranges=[False], num_warps=8, num_stages=4, indexing='tensor_descriptor', pid_type='persistent_blocked') +helion.Config(block_sizes=[64, 64, 32], loop_orders=[[0, 1]], l2_groupings=[32], range_unroll_factors=[3], range_warp_specializes=[True], range_num_stages=[2], range_multi_buffers=[False], range_flattens=[True], static_ranges=[True], num_warps=4, num_stages=1, indexing='pointer', pid_type='flat') +helion.Config(block_sizes=[16, 16, 32], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[1], range_warp_specializes=[False], range_num_stages=[4], range_multi_buffers=[False], range_flattens=[None], static_ranges=[True], num_warps=32, num_stages=7, indexing='tensor_descriptor', pid_type='flat') +helion.Config(block_sizes=[128, 64, 16], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[2], range_warp_specializes=[None], range_num_stages=[4], range_multi_buffers=[True], range_flattens=[True], static_ranges=[False], num_warps=8, num_stages=7, indexing='pointer', pid_type='persistent_interleaved') +helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[16], range_unroll_factors=[0], range_warp_specializes=[True], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[False], static_ranges=[False], num_warps=8, num_stages=3, indexing='block_ptr', pid_type='flat') +helion.Config(block_sizes=[32, 64, 64], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[2], range_warp_specializes=[False], range_num_stages=[3], range_multi_buffers=[False], range_flattens=[True], static_ranges=[False], num_warps=32, num_stages=5, indexing='pointer', pid_type='flat') +helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[2], range_unroll_factors=[2], range_warp_specializes=[False], range_num_stages=[4], range_multi_buffers=[True], range_flattens=[None], static_ranges=[True], num_warps=4, num_stages=7, indexing='block_ptr', pid_type='flat') +helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[16], range_unroll_factors=[2], range_warp_specializes=[False], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[False], static_ranges=[True], num_warps=4, num_stages=6, indexing='pointer', pid_type='persistent_interleaved') +helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[16], range_unroll_factors=[3], range_warp_specializes=[True], range_num_stages=[3], range_multi_buffers=[False], range_flattens=[None], static_ranges=[True], num_warps=2, num_stages=8, indexing='pointer', pid_type='persistent_interleaved') --- assertExpectedJournal(TestAutotuner.test_config_fragment1) helion.Config(block_sizes=[8, 16, 16], loop_orders=[[0, 1, 2]], flatten_loops=[False], num_warps=4, num_stages=3, indexing='pointer', pid_type='flat') @@ -43,4 +43,3 @@ helion.Config(block_sizes=[1, 512, 4], loop_orders=[[2, 1, 0]], flatten_loops=[F "indexing": "block_ptr", "l2_grouping": 32 } - diff --git a/test/test_examples.expected b/test/test_examples.expected index ec8eb86a..db40d189 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -55,7 +55,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr, l_i = tl.full([1, _BLOCK_SIZE_1], 1.0, tl.float32) acc = tl.full([1, _BLOCK_SIZE_1, 64], 0.0, tl.float32) q = tl.load(tl.make_block_ptr(q_view, [64, 1024, 64], [65536, 64, 1], [offset_0, offset_1, 0], [1, _BLOCK_SIZE_1, 64], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero') - for offset_2 in tl.range(0, 512, _BLOCK_SIZE_3): + for offset_2 in tl.range(0, 512, step=_BLOCK_SIZE_3): q_copy = q m_i_copy = m_i l_i_copy = l_i @@ -159,7 +159,7 @@ def _attention_kernel(q_view, k_view, v_view, out, q_in_size_1, k_view_stride_0, l_i = tl.full([1, _BLOCK_SIZE_1], 1.0, tl.float32) acc = tl.full([1, _BLOCK_SIZE_1, 64], 0.0, tl.float32) q = tl.load(q_view + (indices_0[:, None, None] * q_view_stride_0 + indices_1[None, :, None] * q_view_stride_1 + indices_4[None, None, :] * q_view_stride_2), mask_1[None, :, None], other=0) - for offset_2 in tl.range(0, n_dim.to(tl.int32), _BLOCK_SIZE_3): + for offset_2 in tl.range(0, n_dim.to(tl.int32), step=_BLOCK_SIZE_3): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) mask_3 = indices_2 < n_dim q_copy = q @@ -257,7 +257,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr, l_i = tl.full([1, _BLOCK_SIZE_1], 1.0, tl.float32) acc = tl.full([1, _BLOCK_SIZE_1, 64], 0.0, tl.float32) q = tl.load(q_view + (indices_0[:, None, None] * 32768 + indices_1[None, :, None] * 64 + indices_4[None, None, :] * 1), None) - for offset_2 in tl.range(0, 512, _BLOCK_SIZE_3): + for offset_2 in tl.range(0, 512, step=_BLOCK_SIZE_3): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) q_copy = q m_i_copy = m_i @@ -349,7 +349,7 @@ def _bmm_kernel(A, B, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.conste offset_2 = pid_2 * _BLOCK_SIZE_2 indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32) - for offset_3 in tl.range(0, 768, _BLOCK_SIZE_3): + for offset_3 in tl.range(0, 768, step=_BLOCK_SIZE_3): indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) acc_copy = acc acc_copy_0 = acc_copy @@ -575,7 +575,7 @@ def _jagged_dense_add_2d_kernel(x_offsets, x_data, y, out, out_size_0, out_size_ ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, None) v_2 = ends - starts max_nnz = tl.max(v_2, 0) - for offset_1 in tl.range(0, max_nnz.to(tl.int32), _BLOCK_SIZE_1): + for offset_1 in tl.range(0, max_nnz.to(tl.int32), step=_BLOCK_SIZE_1): indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) mask_1 = indices_1 < max_nnz starts_copy = starts @@ -594,7 +594,7 @@ def _jagged_dense_add_2d_kernel(x_offsets, x_data, y, out, out_size_0, out_size_ load_1 = tl.load(tl.make_block_ptr(y, [y_size_0, y_size_1], [y_stride_0, y_stride_1], [offset_0, offset_1], [1, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero') v_7 = load_1 + x_slice tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1], [out_stride_0, out_stride_1], [offset_0, offset_1], [1, _BLOCK_SIZE_1], [1, 0]), v_7, boundary_check=[0, 1]) - for offset_2 in tl.range(max_nnz.to(tl.int32), y_size_1.to(tl.int32), _BLOCK_SIZE_2): + for offset_2 in tl.range(max_nnz.to(tl.int32), y_size_1.to(tl.int32), step=_BLOCK_SIZE_2): load = tl.load(tl.make_block_ptr(y, [y_size_0, y_size_1], [y_stride_0, y_stride_1], [offset_0, offset_2], [1, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero') tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1], [out_stride_0, out_stride_1], [offset_0, offset_2], [1, _BLOCK_SIZE_2], [1, 0]), load, boundary_check=[0, 1]) @@ -668,7 +668,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con offset_1 = pid_1 * _BLOCK_SIZE_1 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, 128, _BLOCK_SIZE_2): + for offset_2 in tl.range(0, 128, step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) acc_copy = acc acc_copy_0 = acc_copy @@ -716,7 +716,7 @@ def _matmul_layernorm_kernel(bias, x, y, weight, out, bias_size_0, bias_stride_0 indices_0 = tl.arange(0, _RDIM_SIZE_0).to(tl.int32) mask_0 = indices_0 < bias_size_0 acc = tl.full([_BLOCK_SIZE_1, _RDIM_SIZE_0], 0.0, tl.float32) - for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + for offset_2 in tl.range(0, k.to(tl.int32), step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) mask_2 = indices_2 < k acc_copy = acc @@ -790,7 +790,7 @@ def _matmul_layernorm_kernel(x, y, weight, bias, out, out_stride_0, _BLOCK_SIZE_ indices_0 = tl.arange(0, _RDIM_SIZE_0).to(tl.int32) mask_0 = indices_0 < 400 acc = tl.full([_BLOCK_SIZE_1, _RDIM_SIZE_0], 0.0, tl.float32) - for offset_2 in tl.range(0, 256, _BLOCK_SIZE_2): + for offset_2 in tl.range(0, 256, step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) acc_copy = acc acc_copy_0 = acc_copy @@ -872,7 +872,7 @@ def _matmul_split_k_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1 offset_2 = pid_2 * _BLOCK_SIZE_2 acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) tile_end = tl.minimum(offset_2 + _BLOCK_SIZE_2, 1024) - for offset_3 in tl.range(offset_2.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_3): + for offset_3 in tl.range(offset_2.to(tl.int32), tile_end.to(tl.int32), step=_BLOCK_SIZE_3): indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) mask_3 = indices_3 < tile_end acc_copy = acc @@ -930,10 +930,10 @@ def _moe_matmul_ogs_kernel(expert_token_offsets, expert_token_counts, sorted_to_ start_copy = start num_tokens_copy_0 = num_tokens_copy start_copy_0 = start_copy - for offset_1 in tl.range(0, max_T_per_expert.to(tl.int32), _BLOCK_SIZE_1): + for offset_1 in tl.range(0, max_T_per_expert.to(tl.int32), step=_BLOCK_SIZE_1): indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) mask_1 = indices_1 < max_T_per_expert - for offset_2 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_2): + for offset_2 in tl.range(0, N.to(tl.int32), step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) mask_2 = indices_2 < N num_tokens_copy_0_copy = num_tokens_copy_0 @@ -950,7 +950,7 @@ def _moe_matmul_ogs_kernel(expert_token_offsets, expert_token_counts, sorted_to_ squeeze = tl.reshape(v_8, [_BLOCK_SIZE_1]) expert_orig_token_indices = tl.load(sorted_to_orig_token_idx + squeeze * sorted_to_orig_token_idx_stride_0, mask_1, other=0) acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32) - for offset_3 in tl.range(0, K.to(tl.int32), _BLOCK_SIZE_3): + for offset_3 in tl.range(0, K.to(tl.int32), step=_BLOCK_SIZE_3): indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) mask_3 = indices_3 < K expert_orig_token_indices_copy = expert_orig_token_indices @@ -1079,7 +1079,7 @@ def _softmax_kernel(x, out, out_size_0, out_size_1, x_size_0, x_size_1, out_stri pid_0 = tl.program_id(0) offset_0 = pid_0 amax_acc = tl.full([1, _REDUCTION_BLOCK_1], float('-inf'), tl.float32) - for roffset_1 in tl.range(0, _m, _REDUCTION_BLOCK_1): + for roffset_1 in tl.range(0, _m, step=_REDUCTION_BLOCK_1): rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32) mask_1 = rindex_1 < _m load = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1], [x_stride_0, x_stride_1], [offset_0, roffset_1], [1, _REDUCTION_BLOCK_1], [1, 0]), boundary_check=[0, 1], padding_option='zero') @@ -1088,7 +1088,7 @@ def _softmax_kernel(x, out, out_size_0, out_size_1, x_size_0, x_size_1, out_stri amax_acc = v_0 amax = tl.reshape(tl.max(amax_acc, 1), [1, 1]) sum_1_acc = tl.full([1, _REDUCTION_BLOCK_1], 0, tl.float32) - for roffset_1 in tl.range(0, _m, _REDUCTION_BLOCK_1): + for roffset_1 in tl.range(0, _m, step=_REDUCTION_BLOCK_1): rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32) mask_1 = rindex_1 < _m amax_copy = amax @@ -1099,7 +1099,7 @@ def _softmax_kernel(x, out, out_size_0, out_size_1, x_size_0, x_size_1, out_stri v_3 = sum_1_acc + _mask_to_1 sum_1_acc = v_3 sum_1 = tl.reshape(tl.sum(sum_1_acc, 1), [1, 1]) - for roffset_1 in tl.range(0, _m, _REDUCTION_BLOCK_1): + for roffset_1 in tl.range(0, _m, step=_REDUCTION_BLOCK_1): rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32) mask_1 = rindex_1 < _m amax_copy_1 = amax @@ -1140,7 +1140,7 @@ def _softmax_two_pass_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_s indices_0 = offset_0 + tl.zeros([1], tl.int32) mi = tl.full([1], float('-inf'), tl.float32) di = tl.full([1], 0.0, tl.float32) - for offset_2 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_1): + for offset_2 in tl.range(0, n.to(tl.int32), step=_BLOCK_SIZE_1): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) mask_1 = indices_2 < n mi_copy = mi @@ -1161,7 +1161,7 @@ def _softmax_two_pass_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_s sum_1 = tl.sum(_mask_to_1, 1) di = v_3 + sum_1 mi = v_0 - for offset_2 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_1): + for offset_2 in tl.range(0, n.to(tl.int32), step=_BLOCK_SIZE_1): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) mask_2 = indices_2 < n mi_copy_1 = mi @@ -1211,7 +1211,7 @@ def _softmax_two_pass_kernel(x, out, out_size_0, out_size_1, x_size_0, x_size_1, mask_0 = indices_0 < m mi = tl.full([_BLOCK_SIZE_0], float('-inf'), tl.float32) di = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32) - for offset_2 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_1): + for offset_2 in tl.range(0, n.to(tl.int32), step=_BLOCK_SIZE_1): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) mask_1 = indices_2 < n mi_copy = mi @@ -1232,7 +1232,7 @@ def _softmax_two_pass_kernel(x, out, out_size_0, out_size_1, x_size_0, x_size_1, sum_1 = tl.sum(_mask_to_1, 1) di = v_3 + sum_1 mi = v_0 - for offset_2 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_1): + for offset_2 in tl.range(0, n.to(tl.int32), step=_BLOCK_SIZE_1): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) mi_copy_1 = mi di_copy_1 = di @@ -1291,7 +1291,7 @@ def _matmul_with_epilogue_kernel(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: t offset_1 = pid_1 * _BLOCK_SIZE_1 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, 1024, _BLOCK_SIZE_2): + for offset_2 in tl.range(0, 1024, step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) acc_copy = acc acc_copy_0 = acc_copy @@ -1351,7 +1351,7 @@ def _matmul_with_epilogue_kernel(x, y, epilogue_closure_0, out, _BLOCK_SIZE_0: t offset_0 = pid_0 * _BLOCK_SIZE_0 offset_1 = pid_1 * _BLOCK_SIZE_1 acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, 1024, _BLOCK_SIZE_2): + for offset_2 in tl.range(0, 1024, step=_BLOCK_SIZE_2): acc_copy = acc acc_copy_0 = acc_copy load = tl.load(tl.make_block_ptr(x, [1024, 1024], [1024, 1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero') @@ -1410,7 +1410,7 @@ def _matmul_with_epilogue_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_ offset_0 = pid_0 * _BLOCK_SIZE_0 offset_1 = pid_1 * _BLOCK_SIZE_1 acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, 1024, _BLOCK_SIZE_2): + for offset_2 in tl.range(0, 1024, step=_BLOCK_SIZE_2): acc_copy = acc acc_copy_0 = acc_copy load = tl.load(tl.make_block_ptr(x, [1024, 1024], [1024, 1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero') @@ -1442,4 +1442,3 @@ def _matmul_with_epilogue_make_precompiler(x: Tensor, y: Tensor, epilogue: Calla _BLOCK_SIZE_2 = 16 from helion.runtime.precompile_shim import make_precompiler return make_precompiler(_matmul_with_epilogue_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4) - diff --git a/test/test_grid.expected b/test/test_grid.expected index 474b0830..eade8a21 100644 --- a/test/test_grid.expected +++ b/test/test_grid.expected @@ -12,13 +12,13 @@ import triton.language as tl def _grid_1d_kernel(x, y, out, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr): pid_0 = tl.program_id(0) offset_0 = pid_0 - for offset_1 in tl.range(0, 16, _BLOCK_SIZE_1): + for offset_1 in tl.range(0, 16, step=_BLOCK_SIZE_1): indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) - for offset_2 in tl.range(0, 4, _BLOCK_SIZE_2): + for offset_2 in tl.range(0, 4, step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) mask_2 = indices_2 < 4 acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32) - for offset_3 in tl.range(0, 32, _BLOCK_SIZE_3): + for offset_3 in tl.range(0, 32, step=_BLOCK_SIZE_3): indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) acc_copy = acc acc_copy_0 = acc_copy @@ -61,10 +61,10 @@ import triton.language as tl def _grid_1d_kernel(x, y, out, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr): pid_0 = tl.program_id(0) offset_0 = pid_0 - for offset_1 in tl.range(0, 16, _BLOCK_SIZE_1): - for offset_2 in tl.range(0, 4, _BLOCK_SIZE_2): + for offset_1 in tl.range(0, 16, step=_BLOCK_SIZE_1): + for offset_2 in tl.range(0, 4, step=_BLOCK_SIZE_2): acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32) - for offset_3 in tl.range(0, 32, _BLOCK_SIZE_3): + for offset_3 in tl.range(0, 32, step=_BLOCK_SIZE_3): acc_copy = acc acc_copy_0 = acc_copy load = tl.reshape(tl.load(tl.make_block_ptr(x, [8, 16, 32], [512, 32, 1], [offset_0, offset_1, offset_3], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3], [2, 1, 0]), boundary_check=[1, 2], padding_option='zero'), [_BLOCK_SIZE_1, _BLOCK_SIZE_3]) @@ -109,12 +109,12 @@ def _grid_2d_idx_list_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE pid_1 = tl.program_id(0) // num_blocks_0 offset_0 = pid_0 offset_1 = pid_1 - for offset_2 in tl.range(0, 64, _BLOCK_SIZE_2): + for offset_2 in tl.range(0, 64, step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) - for offset_3 in tl.range(0, 16, _BLOCK_SIZE_3): + for offset_3 in tl.range(0, 16, step=_BLOCK_SIZE_3): indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) acc = tl.full([_BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0, tl.float32) - for offset_4 in tl.range(0, 32, _BLOCK_SIZE_4): + for offset_4 in tl.range(0, 32, step=_BLOCK_SIZE_4): indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32) acc_copy = acc acc_copy_0 = acc_copy @@ -160,10 +160,10 @@ def _grid_2d_idx_list_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE pid_1 = tl.program_id(0) // num_blocks_0 offset_0 = pid_0 offset_1 = pid_1 - for offset_2 in tl.range(0, 64, _BLOCK_SIZE_2): - for offset_3 in tl.range(0, 16, _BLOCK_SIZE_3): + for offset_2 in tl.range(0, 64, step=_BLOCK_SIZE_2): + for offset_3 in tl.range(0, 16, step=_BLOCK_SIZE_3): acc = tl.full([_BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0, tl.float32) - for offset_4 in tl.range(0, 32, _BLOCK_SIZE_4): + for offset_4 in tl.range(0, 32, step=_BLOCK_SIZE_4): acc_copy = acc acc_copy_0 = acc_copy load = tl.reshape(tl.load(tl.make_block_ptr(x, [3, 4, 64, 32], [8192, 2048, 32, 1], [offset_0, offset_1, offset_2, offset_4], [1, 1, _BLOCK_SIZE_2, _BLOCK_SIZE_4], [3, 2, 1, 0]), boundary_check=[2, 3], padding_option='zero'), [_BLOCK_SIZE_2, _BLOCK_SIZE_4]) @@ -205,13 +205,13 @@ import triton.language as tl def _grid_2d_idx_nested_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr): pid_0 = tl.program_id(0) offset_0 = pid_0 - for offset_1 in tl.range(0, 4, 1): - for offset_2 in tl.range(0, 64, _BLOCK_SIZE_2): + for offset_1 in tl.range(0, 4, step=1): + for offset_2 in tl.range(0, 64, step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) - for offset_3 in tl.range(0, 16, _BLOCK_SIZE_3): + for offset_3 in tl.range(0, 16, step=_BLOCK_SIZE_3): indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) acc = tl.full([_BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0, tl.float32) - for offset_4 in tl.range(0, 32, _BLOCK_SIZE_4): + for offset_4 in tl.range(0, 32, step=_BLOCK_SIZE_4): indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32) acc_copy = acc acc_copy_0 = acc_copy @@ -413,7 +413,7 @@ def _range_step_kernel_kernel(out, x, out_stride_0, x_stride_0, batch, _BLOCK_SI offset_0 = pid_0 * _BLOCK_SIZE_0 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) mask_0 = indices_0 < batch - for offset_1 in tl.range(1, 10, _BLOCK_SIZE_1): + for offset_1 in tl.range(1, 10, step=_BLOCK_SIZE_1): load = tl.load(out + indices_0 * out_stride_0, mask_0, other=0) load_1 = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) v_0 = offset_1.to(tl.float32) @@ -466,4 +466,3 @@ def _tile_begin_end_make_precompiler(x: torch.Tensor): _BLOCK_SIZE_0 = 4 from helion.runtime.precompile_shim import make_precompiler return make_precompiler(_tile_begin_end_kernel)(x, out, out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) - diff --git a/test/test_loops.expected b/test/test_loops.expected index 0bcf1578..7b4ff877 100644 --- a/test/test_loops.expected +++ b/test/test_loops.expected @@ -14,13 +14,13 @@ def _device_loop_3d_kernel(x, out, out_stride_0, out_stride_1, out_stride_2, out pid_0 = tl.program_id(0) offset_0 = pid_0 indices_0 = offset_0 + tl.zeros([1], tl.int32) - for offset_1 in tl.range(0, b.to(tl.int32), _BLOCK_SIZE_1): + for offset_1 in tl.range(0, b.to(tl.int32), step=_BLOCK_SIZE_1): indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) mask_1 = indices_1 < b - for offset_2 in tl.range(0, c.to(tl.int32), _BLOCK_SIZE_2): + for offset_2 in tl.range(0, c.to(tl.int32), step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) mask_2 = indices_2 < c - for offset_3 in tl.range(0, d.to(tl.int32), _BLOCK_SIZE_3): + for offset_3 in tl.range(0, d.to(tl.int32), step=_BLOCK_SIZE_3): indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) mask_3 = indices_3 < d load = tl.load(x + (indices_0[:, None, None, None] * x_stride_0 + indices_1[None, :, None, None] * x_stride_1 + indices_2[None, None, :, None] * x_stride_2 + indices_3[None, None, None, :] * x_stride_3), mask_1[None, :, None, None] & mask_2[None, None, :, None] & mask_3[None, None, None, :], other=0) @@ -59,13 +59,13 @@ def _device_loop_3d_kernel(x, out, out_stride_0, out_stride_1, out_stride_2, out offset_0 = pid_0 * _BLOCK_SIZE_0 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) mask_0 = indices_0 < a - for offset_2 in tl.range(0, c.to(tl.int32), _BLOCK_SIZE_2): + for offset_2 in tl.range(0, c.to(tl.int32), step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) mask_2 = indices_2 < c - for offset_1 in tl.range(0, b.to(tl.int32), _BLOCK_SIZE_1): + for offset_1 in tl.range(0, b.to(tl.int32), step=_BLOCK_SIZE_1): indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) mask_1 = indices_1 < b - for offset_3 in tl.range(0, d.to(tl.int32), 1): + for offset_3 in tl.range(0, d.to(tl.int32), step=1): indices_3 = offset_3 + tl.arange(0, 1).to(tl.int32) load = tl.load(x + (indices_0[:, None, None, None] * x_stride_0 + indices_1[None, :, None, None] * x_stride_1 + indices_2[None, None, :, None] * x_stride_2 + indices_3[None, None, None, :] * x_stride_3), mask_0[:, None, None, None] & mask_1[None, :, None, None] & mask_2[None, None, :, None], other=0) v_0 = tl_math.sin(load) @@ -141,9 +141,9 @@ from torch._inductor.runtime.triton_helpers import math as tl_math def _device_loop_3d_kernel(x, out, out_size_0, out_size_1, out_size_2, out_size_3, x_size_0, x_size_1, x_size_2, x_size_3, out_stride_0, out_stride_1, out_stride_2, out_stride_3, x_stride_0, x_stride_1, x_stride_2, x_stride_3, b, c, d, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): pid_0 = tl.program_id(0) offset_0 = pid_0 * _BLOCK_SIZE_0 - for offset_3 in tl.range(0, d.to(tl.int32), 1): - for offset_1 in tl.range(0, b.to(tl.int32), _BLOCK_SIZE_1): - for offset_2 in tl.range(0, c.to(tl.int32), _BLOCK_SIZE_2): + for offset_3 in tl.range(0, d.to(tl.int32), step=1): + for offset_1 in tl.range(0, b.to(tl.int32), step=_BLOCK_SIZE_1): + for offset_2 in tl.range(0, c.to(tl.int32), step=_BLOCK_SIZE_2): load = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1, x_size_2, x_size_3], [x_stride_0, x_stride_1, x_stride_2, x_stride_3], [offset_0, offset_1, offset_2, offset_3], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, 1], [3, 2, 1, 0]), boundary_check=[0, 1, 2, 3], padding_option='zero') v_0 = tl_math.sin(load) tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1, out_size_2, out_size_3], [out_stride_0, out_stride_1, out_stride_2, out_stride_3], [offset_0, offset_1, offset_2, offset_3], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, 1], [3, 2, 1, 0]), v_0, boundary_check=[0, 1, 2, 3]) @@ -196,7 +196,7 @@ def _chebyshev_kernel_kernel(x, w, out, out_stride_0, out_stride_1, w_stride_0, v_2 = v_0 + v_1 v_3 = 2.0 v_4 = in_x * v_3 - for offset_2 in tl.range(2, 5, 1): + for offset_2 in tl.range(2, 5, step=1): indices_2 = offset_2 + tl.arange(0, 1).to(tl.int32) v_4_copy = v_4 in_x_0_copy = in_x_0 @@ -252,7 +252,7 @@ def _fn_kernel(x, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLO mask_1 = indices_1 < x_size_0 acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32) load = tl.load(end + tl.zeros([], tl.int32), None) - for offset_0 in tl.range(0, load.to(tl.int32), _BLOCK_SIZE_0): + for offset_0 in tl.range(0, load.to(tl.int32), step=_BLOCK_SIZE_0): indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) mask_0 = indices_0 < load acc_copy = acc @@ -293,7 +293,7 @@ def _fn_kernel(x, end, out, out_size_0, x_size_0, out_stride_0, x_stride_0, x_st mask_0 = indices_0 < x_size_0 acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32) load = tl.load(end + tl.zeros([], tl.int32), None) - for offset_1 in tl.range(0, load.to(tl.int32), _BLOCK_SIZE_1): + for offset_1 in tl.range(0, load.to(tl.int32), step=_BLOCK_SIZE_1): indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) mask_1 = indices_1 < load acc_copy = acc @@ -333,10 +333,10 @@ def _fn_kernel(x, end0, end1, out, x_size_0, out_stride_0, x_stride_0, x_stride_ acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float64) load = tl.load(end0 + tl.zeros([], tl.int32), None) load_1 = tl.load(end1 + tl.zeros([], tl.int32), None) - for offset_1 in tl.range(0, load.to(tl.int32), _BLOCK_SIZE_1): + for offset_1 in tl.range(0, load.to(tl.int32), step=_BLOCK_SIZE_1): indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) mask_1 = indices_1 < load - for offset_2 in tl.range(0, load_1.to(tl.int32), _BLOCK_SIZE_2): + for offset_2 in tl.range(0, load_1.to(tl.int32), step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) mask_2 = indices_2 < load_1 acc_copy = acc @@ -379,7 +379,7 @@ def _fn_kernel(x, begin, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_ acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32) load = tl.load(begin + tl.zeros([], tl.int32), None) load_1 = tl.load(end + tl.zeros([], tl.int32), None) - for offset_0 in tl.range(load.to(tl.int32), load_1.to(tl.int32), _BLOCK_SIZE_0): + for offset_0 in tl.range(load.to(tl.int32), load_1.to(tl.int32), step=_BLOCK_SIZE_0): indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) mask_0 = indices_0 < load_1 acc_copy = acc @@ -421,7 +421,7 @@ def _fn_kernel(x, begin, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_ acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32) load = tl.load(begin + tl.zeros([], tl.int32), None) load_1 = tl.load(end + tl.zeros([], tl.int32), None) - for offset_1 in tl.range(load.to(tl.int32), load_1.to(tl.int32), _BLOCK_SIZE_1): + for offset_1 in tl.range(load.to(tl.int32), load_1.to(tl.int32), step=_BLOCK_SIZE_1): indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) mask_1 = indices_1 < load_1 acc_copy = acc @@ -532,7 +532,7 @@ def _fn_kernel(x, out, out_size_0, out_size_1, out_size_2, x_size_0, x_size_1, x pid_1 = tl.program_id(0) // num_blocks_0 offset_0 = pid_0 * _BLOCK_SIZE_0 offset_1 = pid_1 * _BLOCK_SIZE_1 - for offset_2 in tl.range(0, c.to(tl.int32), _BLOCK_SIZE_2): + for offset_2 in tl.range(0, c.to(tl.int32), step=_BLOCK_SIZE_2): load = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1, x_size_2], [x_stride_0, x_stride_1, x_stride_2], [offset_0, offset_1, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero') v_0 = tl_math.sin(load) tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1, out_size_2], [out_stride_0, out_stride_1, out_stride_2], [offset_0, offset_1, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2], [2, 1, 0]), v_0, boundary_check=[0, 1, 2]) @@ -714,7 +714,7 @@ def _addToBoth_kernel(x0, x1, x2, x0_stride_0, x0_stride_1, x1_stride_0, x1_stri offset_0 = pid_0 * _BLOCK_SIZE_0 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) mask_0 = indices_0 < a_n - for offset_1 in tl.range(0, a_m.to(tl.int32), _BLOCK_SIZE_1): + for offset_1 in tl.range(0, a_m.to(tl.int32), step=_BLOCK_SIZE_1): indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) mask_1 = indices_1 < a_m load = tl.load(x0 + (indices_0[:, None] * x0_stride_0 + indices_1[None, :] * x0_stride_1), mask_0[:, None] & mask_1[None, :], other=0) @@ -727,7 +727,7 @@ def _addToBoth_kernel(x0, x1, x2, x0_stride_0, x0_stride_1, x1_stride_0, x1_stri offset_2 = pid_1 * _BLOCK_SIZE_2 indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) mask_2 = indices_2 < b_n - for offset_3 in tl.range(0, b_m.to(tl.int32), _BLOCK_SIZE_3): + for offset_3 in tl.range(0, b_m.to(tl.int32), step=_BLOCK_SIZE_3): indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) mask_3 = indices_3 < b_m load_1 = tl.load(x1 + (indices_2[:, None] * x1_stride_0 + indices_3[None, :] * x1_stride_1), mask_2[:, None] & mask_3[None, :], other=0) @@ -740,7 +740,7 @@ def _addToBoth_kernel(x0, x1, x2, x0_stride_0, x0_stride_1, x1_stride_0, x1_stri offset_4 = pid_2 * _BLOCK_SIZE_4 indices_4 = (offset_4 + tl.arange(0, _BLOCK_SIZE_4)).to(tl.int32) mask_4 = indices_4 < c_n - for offset_5 in tl.range(0, c_m.to(tl.int32), _BLOCK_SIZE_5): + for offset_5 in tl.range(0, c_m.to(tl.int32), step=_BLOCK_SIZE_5): indices_5 = offset_5 + tl.arange(0, _BLOCK_SIZE_5).to(tl.int32) mask_5 = indices_5 < c_m load_2 = tl.load(x2 + (indices_4[:, None] * x2_stride_0 + indices_5[None, :] * x2_stride_1), mask_4[:, None] & mask_5[None, :], other=0) @@ -879,7 +879,7 @@ def _pointwise_device_loop_kernel(x, out, out_stride_0, out_stride_1, x_stride_0 offset_0 = pid_0 * _BLOCK_SIZE_0 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) mask_0 = indices_0 < n - for offset_1 in tl.range(0, m.to(tl.int32), _BLOCK_SIZE_1): + for offset_1 in tl.range(0, m.to(tl.int32), step=_BLOCK_SIZE_1): indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) mask_1 = indices_1 < m load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) @@ -917,7 +917,7 @@ def _nested_loop_kernel_kernel(x, out, x_size_0, x_size_1, out_stride_0, out_str offset_0 = pid_0 * _BLOCK_SIZE_0 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) mask_0 = indices_0 < x_size_0 - for offset_1 in tl.range(0, x_size_1.to(tl.int32), _BLOCK_SIZE_1, loop_unroll_factor=2): + for offset_1 in tl.range(0, x_size_1.to(tl.int32), step=_BLOCK_SIZE_1, loop_unroll_factor=2): indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) @@ -988,10 +988,10 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con pid_0 = tl.program_id(0) offset_0 = pid_0 * _BLOCK_SIZE_0 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) - for offset_1 in tl.range(0, 128, _BLOCK_SIZE_1): + for offset_1 in tl.range(0, 128, step=_BLOCK_SIZE_1): indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, 512, _BLOCK_SIZE_2): + for offset_2 in tl.range(0, 512, step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) acc_copy = acc acc_copy_0 = acc_copy @@ -1021,4 +1021,3 @@ def _matmul_make_precompiler(x: torch.Tensor, y: torch.Tensor): _BLOCK_SIZE_2 = 64 from helion.runtime.precompile_shim import make_precompiler return make_precompiler(_matmul_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) - diff --git a/test/test_loops.py b/test/test_loops.py index 9b5c8a9a..200745df 100644 --- a/test/test_loops.py +++ b/test/test_loops.py @@ -659,10 +659,11 @@ def test_range_num_stages(self): self.assertNotEqual(code0, code3) # Check that range_num_stages parameter appears in tl.range call self.assertNotIn( - "tl.range(0, x_size_1.to(tl.int32), _BLOCK_SIZE_1, num_stages=", code0 + "tl.range(0, x_size_1.to(tl.int32), step=_BLOCK_SIZE_1, num_stages=", code0 ) self.assertIn( - "tl.range(0, x_size_1.to(tl.int32), _BLOCK_SIZE_1, num_stages=3)", code3 + "tl.range(0, x_size_1.to(tl.int32), step=_BLOCK_SIZE_1, num_stages=3)", + code3, ) def test_range_multi_buffers(self): @@ -725,6 +726,135 @@ def test_range_flatten(self): self.assertIn("flatten=True", code_true) self.assertIn("flatten=False", code_false) + def test_static_range_2d(self): + @helion.kernel() + def nested_loop_kernel_2d(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + # The return value of hl.specialized is a LiteralType and thus a tl.constexpr. + # TODO(joydddd): support static_range in for tile_m in hl.tile([x.size(1)]) + M = hl.specialize(x.size(1)) + N = hl.specialize(x.size(2)) + # Outer loop becomes grid (no tl.range) + for tile_outer in hl.tile(x.size(0)): + # Inner loop becomes device loop with tl.range / tl.static_range + # Specialize on x.size(1) to allow range_staitic + for tile_m, tile_n in hl.tile([M, N]): + out[tile_outer, tile_m, tile_n] = x[tile_outer, tile_m, tile_n] + 1 + return out + + args = (torch.randn([64, 32, 4], device=DEVICE),) + + # Test with static_ranges = [True] (use tl.static_range for device loop) + code_true, result_true = code_and_output( + nested_loop_kernel_2d, args, block_sizes=[16, 16, 1], static_ranges=[True] + ) + + # Test with static_ranges = [False] (use tl.range for device loop) + code_false, result_false = code_and_output( + nested_loop_kernel_2d, args, block_sizes=[16, 16, 1], static_ranges=[False] + ) + + # Test default + code_default, result_default = code_and_output( + nested_loop_kernel_2d, args, block_sizes=[16, 16, 1] + ) + + # Ignore range kwargs when static_range is set to Ture. + code_ignore, result_ignore = code_and_output( + nested_loop_kernel_2d, + args, + block_sizes=[16, 16, 1], + static_ranges=[True], + range_unroll_factors=[2], + range_num_stages=[3], + range_multi_buffers=[True], + range_flattens=[True], + ) + + torch.testing.assert_close(result_false, result_true) + torch.testing.assert_close(result_true, args[0] + 1) + self.assertEqual(code_default, code_false) + self.assertEqual(code_ignore, code_true) + self.assertNotEqual(code_true, code_false) + # Check that tl.range / tl.static_range is used according to setups. + self.assertIn("tl.range", code_false) + self.assertIn("tl.static_range", code_true) + + def test_static_range_scalar(self): + @helion.kernel() + def nested_loop_kernel_scalar(x: torch.Tensor) -> torch.Tensor: + world_size = 4 + # Outer loop becomes grid (no tl.range) + for tile_outer in hl.tile(x.size(0)): + # Inner loop becomes device loop with tl.range / tl.static_range + # Specialize on x.size(1) to allow range_staitic + for _rank in range(world_size): + x[tile_outer] = x[tile_outer] + 1 + return x + + x = torch.randn([64], device=DEVICE) + + # Test with static_ranges = [True] (use tl.static_range for device loop) + code_true, result_true = code_and_output( + nested_loop_kernel_scalar, + (x.clone(),), + block_sizes=[16], + static_ranges=[True], + ) + + # Test with static_ranges = [False] (use tl.range for device loop) + code_false, result_false = code_and_output( + nested_loop_kernel_scalar, + (x.clone(),), + block_sizes=[16], + static_ranges=[False], + ) + + # Test default + code_default, result_default = code_and_output( + nested_loop_kernel_scalar, + (x.clone(),), + block_sizes=[ + 16, + ], + ) + + torch.testing.assert_close(result_default, result_true) + torch.testing.assert_close(result_default, result_false) + torch.testing.assert_close(result_default, x + 4) + self.assertNotEqual(code_default, code_true) + self.assertNotEqual(code_true, code_false) + self.assertEqual(code_default, code_false) + # Check that tl.range / tl.static_range is used according to setups. + self.assertIn("tl.range", code_false) + self.assertIn("tl.static_range", code_true) + + @unittest.skip("TODO(joydddd): handle constexpr type casting.") + def test_static_range_casting(self): + @helion.kernel() + def nested_loop_kernel_w_casting(x: torch.Tensor) -> torch.Tensor: + world_size = 4 + # Outer loop becomes grid (no tl.range) + for tile_outer in hl.tile(x.size(0)): + # Inner loop becomes device loop with tl.range / tl.static_range + # Specialize on x.size(1) to allow range_staitic + for rank in range(world_size): + x[tile_outer] = x[tile_outer] + rank + return x + + x = torch.randn([64], device=DEVICE) + + # Test with static_ranges = [True] (use tl.static_range for device loop) + code, result = code_and_output( + nested_loop_kernel_w_casting, + (x.clone(),), + block_sizes=[16], + static_ranges=[True], + ) + + torch.testing.assert_close(result, x + 5) + self.assertIn("tl.static_range", code) + if __name__ == "__main__": unittest.main() diff --git a/test/test_masking.expected b/test/test_masking.expected index 141ae3f9..6b6aa866 100644 --- a/test/test_masking.expected +++ b/test/test_masking.expected @@ -15,7 +15,7 @@ def _fn_kernel(out, out_stride_0, m, n, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) mask_1 = indices_1 < m acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32) - for offset_0 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_0): + for offset_0 in tl.range(0, n.to(tl.int32), step=_BLOCK_SIZE_0): indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) mask_0 = indices_0 < n acc_copy = acc @@ -66,7 +66,7 @@ def _add1mm_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) mask_1 = indices_1 < n acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + for offset_2 in tl.range(0, k.to(tl.int32), step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) mask_2 = indices_2 < k acc_copy = acc @@ -114,7 +114,7 @@ def _fn_kernel(x, out, out_size_0, x_size_0, x_size_1, out_stride_0, x_stride_0, pid_0 = tl.program_id(0) offset_1 = pid_0 * _BLOCK_SIZE_1 acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32) - for offset_0 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_0): + for offset_0 in tl.range(0, n.to(tl.int32), step=_BLOCK_SIZE_0): acc_copy = acc acc_copy_0 = acc_copy load = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1], [x_stride_0, x_stride_1], [offset_1, offset_0], [_BLOCK_SIZE_1, _BLOCK_SIZE_0], [1, 0]), boundary_check=[0, 1], padding_option='zero') @@ -139,4 +139,3 @@ def _fn_make_precompiler(x): _BLOCK_SIZE_0 = 32 from helion.runtime.precompile_shim import make_precompiler return make_precompiler(_fn_kernel)(x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), n, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3) - diff --git a/test/test_matmul.expected b/test/test_matmul.expected index 4f671839..b5668bf6 100644 --- a/test/test_matmul.expected +++ b/test/test_matmul.expected @@ -25,7 +25,7 @@ def _matmul_without_addmm_kernel(x, y, out, out_stride_0, out_stride_1, x_stride indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) mask_1 = indices_1 < n acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + for offset_2 in tl.range(0, k.to(tl.int32), step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) mask_2 = indices_2 < k acc_copy = acc @@ -73,7 +73,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.con offset_0 = pid_1 * _BLOCK_SIZE_0 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, 128, _BLOCK_SIZE_2): + for offset_2 in tl.range(0, 128, step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) acc_copy = acc acc_copy_0 = acc_copy @@ -128,7 +128,7 @@ def _matmul_with_addmm_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) mask_1 = indices_1 < n acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + for offset_2 in tl.range(0, k.to(tl.int32), step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) mask_2 = indices_2 < k acc_copy = acc @@ -178,7 +178,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con offset_0 = pid_0 * _BLOCK_SIZE_0 offset_1 = pid_1 * _BLOCK_SIZE_1 acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, 128, _BLOCK_SIZE_2): + for offset_2 in tl.range(0, 128, step=_BLOCK_SIZE_2): acc_copy = acc acc_copy_0 = acc_copy load = tl.load(tl.make_block_ptr(x, [128, 128], [128, 1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_2], [1, 0]), boundary_check=[0, 1], padding_option='zero') @@ -231,7 +231,7 @@ def _matmul_split_k_kernel(x, y, out, x_size_0, x_size_1, y_size_0, y_size_1, ou offset_2 = pid_2 * _BLOCK_SIZE_2 acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) tile_end = tl.minimum(offset_2 + _BLOCK_SIZE_2, k) - for offset_3 in tl.range(offset_2.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_3): + for offset_3 in tl.range(offset_2.to(tl.int32), tile_end.to(tl.int32), step=_BLOCK_SIZE_3): acc_copy = acc acc_copy_0 = acc_copy load = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1], [x_stride_0, x_stride_1], [offset_0, offset_3], [_BLOCK_SIZE_0, _BLOCK_SIZE_3], [1, 0]), boundary_check=[0, 1], padding_option='zero') @@ -283,7 +283,7 @@ def _matmul_static_shapes_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_ offset_1 = pid_1 * _BLOCK_SIZE_1 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, 128, _BLOCK_SIZE_2): + for offset_2 in tl.range(0, 128, step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) acc_copy = acc acc_copy_0 = acc_copy @@ -337,7 +337,7 @@ def _matmul_static_shapes_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_ offset_1 = pid_1 * _BLOCK_SIZE_1 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, 128, _BLOCK_SIZE_2): + for offset_2 in tl.range(0, 128, step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) acc_copy = acc acc_copy_0 = acc_copy @@ -391,7 +391,7 @@ def _matmul_static_shapes_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_ offset_1 = pid_1 * _BLOCK_SIZE_1 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, 127, _BLOCK_SIZE_2): + for offset_2 in tl.range(0, 127, step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) mask_2 = indices_2 < 127 acc_copy = acc @@ -448,7 +448,7 @@ def _matmul_static_shapes_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) mask_1 = indices_1 < 127 acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, 128, _BLOCK_SIZE_2): + for offset_2 in tl.range(0, 128, step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) acc_copy = acc acc_copy_0 = acc_copy @@ -506,7 +506,7 @@ def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.con offset_0 = pid_0 * _BLOCK_SIZE_0 offset_1 = pid_1 * _BLOCK_SIZE_1 acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, 128, _BLOCK_SIZE_2): + for offset_2 in tl.range(0, 128, step=_BLOCK_SIZE_2): acc_copy = acc acc_copy_0 = acc_copy load = x_desc.load([offset_0, offset_2]) @@ -535,4 +535,3 @@ def _matmul_make_precompiler(x: torch.Tensor, y: torch.Tensor): _BLOCK_SIZE_2 = 16 from helion.runtime.precompile_shim import make_precompiler return make_precompiler(_matmul_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) - diff --git a/test/test_misc.expected b/test/test_misc.expected index ce2e4621..5c1ab3ec 100644 --- a/test/test_misc.expected +++ b/test/test_misc.expected @@ -67,7 +67,7 @@ def _fn_kernel(x, out, out_stride_0, x_stride_0, x_stride_1, m, n, _BLOCK_SIZE_1 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) mask_1 = indices_1 < m acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0, tl.float32) - for offset_0 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_0): + for offset_0 in tl.range(0, n.to(tl.int32), step=_BLOCK_SIZE_0): indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) mask_0 = indices_0 < n acc_copy = acc @@ -94,4 +94,3 @@ def _fn_make_precompiler(x: torch.Tensor): _BLOCK_SIZE_0 = 64 from helion.runtime.precompile_shim import make_precompiler return make_precompiler(_fn_kernel)(x, out, out.stride(0), x.stride(0), x.stride(1), m, n, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3) - diff --git a/test/test_persistent_kernels.expected b/test/test_persistent_kernels.expected index c4bc44c1..f77e850e 100644 --- a/test/test_persistent_kernels.expected +++ b/test/test_persistent_kernels.expected @@ -334,7 +334,7 @@ def _matmul_kernel_kernel(A, B, result, A_stride_0, A_stride_1, B_stride_0, B_st indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) mask_1 = indices_1 < N acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, K.to(tl.int32), _BLOCK_SIZE_2): + for offset_2 in tl.range(0, K.to(tl.int32), step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) mask_2 = indices_2 < K acc_copy = acc @@ -388,7 +388,7 @@ def _matmul_kernel_kernel(A, B, result, A_stride_0, A_stride_1, B_stride_0, B_st indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) mask_1 = indices_1 < N acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, K.to(tl.int32), _BLOCK_SIZE_2): + for offset_2 in tl.range(0, K.to(tl.int32), step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) mask_2 = indices_2 < K acc_copy = acc @@ -786,7 +786,7 @@ def _matmul_kernel_kernel(A, B, result, A_stride_0, A_stride_1, B_stride_0, B_st indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) mask_1 = indices_1 < N acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, K.to(tl.int32), _BLOCK_SIZE_2): + for offset_2 in tl.range(0, K.to(tl.int32), step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) mask_2 = indices_2 < K acc_copy = acc @@ -840,7 +840,7 @@ def _matmul_kernel_kernel(A, B, result, A_stride_0, A_stride_1, B_stride_0, B_st indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) mask_1 = indices_1 < N acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) - for offset_2 in tl.range(0, K.to(tl.int32), _BLOCK_SIZE_2): + for offset_2 in tl.range(0, K.to(tl.int32), step=_BLOCK_SIZE_2): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) mask_2 = indices_2 < K acc_copy = acc @@ -1912,4 +1912,3 @@ def _simple_add_make_precompiler(x: torch.Tensor, y: torch.Tensor): _BLOCK_SIZE_1 = 16 from helion.runtime.precompile_shim import make_precompiler return make_precompiler(_simple_add_kernel)(x, y, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) - diff --git a/test/test_reductions.expected b/test/test_reductions.expected index d9b3791b..ae0eaa04 100644 --- a/test/test_reductions.expected +++ b/test/test_reductions.expected @@ -52,7 +52,7 @@ def _reduce_kernel_kernel(x, out, out_size_0, x_size_0, x_size_1, out_stride_0, offset_0 = pid_0 argmax_acc = tl.full([1, _REDUCTION_BLOCK_1], float('-inf'), tl.float32) argmax_acc_index = tl.full([1, _REDUCTION_BLOCK_1], 2147483647, tl.int32) - for roffset_1 in tl.range(0, _m, _REDUCTION_BLOCK_1): + for roffset_1 in tl.range(0, _m, step=_REDUCTION_BLOCK_1): rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32) mask_1 = rindex_1 < _m load = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1], [x_stride_0, x_stride_1], [offset_0, roffset_1], [1, _REDUCTION_BLOCK_1], [1, 0]), boundary_check=[0, 1], padding_option='zero') @@ -249,7 +249,7 @@ def _sum_kernel_kernel(x, out, out_stride_0, x_stride_0, x_stride_1, n, _m, _BLO indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) mask_0 = indices_0 < n sum_1_acc = tl.full([_BLOCK_SIZE_0, _REDUCTION_BLOCK_1], 0, tl.float32) - for roffset_1 in tl.range(0, _m, _REDUCTION_BLOCK_1): + for roffset_1 in tl.range(0, _m, step=_REDUCTION_BLOCK_1): rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32) mask_1 = rindex_1 < _m load = tl.load(x + (indices_0[:, None] * x_stride_0 + rindex_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) @@ -273,4 +273,3 @@ def _sum_kernel_make_precompiler(x: torch.Tensor): _REDUCTION_BLOCK_1 = 64 from helion.runtime.precompile_shim import make_precompiler return make_precompiler(_sum_kernel_kernel)(x, out, out.stride(0), x.stride(0), x.stride(1), n, _m, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3) - diff --git a/test/test_register_tunable.expected b/test/test_register_tunable.expected index ddc89ce1..3c600002 100644 --- a/test/test_register_tunable.expected +++ b/test/test_register_tunable.expected @@ -66,7 +66,7 @@ def _matmul_split_k_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_ mask_0 = indices_0 < m acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) tile_end = tl.minimum(offset_2 + _BLOCK_SIZE_2, k) - for offset_3 in tl.range(offset_2.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_3): + for offset_3 in tl.range(offset_2.to(tl.int32), tile_end.to(tl.int32), step=_BLOCK_SIZE_3): indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) mask_3 = indices_3 < tile_end acc_copy = acc @@ -176,4 +176,3 @@ def _fn_make_precompiler(x: torch.Tensor): _BLOCK_SIZE_0 = 64 from helion.runtime.precompile_shim import make_precompiler return make_precompiler(_fn_kernel)(x, partial, partial.stride(0), x.stride(0), m, _BLOCK_SIZE_0, num_warps=4, num_stages=3) - diff --git a/test/test_tensor_descriptor.expected b/test/test_tensor_descriptor.expected index 20fedfa6..c9c0b04e 100644 --- a/test/test_tensor_descriptor.expected +++ b/test/test_tensor_descriptor.expected @@ -31,7 +31,7 @@ def _attention_kernel(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2, l_i = tl.full([1, _BLOCK_SIZE_1], 1.0, tl.float32) acc = tl.full([1, _BLOCK_SIZE_1, 64], 0.0, tl.float32) q = q_view_desc.load([offset_0, offset_1, 0]) - for offset_2 in tl.range(0, n_dim.to(tl.int32), _BLOCK_SIZE_3): + for offset_2 in tl.range(0, n_dim.to(tl.int32), step=_BLOCK_SIZE_3): indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) mask_3 = indices_2 < n_dim q_copy = q @@ -133,7 +133,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr, l_i = tl.full([1, _BLOCK_SIZE_1], 1.0, tl.float32) acc = tl.full([1, _BLOCK_SIZE_1, 64], 0.0, tl.float32) q = q_view_desc.load([offset_0, offset_1, 0]) - for offset_2 in tl.range(0, 512, _BLOCK_SIZE_3): + for offset_2 in tl.range(0, 512, step=_BLOCK_SIZE_3): q_copy = q m_i_copy = m_i l_i_copy = l_i @@ -206,4 +206,3 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to _BLOCK_SIZE_3 = 64 from helion.runtime.precompile_shim import make_precompiler return make_precompiler(_attention_kernel)(q_view, k_view, v_view, out, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3) -