Skip to content

Commit dcfa500

Browse files
authored
Add tl._experimental_make_tensor_descriptor support (#322)
1 parent 566045a commit dcfa500

File tree

6 files changed

+43
-16
lines changed

6 files changed

+43
-16
lines changed

helion/_compat.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,17 @@ def _supports_tensor_descriptor() -> bool:
2222
major, _ = torch.cuda.get_device_capability(torch.cuda.current_device())
2323
if major < 9:
2424
return False
25-
return hasattr(triton.language, "make_tensor_descriptor")
25+
return hasattr(triton.language, "make_tensor_descriptor") or hasattr(
26+
triton.language, "_experimental_make_tensor_descriptor"
27+
)
28+
29+
30+
@functools.cache
31+
def get_tensor_descriptor_fn_name() -> str:
32+
if hasattr(triton.language, "make_tensor_descriptor"):
33+
return "tl.make_tensor_descriptor"
34+
assert hasattr(triton.language, "_experimental_make_tensor_descriptor")
35+
return "tl._experimental_make_tensor_descriptor"
2636

2737

2838
@functools.cache

helion/_compiler/device_function.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torch._inductor.codegen.triton import texpr
1818
from torch.fx.graph import _Namespace
1919

20+
from .._compat import get_tensor_descriptor_fn_name
2021
from .ast_extension import ExtendedAST
2122
from .ast_extension import create
2223
from .ast_extension import create_arg
@@ -347,8 +348,9 @@ def tensor_descriptor_arg(
347348
sizes = ", ".join([arg.name for arg in size_args])
348349
strides = ", ".join([arg.name for arg in stride_args])
349350

351+
tensor_descriptor_fn_name = get_tensor_descriptor_fn_name()
350352
descriptor_stmt = statement_from_string(
351-
f"{desc_name} = tl.make_tensor_descriptor({tensor_arg.name}, [{sizes}], [{strides}], [{block_size_expr}])"
353+
f"{desc_name} = {tensor_descriptor_fn_name}({tensor_arg.name}, [{sizes}], [{strides}], [{block_size_expr}])"
352354
)
353355
self.preamble.append(descriptor_stmt)
354356

helion/_testing.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from triton.testing import do_bench
1717

1818
from .runtime.config import Config
19+
from helion._compat import get_tensor_descriptor_fn_name
1920

2021
if TYPE_CHECKING:
2122
import types
@@ -220,6 +221,12 @@ def normalize_id(test_id: str) -> str:
220221
assert match, f"Test ID '{test_id}' does not match expected format"
221222
return match.group(1)
222223

224+
@staticmethod
225+
def normalize_tensor_descriptors(code: str) -> str:
226+
return code.replace(
227+
get_tensor_descriptor_fn_name(), "tl.make_tensor_descriptor"
228+
)
229+
223230
def lookup(self, test_id: str, value: str) -> tuple[str, str]:
224231
test_id = self.normalize_id(test_id)
225232
if self._current_id != test_id:
@@ -234,6 +241,7 @@ def lookup(self, test_id: str, value: str) -> tuple[str, str]:
234241
expected_values.append("")
235242
expected = ""
236243

244+
value = self.normalize_tensor_descriptors(value)
237245
value = value.strip()
238246
if value != expected and os.environ.get("EXPECTTEST_ACCEPT", "0") not in {
239247
"0",

test/test_indexing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66

77
import helion
8+
from helion._compat import get_tensor_descriptor_fn_name
89
from helion._compat import supports_tensor_descriptor
910
from helion._testing import DEVICE
1011
from helion._testing import TestCase
@@ -366,6 +367,10 @@ def test_broadcasting_block_ptr_indexing(self):
366367
self.assertExpectedJournal(code)
367368

368369
@unittest.skipIf(not supports_tensor_descriptor(), "TensorDescriptor not supported")
370+
@unittest.skipIf(
371+
get_tensor_descriptor_fn_name() == "tl._experimental_make_tensor_descriptor",
372+
"LLVM ERROR: Illegal shared layout",
373+
)
369374
def test_broadcasting_tensor_descriptor_indexing(self):
370375
x = torch.randn([16, 24, 32], device=DEVICE)
371376
bias1 = torch.randn([1, 24, 32], device=DEVICE)

test/test_persistent_kernels.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66

77
import helion
8+
from helion._compat import get_tensor_descriptor_fn_name
89
from helion._compat import supports_tensor_descriptor
910
from helion._testing import DEVICE
1011
from helion._testing import TestCase
@@ -999,8 +1000,8 @@ def tensor_descriptor_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
9991000
torch.testing.assert_close(result_interleaved, expected)
10001001

10011002
# Verify tensor descriptor features in code
1002-
self.assertIn("tl.make_tensor_descriptor", code_blocked)
1003-
self.assertIn("tl.make_tensor_descriptor", code_interleaved)
1003+
self.assertIn(get_tensor_descriptor_fn_name(), code_blocked)
1004+
self.assertIn(get_tensor_descriptor_fn_name(), code_interleaved)
10041005

10051006
# Verify persistent kernel features
10061007
self.assertIn("for virtual_pid in tl.range", code_blocked)

test/test_tensor_descriptor.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66

77
import helion
8+
from helion._compat import get_tensor_descriptor_fn_name
89
from helion._compat import supports_tensor_descriptor
910
from helion._testing import DEVICE
1011
from helion._testing import TestCase
@@ -41,15 +42,15 @@ def kernel_with_permutation(x: torch.Tensor) -> torch.Tensor:
4142
kernel_with_permutation,
4243
(x,),
4344
indexing="tensor_descriptor",
44-
block_sizes=[4, 8],
45+
block_sizes=[8, 8],
4546
)
4647

4748
# Check that the result is correct
4849
expected = x + 1.0
4950
torch.testing.assert_close(result, expected)
5051

5152
# Check that the generated code contains permutation calls
52-
self.assertIn("tl.make_tensor_descriptor", code)
53+
self.assertIn(get_tensor_descriptor_fn_name(), code)
5354
# The tensor descriptor should be created with permuted dimensions
5455
# (sizes and strides should be reordered so stride==1 dim is last)
5556

@@ -77,15 +78,15 @@ def kernel_no_permutation(x: torch.Tensor) -> torch.Tensor:
7778
kernel_no_permutation,
7879
(x,),
7980
indexing="tensor_descriptor",
80-
block_sizes=[4, 8],
81+
block_sizes=[8, 8],
8182
)
8283

8384
# Check that the result is correct
8485
expected = x * 2.0
8586
torch.testing.assert_close(result, expected)
8687

8788
# Check that the generated code contains tensor descriptor
88-
self.assertIn("tl.make_tensor_descriptor", code)
89+
self.assertIn(get_tensor_descriptor_fn_name(), code)
8990
# Should not contain permute calls since no permutation needed
9091
self.assertNotIn("tl.permute", code)
9192

@@ -121,7 +122,7 @@ def kernel_3d_permutation(x: torch.Tensor) -> torch.Tensor:
121122
torch.testing.assert_close(result, expected)
122123

123124
# Should contain both tensor descriptor and permute operations
124-
self.assertIn("tl.make_tensor_descriptor", code)
125+
self.assertIn(get_tensor_descriptor_fn_name(), code)
125126
self.assertIn("tl.permute", code)
126127

127128
@unittest.skipUnless(
@@ -149,15 +150,15 @@ def kernel_transpose_case(x: torch.Tensor) -> torch.Tensor:
149150
kernel_transpose_case,
150151
(x,),
151152
indexing="tensor_descriptor",
152-
block_sizes=[4, 8],
153+
block_sizes=[8, 8],
153154
)
154155

155156
# Check correctness
156157
expected = x * x
157158
torch.testing.assert_close(result, expected)
158159

159160
# Should handle the permutation properly
160-
self.assertIn("tl.make_tensor_descriptor", code)
161+
self.assertIn(get_tensor_descriptor_fn_name(), code)
161162
self.assertIn("tl.permute", code)
162163

163164
@unittest.skipUnless(
@@ -183,14 +184,14 @@ def kernel_different_blocks(x: torch.Tensor) -> torch.Tensor:
183184
kernel_different_blocks,
184185
(x,),
185186
indexing="tensor_descriptor",
186-
block_sizes=[4, 8],
187+
block_sizes=[8, 8],
187188
)
188189

189190
expected = x + 5.0
190191
torch.testing.assert_close(result, expected)
191192

192193
# Should contain permutation and tensor descriptor
193-
self.assertIn("tl.make_tensor_descriptor", code)
194+
self.assertIn(get_tensor_descriptor_fn_name(), code)
194195
self.assertIn("tl.permute", code)
195196

196197
# The block sizes should also be permuted in the tensor descriptor
@@ -223,14 +224,14 @@ def kernel_store_permutation(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
223224
kernel_store_permutation,
224225
(x, y),
225226
indexing="tensor_descriptor",
226-
block_sizes=[4, 8],
227+
block_sizes=[8, 8],
227228
)
228229

229230
expected = x * 3.0
230231
torch.testing.assert_close(result, expected)
231232

232233
# Should have permutation for both load and store
233-
self.assertIn("tl.make_tensor_descriptor", code)
234+
self.assertIn(get_tensor_descriptor_fn_name(), code)
234235
self.assertIn("tl.permute", code)
235236

236237
@unittest.skipUnless(
@@ -301,7 +302,7 @@ def kernel_small_block(x: torch.Tensor) -> torch.Tensor:
301302

302303
# Should fall back to block_ptr or pointer indexing instead of tensor descriptor
303304
# If our fix works, this should NOT contain tensor descriptor
304-
self.assertNotIn("tl.make_tensor_descriptor", code)
305+
self.assertNotIn(get_tensor_descriptor_fn_name(), code)
305306

306307
# But should still work correctly
307308
expected = x + 1.0

0 commit comments

Comments
 (0)