Skip to content

Commit 2f600d2

Browse files
janselpytorchmergebot
authored andcommitted
Codegen explicit broadcasts when ranks don't match (#51)
Pull Request resolved: #51 Approved by: https://github.com/oulgen, https://github.com/yf225 ghstack dependencies: #50
1 parent 4165a69 commit 2f600d2

File tree

2 files changed

+65
-4
lines changed

2 files changed

+65
-4
lines changed

helion/_compiler/inductor_lowering.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,11 +274,20 @@ class InductorLowering(Lowering):
274274
input_names: list[str]
275275

276276
def input_asts(self, ctx: GraphInterpreter, node: torch.fx.Node) -> list[ast.AST]:
277+
def visit(n: torch.fx.Node) -> None:
278+
ast_val = ctx.env[n]
279+
if isinstance(fake_val := n.meta["val"], torch.Tensor):
280+
if fake_val.ndim < ndim:
281+
# Broadcast to force ranks to match
282+
expand = ["None"] * (ndim - fake_val.ndim) + [":"] * fake_val.ndim
283+
ast_val = expr_from_string(
284+
"tensor[" + ", ".join(expand) + "]", tensor=ast_val
285+
)
286+
input_asts.append(ast_val)
287+
288+
ndim: int = max([x.ndim for x in self.input_fake_tensors(node)] or (0,))
277289
input_asts: list[ast.AST] = []
278-
map_arg(
279-
(node.args, node.kwargs),
280-
lambda arg: input_asts.append(ctx.env[arg]),
281-
)
290+
map_arg((node.args, node.kwargs), visit)
282291
assert len(input_asts) == len(self.input_names)
283292
return input_asts
284293

test/test_broadcasting.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,58 @@ def _fn_make_precompiler(a, idx1):
357357
return make_precompiler(_fn_kernel)(a, out0, out1, out2, a.size(0), a.size(1), a.stride(0), a.stride(1), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), out2.stride(0), out2.stride(1), idx1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""",
358358
)
359359

360+
def test_implicit_broadcast(self):
361+
@helion.kernel
362+
def fn(a, b):
363+
out = torch.empty_like(a)
364+
for tile0, tile1 in hl.tile(a.size()):
365+
out[tile0, tile1] = a[tile0, tile1] + b[tile1]
366+
return out
367+
368+
args = (torch.randn(512, 512, device=DEVICE), torch.randn(512, device=DEVICE))
369+
code, out = code_and_output(fn, args, block_size=[16, 16])
370+
torch.testing.assert_close(out, sum(args))
371+
self.assertExpectedInline(
372+
code,
373+
"""\
374+
from __future__ import annotations
375+
376+
import torch
377+
import triton
378+
import triton.language as tl
379+
380+
@triton.jit
381+
def _fn_kernel(a, b, out, a_size_0, a_size_1, a_stride_0, a_stride_1, b_stride_0, out_stride_0, out_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
382+
num_blocks_0 = tl.cdiv(a_size_0, _BLOCK_SIZE_0)
383+
pid_0 = tl.program_id(0) % num_blocks_0
384+
pid_1 = tl.program_id(0) // num_blocks_0
385+
offset_0 = pid_0 * _BLOCK_SIZE_0
386+
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
387+
mask_0 = indices_0 < a_size_0
388+
offset_1 = pid_1 * _BLOCK_SIZE_1
389+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
390+
mask_1 = indices_1 < a_size_1
391+
load = tl.load(a + (indices_0[:, None] * a_stride_0 + indices_1[None, :] * a_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
392+
load_1 = tl.load(b + indices_1 * b_stride_0, mask_1, other=0)
393+
v_0 = load_1[None, :]
394+
v_1 = load + v_0
395+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
396+
397+
def fn(a, b):
398+
out = torch.empty_like(a)
399+
_BLOCK_SIZE_0 = 16
400+
_BLOCK_SIZE_1 = 16
401+
_fn_kernel[triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),](a, b, out, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out.stride(0), out.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
402+
return out
403+
404+
def _fn_make_precompiler(a, b):
405+
out = torch.empty_like(a)
406+
_BLOCK_SIZE_0 = 16
407+
_BLOCK_SIZE_1 = 16
408+
from helion.runtime.precompile_shim import make_precompiler
409+
return make_precompiler(_fn_kernel)(a, b, out, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out.stride(0), out.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""",
410+
)
411+
360412

361413
if __name__ == "__main__":
362414
unittest.main()

0 commit comments

Comments
 (0)