Skip to content

Commit 46b96f9

Browse files
authored
Fix TensorDescriptor handling in _find_device (#35)
1 parent 57b631d commit 46b96f9

File tree

4 files changed

+25
-9
lines changed

4 files changed

+25
-9
lines changed

helion/_compat.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,20 @@ def _supports_tensor_descriptor() -> bool:
2323
if major < 9:
2424
return False
2525
try:
26-
return get_triton_tensor_descriptor_import_path() is not None
26+
return get_triton_tensor_descriptor_class() is not None
2727
except ImportError:
2828
return False
2929

3030

3131
@functools.cache
32-
def get_triton_tensor_descriptor_import_path() -> str:
33-
"""Attempt to import TensorDescriptor object from known Triton modules."""
32+
def get_triton_tensor_descriptor_class_import_path() -> str:
33+
cls = get_triton_tensor_descriptor_class()
34+
return f"from {cls.__module__} import {cls.__qualname__}"
35+
36+
37+
@functools.cache
38+
def get_triton_tensor_descriptor_class() -> type[object]:
39+
"""Attempt to import TensorDescriptor class from known Triton modules."""
3440
possible_modules = [
3541
"triton.tools.experimental_descriptor",
3642
"triton.tools.tensor_descriptor",
@@ -39,10 +45,12 @@ def get_triton_tensor_descriptor_import_path() -> str:
3945
try:
4046
module = importlib.import_module(module_name)
4147
if hasattr(module, "TensorDescriptor"):
42-
return f"from {module_name} import TensorDescriptor"
48+
return module.TensorDescriptor
4349
except ImportError:
4450
continue
45-
raise ImportError("TensorDescriptor not found in any of the known Triton modules.")
51+
raise ImportError(
52+
"TensorDescriptor class not found in any of the known Triton modules."
53+
)
4654

4755

4856
@functools.cache

helion/_compiler/output_header.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from .. import exc
66
from .ast_read_writes import ReadWrites
7-
from helion._compat import get_triton_tensor_descriptor_import_path
7+
from helion._compat import get_triton_tensor_descriptor_class_import_path
88
from helion._compat import supports_tensor_descriptor
99

1010
if TYPE_CHECKING:
@@ -24,7 +24,9 @@
2424
}
2525

2626
if supports_tensor_descriptor():
27-
library_imports["TensorDescriptor"] = get_triton_tensor_descriptor_import_path()
27+
library_imports["TensorDescriptor"] = (
28+
get_triton_tensor_descriptor_class_import_path()
29+
)
2830

2931
disallowed_names: dict[str, None] = dict.fromkeys(
3032
[

helion/runtime/kernel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from torch._inductor.codecache import PyCodeCache
1414

1515
from .. import exc
16+
from .._compat import get_triton_tensor_descriptor_class
17+
from .._compat import supports_tensor_descriptor
1618
from .._compiler.ast_extension import unparse
1719
from .._compiler.compile_environment import CompileEnvironment
1820
from .._compiler.generate_ast import generate_ast
@@ -503,6 +505,10 @@ def _find_device(args: tuple[object, ...]) -> torch.device:
503505
for arg in args:
504506
if isinstance(arg, torch.Tensor):
505507
return arg.device
508+
if supports_tensor_descriptor() and isinstance(
509+
arg, get_triton_tensor_descriptor_class()
510+
):
511+
return arg.base.device # pyre-ignore[16]
506512
if isinstance(arg, (tuple, list)):
507513
for item in arg:
508514
try:

test/test_matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import helion
1010
from helion import Config
11-
from helion._compat import get_triton_tensor_descriptor_import_path
11+
from helion._compat import get_triton_tensor_descriptor_class_import_path
1212
from helion._compat import supports_tensor_descriptor
1313
from helion._testing import DEVICE
1414
from helion._testing import code_and_output
@@ -355,7 +355,7 @@ def test_matmul_tensor_descriptor(self):
355355
import torch
356356
import triton
357357
import triton.language as tl
358-
{get_triton_tensor_descriptor_import_path()}
358+
{get_triton_tensor_descriptor_class_import_path()}
359359
360360
@triton.jit
361361
def _matmul_kernel(x_desc, y_desc, out_desc, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):

0 commit comments

Comments
 (0)