Skip to content

Commit 4ab101e

Browse files
committed
dense bmm support
ghstack-source-id: f7286f4 Pull-Request-resolved: #39
1 parent 85ae546 commit 4ab101e

File tree

4 files changed

+138
-5
lines changed

4 files changed

+138
-5
lines changed

examples/bmm.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import helion
6+
import helion.language as hl
7+
8+
9+
# static_shapes=True gives a performance boost for matmuls
10+
@helion.kernel(static_shapes=True)
11+
def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
12+
# A: [B, M, K], B: [B, K, N], Out: [B, M, N] # dense bmm
13+
b, m, k = A.size()
14+
b, k, n = B.size()
15+
out = torch.empty(
16+
[b, m, n], device=A.device, dtype=torch.promote_types(A.dtype, B.dtype)
17+
)
18+
for tile_b, tile_m, tile_n in hl.tile([b, m, n]):
19+
acc = hl.zeros([tile_b, tile_m, tile_n], dtype=torch.float32)
20+
for tile_k in hl.tile(k):
21+
acc = torch.baddbmm(
22+
acc, A[tile_b, tile_m, tile_k], B[tile_b, tile_k, tile_n]
23+
)
24+
out[tile_b, tile_m, tile_n] = acc
25+
return out
26+
27+
28+
def check(b: int, m: int, k: int, n: int) -> None:
29+
from triton.testing import do_bench
30+
31+
x = torch.randn([b, m, k], device="cuda", dtype=torch.float16)
32+
y = torch.randn([b, k, n], device="cuda", dtype=torch.float16)
33+
result = bmm(x, y)
34+
torch.testing.assert_close(result, x @ y, rtol=1e-2, atol=1e-1)
35+
sec = do_bench(lambda: bmm(x, y))
36+
baseline_sec = do_bench(lambda: torch.bmm(x, y))
37+
print(
38+
f"Helion time: {sec:.4f}s, torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x"
39+
)
40+
41+
42+
if __name__ == "__main__":
43+
check(16, 512, 768, 1024)

helion/_compiler/inductor_lowering.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -584,8 +584,12 @@ def apply_dot_requirements(handler: CodegenHandler, node: torch.fx.Node) -> Lowe
584584
lproxy, rproxy = map_arg(node.args[-2:], lambda arg: arg.meta["val"])
585585
assert isinstance(lproxy, torch.Tensor)
586586
assert isinstance(rproxy, torch.Tensor)
587-
n, k = lproxy.size()
588-
_, m = rproxy.size()
587+
lshape = lproxy.size()
588+
rshape = rproxy.size()
589+
# use last two dimensions for dot (supports 2D and batched 3D tensors)
590+
n, k = lshape[-2], lshape[-1]
591+
k2, m = rshape[-2], rshape[-1]
592+
assert k == k2, f"Mismatched k dimensions for dot: {k} vs {k2}"
589593
a, b, c = min_dot_size(lproxy.device, lproxy.dtype, rproxy.dtype)
590594
env = CompileEnvironment.current()
591595
for shape, min_size in [(n, a), (k, b), (m, c)]:
@@ -625,6 +629,23 @@ def codegen_addmm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST:
625629
)
626630

627631

632+
# pyre-fixme[56]
633+
@register_lowering(torch.ops.aten.baddbmm.default, apply_dot_requirements)
634+
def codegen_baddbmm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST:
635+
assert not node.kwargs, "baddbmm kwargs not supported"
636+
acc, lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg])
637+
assert isinstance(acc, ast.AST)
638+
assert isinstance(lhs, ast.AST)
639+
assert isinstance(rhs, ast.AST)
640+
tf32 = CompileEnvironment.current().settings.dot_precision
641+
return expr_from_string(
642+
f"tl.dot(lhs, rhs, acc=acc, input_precision={tf32!r})",
643+
lhs=lhs,
644+
rhs=rhs,
645+
acc=acc,
646+
)
647+
648+
628649
class GenerateASTFromInductor(DefaultHandler):
629650
def __init__(self, cg: GenerateAST, input_name_lookup: dict[str, ast.AST]) -> None:
630651
super().__init__()

