Skip to content

Commit 28aeec1

Browse files
committed
tensor[tile] when tile size is 1 returns a 1D tensor, instead of a scalar
1 parent a44aabb commit 28aeec1

File tree

5 files changed

+139
-99
lines changed

5 files changed

+139
-99
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,9 +378,8 @@ def create(
378378
assert len(index_values) == fake_value.ndim
379379
index_expr = []
380380
for i, idx in enumerate(index_values):
381-
if fake_value.size(i) != 1:
382-
stride = state.device_function.tensor_stride(fake_value, i).name
383-
index_expr.append(f"{idx} * {stride}")
381+
stride = state.device_function.tensor_stride(fake_value, i).name
382+
index_expr.append(f"{idx} * {stride}")
384383
if not index_expr:
385384
shape_str = tile_strategy.shape_str(output_size)
386385
index_expr.append(f"tl.zeros({shape_str}, {dtype})")

helion/_compiler/tile_strategy.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -670,11 +670,19 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
670670
type_comment=None,
671671
)
672672
assert for_node.body is body
673-
extra_body = [
674-
statement_from_string(
675-
f"{index_var} = {offset_var} + tl.arange(0, ({block_size_var})).to({dtype})"
676-
),
677-
]
673+
extra_body = []
674+
if block_size == 1:
675+
extra_body.append(
676+
statement_from_string(
677+
f"{index_var} = {offset_var} + tl.zeros([1], {dtype})"
678+
),
679+
)
680+
else:
681+
extra_body.append(
682+
statement_from_string(
683+
f"{index_var} = {offset_var} + tl.arange(0, ({block_size_var})).to({dtype})"
684+
),
685+
)
678686
mask_statement = self._setup_mask( # pyright: ignore[reportAttributeAccessIssue]
679687
state, block_idx, block_size, index_var, end
680688
)

0 commit comments

Comments
 (0)