Skip to content

Commit 03b58ca

Browse files
authored
Fix bug with tensor descriptor and small block size (#296)
1 parent 18a07a3 commit 03b58ca

File tree

3 files changed

+85
-25
lines changed

3 files changed

+85
-25
lines changed

helion/_compiler/compile_environment.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def symbol(self) -> sympy.Symbol:
441441
return self.var._sympy_()
442442

443443
def from_config(self, config: Config) -> int | torch.SymInt | None:
444-
return self.block_size_source.from_config(config, self.block_id)
444+
return self.block_size_source.from_config(config, self)
445445

446446
def from_config_assert(self, config: Config) -> int | torch.SymInt:
447447
val = self.from_config(config)
@@ -461,7 +461,9 @@ def update_min_block(self, value: int, *, allow_flattened: bool = True) -> None:
461461

462462

463463
class BlockSizeSource:
464-
def from_config(self, config: Config, block_id: int) -> int | torch.SymInt | None:
464+
def from_config(
465+
self, config: Config, block_size_info: BlockSizeInfo
466+
) -> int | torch.SymInt | None:
465467
raise NotImplementedError
466468

467469
def l2_grouping(self, config: Config) -> int:
@@ -472,15 +474,17 @@ def l2_grouping(self, config: Config) -> int:
472474
class FixedBlockSizeSource(BlockSizeSource):
473475
value: int | torch.SymInt
474476

475-
def from_config(self, config: Config, block_id: int) -> int | torch.SymInt:
477+
def from_config(
478+
self, config: Config, block_size_info: BlockSizeInfo
479+
) -> int | torch.SymInt:
476480
return self.value
477481

478482

479483
@dataclasses.dataclass
480484
class LoopSpecBlockSizeSource(BlockSizeSource):
481-
def from_config(self, config: Config, block_id: int) -> int:
485+
def from_config(self, config: Config, block_size_info: BlockSizeInfo) -> int:
482486
index = CompileEnvironment.current().config_spec.block_sizes.block_id_to_index(
483-
block_id
487+
block_size_info.block_id
484488
)
485489
return config.block_sizes[index]
486490

@@ -489,7 +493,12 @@ def from_config(self, config: Config, block_id: int) -> int:
489493
class ReductionLoopBlockSizeSource(BlockSizeSource):
490494
reduction_loop: int
491495

492-
def from_config(self, config: Config, block_id: int) -> int | None:
496+
def from_config(self, config: Config, block_size_info: BlockSizeInfo) -> int | None:
497+
if (
498+
len(config.reduction_loops) <= self.reduction_loop
499+
or config.reduction_loops[self.reduction_loop] is None
500+
):
501+
return next_power_of_2(block_size_info.size_hint())
493502
return config.reduction_loops[self.reduction_loop]
494503

495504

helion/_compiler/indexing_strategy.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .. import exc
1313
from .ast_extension import expr_from_string
1414
from .compile_environment import CompileEnvironment
15+
from .device_function import DeviceFunction
1516
from .host_function import HostFunction
1617
from .tile_strategy import DeviceLoopState
1718
from .variable_origin import BlockSizeOrigin
@@ -178,9 +179,40 @@ def is_supported(
178179
byte_stride = stride * element_size
179180
if byte_stride % 16 != 0:
180181
return False
182+
if stride_one_count != 1:
183+
# There should be exactly one dimension with stride==1
184+
return False
185+
186+
def valid_block_size(
187+
block_size: int | torch.SymInt | None, stride: int | torch.SymInt
188+
) -> bool:
189+
if not isinstance(block_size, int):
190+
return False
191+
# was getting some IMAs with small block sizes even in non-stride 1 dims
192+
return block_size * element_size >= 16 or (block_size == 1 and stride != 1)
193+
194+
# 4) Check minimum 16 bytes in each dimension
195+
size_stride = collections.deque(
196+
zip(fake_tensor.size(), fake_tensor.stride(), strict=True)
197+
)
198+
config = DeviceFunction.current().config
199+
for k in subscript:
200+
if k is None:
201+
continue
202+
size, stride = size_stride.popleft()
203+
if str(k) == "slice(None, None, None)":
204+
block_size = env.allocate_reduction_dimension(size).from_config(config)
205+
if not valid_block_size(block_size, stride):
206+
return False
207+
elif isinstance(k, torch.SymInt):
208+
block_id = env.get_block_id(k)
209+
if block_id is None:
210+
return False
211+
block_size = env.block_sizes[block_id].from_config(config)
212+
if not valid_block_size(block_size, stride):
213+
return False
181214

182-
# TODO(jansel): check that base_ptr is aligned to 16 bytes
183-
return stride_one_count == 1
215+
return True
184216

185217
def codegen_load(
186218
self,

test/test_tensor_descriptor.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -109,27 +109,11 @@ def kernel_3d_permutation(x: torch.Tensor) -> torch.Tensor:
109109
base_tensor = torch.randn(storage_size, device=DEVICE, dtype=torch.float32)
110110
x = base_tensor.as_strided([4, 8, 4], [64, 1, 4])
111111

112-
# Verify stride pattern - middle dimension should have stride 1, others 16-byte aligned
113-
self.assertEqual(x.stride(), (64, 1, 4)) # Expected stride pattern
114-
self.assertEqual(x.stride()[1], 1) # middle dimension has stride 1
115-
116-
# Check 16-byte alignment for non-stride-1 dimensions
117-
element_size = x.element_size()
118-
for dim in range(x.ndim):
119-
stride = x.stride(dim)
120-
if stride != 1:
121-
byte_stride = stride * element_size
122-
self.assertEqual(
123-
byte_stride % 16,
124-
0,
125-
f"Dim {dim} not 16-byte aligned: stride={stride}, byte_stride={byte_stride}",
126-
)
127-
128112
code, result = code_and_output(
129113
kernel_3d_permutation,
130114
(x,),
131115
indexing="tensor_descriptor",
132-
block_sizes=[2, 4, 2],
116+
block_sizes=[8, 8, 8],
133117
)
134118

135119
# Check correctness
@@ -288,6 +272,41 @@ def test_attention_td_dynamic(self):
288272
)
289273
)
290274

275+
@unittest.skipUnless(
276+
supports_tensor_descriptor(), "Tensor descriptor support is required"
277+
)
278+
def test_minimum_16_byte_block_size_fallback(self):
279+
"""Test that tensor descriptor falls back when block size is too small."""
280+
281+
@helion.kernel(use_default_config=True)
282+
def kernel_small_block(x: torch.Tensor) -> torch.Tensor:
283+
result = torch.zeros_like(x)
284+
for tile in hl.tile(x.size()):
285+
result[tile] = x[tile] + 1.0
286+
return result
287+
288+
# Create a tensor with proper stride alignment
289+
x = torch.randn([8, 16], device=DEVICE, dtype=torch.float32)
290+
291+
# Use small block sizes that would result in < 16 bytes in last dimension
292+
# block_sizes=[4, 2] means last dimension block size = 2
293+
# 2 * 4 bytes (float32) = 8 bytes < 16 bytes required
294+
# With the fix, this should fall back to another indexing strategy
295+
code, result = code_and_output(
296+
kernel_small_block,
297+
(x,),
298+
indexing="tensor_descriptor", # Request tensor descriptor
299+
block_sizes=[4, 2], # Small block size in last dimension
300+
)
301+
302+
# Should fall back to block_ptr or pointer indexing instead of tensor descriptor
303+
# If our fix works, this should NOT contain tensor descriptor
304+
self.assertNotIn("tl.make_tensor_descriptor", code)
305+
306+
# But should still work correctly
307+
expected = x + 1.0
308+
torch.testing.assert_close(result, expected)
309+
291310

292311
if __name__ == "__main__":
293312
unittest.main()

0 commit comments

Comments
 (0)