Skip to content

Commit 96253e7

Browse files
janselpytorchmergebot
authored andcommitted
Fix compatibility with torch 2.7 (#45)
Pull Request resolved: #45 Approved by: https://github.com/oulgen, https://github.com/yf225
1 parent 4ab101e commit 96253e7

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

examples/bmm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,6 @@ def check(b: int, m: int, k: int, n: int) -> None:
4040

4141

4242
if __name__ == "__main__":
43+
# torch.baddbmm support for 16-bit tensors requires torch 2.8+
44+
assert torch.__version__.split(".")[:2] >= ["2", "8"], "Requires torch 2.8+"
4345
check(16, 512, 768, 1024)

helion/runtime/kernel.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Callable
4+
import contextlib
45
import dataclasses
56
import functools
67
import inspect
@@ -247,9 +248,7 @@ def __init__(self, kernel: Kernel, args: tuple[object, ...]) -> None:
247248
constexpr_args[name] = arg
248249
else:
249250
self.fake_args.append(self.env.to_fake(arg, ArgumentOrigin(name)))
250-
with torch.fx.experimental._config.patch( # pyre-ignore[16]
251-
skip_dtype_check_in_meta_registrations=True
252-
):
251+
with _maybe_skip_dtype_check_in_meta_registrations():
253252
self.host_fn: HostFunction = HostFunction(
254253
self.kernel.fn, self.fake_args, constexpr_args
255254
)
@@ -542,3 +541,13 @@ def _find_device(args: tuple[object, ...]) -> torch.device:
542541
except exc.NoTensorArgs:
543542
pass
544543
raise exc.NoTensorArgs
544+
545+
546+
def _maybe_skip_dtype_check_in_meta_registrations() -> (
547+
contextlib.AbstractContextManager[None, None]
548+
):
549+
if hasattr(torch.fx.experimental._config, "skip_dtype_check_in_meta_registrations"):
550+
return torch.fx.experimental._config.patch( # pyre-ignore[16]
551+
skip_dtype_check_in_meta_registrations=True
552+
)
553+
return contextlib.nullcontext()

test/test_examples.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

33
from pathlib import Path
4+
import unittest
45

56
from expecttest import TestCase
7+
from packaging import version
68
import torch
79

810
from helion._testing import DEVICE
@@ -141,6 +143,10 @@ def _matmul_make_precompiler(x: torch.Tensor, y: torch.Tensor):
141143
return make_precompiler(_matmul_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)""",
142144
)
143145

146+
@unittest.skipIf(
147+
version.parse(torch.__version__.split("+")[0]) < version.parse("2.8"),
148+
"Requires torch 2.8+",
149+
)
144150
def test_bmm(self):
145151
args = (
146152
torch.randn([16, 512, 768], device=DEVICE, dtype=torch.float16),

0 commit comments

Comments
 (0)