@@ -357,6 +357,58 @@ def _fn_make_precompiler(a, idx1):
357
357
return make_precompiler(_fn_kernel)(a, out0, out1, out2, a.size(0), a.size(1), a.stride(0), a.stride(1), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), out2.stride(0), out2.stride(1), idx1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""" ,
358
358
)
359
359
360
+ def test_implicit_broadcast (self ):
361
+ @helion .kernel
362
+ def fn (a , b ):
363
+ out = torch .empty_like (a )
364
+ for tile0 , tile1 in hl .tile (a .size ()):
365
+ out [tile0 , tile1 ] = a [tile0 , tile1 ] + b [tile1 ]
366
+ return out
367
+
368
+ args = (torch .randn (512 , 512 , device = DEVICE ), torch .randn (512 , device = DEVICE ))
369
+ code , out = code_and_output (fn , args , block_size = [16 , 16 ])
370
+ torch .testing .assert_close (out , sum (args ))
371
+ self .assertExpectedInline (
372
+ code ,
373
+ """\
374
+ from __future__ import annotations
375
+
376
+ import torch
377
+ import triton
378
+ import triton.language as tl
379
+
380
+ @triton.jit
381
+ def _fn_kernel(a, b, out, a_size_0, a_size_1, a_stride_0, a_stride_1, b_stride_0, out_stride_0, out_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
382
+ num_blocks_0 = tl.cdiv(a_size_0, _BLOCK_SIZE_0)
383
+ pid_0 = tl.program_id(0) % num_blocks_0
384
+ pid_1 = tl.program_id(0) // num_blocks_0
385
+ offset_0 = pid_0 * _BLOCK_SIZE_0
386
+ indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
387
+ mask_0 = indices_0 < a_size_0
388
+ offset_1 = pid_1 * _BLOCK_SIZE_1
389
+ indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
390
+ mask_1 = indices_1 < a_size_1
391
+ load = tl.load(a + (indices_0[:, None] * a_stride_0 + indices_1[None, :] * a_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
392
+ load_1 = tl.load(b + indices_1 * b_stride_0, mask_1, other=0)
393
+ v_0 = load_1[None, :]
394
+ v_1 = load + v_0
395
+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
396
+
397
+ def fn(a, b):
398
+ out = torch.empty_like(a)
399
+ _BLOCK_SIZE_0 = 16
400
+ _BLOCK_SIZE_1 = 16
401
+ _fn_kernel[triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),](a, b, out, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out.stride(0), out.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
402
+ return out
403
+
404
+ def _fn_make_precompiler(a, b):
405
+ out = torch.empty_like(a)
406
+ _BLOCK_SIZE_0 = 16
407
+ _BLOCK_SIZE_1 = 16
408
+ from helion.runtime.precompile_shim import make_precompiler
409
+ return make_precompiler(_fn_kernel)(a, b, out, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out.stride(0), out.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""" ,
410
+ )
411
+
360
412
361
413
if __name__ == "__main__" :
362
414
unittest .main ()
0 commit comments