helion/runtime/kernel.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,12 @@ def __init__(self, kernel: Kernel, args: tuple[object, ...]) -> None:
247247
constexpr_args[name] = arg
248248
else:
249249
self.fake_args.append(self.env.to_fake(arg, ArgumentOrigin(name)))
250-
self.host_fn: HostFunction = HostFunction(
251-
self.kernel.fn, self.fake_args, constexpr_args
252-
)
250+
with torch.fx.experimental._config.patch( # pyre-ignore[16]
251+
skip_dtype_check_in_meta_registrations=True
252+
):
253+
self.host_fn: HostFunction = HostFunction(
254+
self.kernel.fn, self.fake_args, constexpr_args
255+
)
253256
if len(kernel.configs) == 1:
254257
self.set_config(kernel.configs[0])
255258

test/test_examples.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,72 @@ def _matmul_make_precompiler(x: torch.Tensor, y: torch.Tensor):
141141
return make_precompiler(_matmul_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)""",
142142
)
143143

144+
def test_bmm(self):
145+
args = (
146+
torch.randn([16, 512, 768], device=DEVICE, dtype=torch.float16),
147+
torch.randn([16, 768, 1024], device=DEVICE, dtype=torch.float16),
148+
)
149+
self.assertExpectedInline(
150+
run_example(
151+
"bmm",
152+
args,
153+
torch.bmm(args[0], args[1]),
154+
block_sizes=[[16, 16, 16], 16],
155+
l2_grouping=4,
156+
),
157+
"""\
158+
from __future__ import annotations
159+
160+
import torch
161+
import triton
162+
import triton.language as tl
163+
164+
@triton.jit
165+
def _bmm_kernel(A, B, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
166+
num_blocks_0 = tl.cdiv(16, _BLOCK_SIZE_0)
167+
num_blocks_1 = tl.cdiv(512, _BLOCK_SIZE_1)
168+
pid_0 = tl.program_id(0) % num_blocks_0
169+
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
170+
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
171+
offset_0 = pid_0 * _BLOCK_SIZE_0
172+
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
173+
offset_1 = pid_1 * _BLOCK_SIZE_1
174+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
175+
offset_2 = pid_2 * _BLOCK_SIZE_2
176+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
177+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
178+
for offset_3 in range(0, 768, _BLOCK_SIZE_3):
179+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
180+
acc_copy = acc
181+
load = tl.load(A + (indices_0[:, None, None] * 393216 + indices_1[None, :, None] * 768 + indices_3[None, None, :] * 1), None)
182+
load_1 = tl.load(B + (indices_0[:, None, None] * 786432 + indices_3[None, :, None] * 1024 + indices_2[None, None, :] * 1), None)
183+
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
184+
v_0 = acc.to(tl.float16)
185+
tl.store(out + (indices_0[:, None, None] * 524288 + indices_1[None, :, None] * 1024 + indices_2[None, None, :] * 1), v_0, None)
186+
187+
def bmm(A: torch.Tensor, B: torch.Tensor):
188+
b, m, k = A.size()
189+
b, k, n = B.size()
190+
out = torch.empty([b, m, n], device=A.device, dtype=torch.promote_types(A.dtype, B.dtype))
191+
_BLOCK_SIZE_0 = 16
192+
_BLOCK_SIZE_1 = 16
193+
_BLOCK_SIZE_2 = 16
194+
_BLOCK_SIZE_3 = 16
195+
_bmm_kernel[triton.cdiv(16, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1) * triton.cdiv(1024, _BLOCK_SIZE_2),](A, B, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
196+
return out
197+
198+
def _bmm_make_precompiler(A: torch.Tensor, B: torch.Tensor):
199+
b, m, k = A.size()
200+
b, k, n = B.size()
201+
out = torch.empty([b, m, n], device=A.device, dtype=torch.promote_types(A.dtype, B.dtype))
202+
_BLOCK_SIZE_0 = 16
203+
_BLOCK_SIZE_1 = 16
204+
_BLOCK_SIZE_2 = 16
205+
_BLOCK_SIZE_3 = 16
206+
from helion.runtime.precompile_shim import make_precompiler
207+
return make_precompiler(_bmm_kernel)(A, B, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)""",
208+
)
209+
144210
def test_template_via_closure0(self):
145211
bias = torch.randn([1, 1024], device=DEVICE, dtype=torch.float16)
146212
args = (

0 commit comments

Comments
 (0)