|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import unittest |
| 4 | + |
3 | 5 | from expecttest import TestCase
|
4 | 6 | import pytest
|
5 | 7 | import torch
|
@@ -53,3 +55,66 @@ def add3(x, y):
|
53 | 55 | code_and_output(add2, (x, x))
|
54 | 56 |
|
55 | 57 | code_and_output(add3, (x, x))
|
| 58 | + |
| 59 | + def test_inputs(self): |
| 60 | + @helion.kernel |
| 61 | + def kernel(a_list, b_dict, b_tuple): |
| 62 | + a0, a1 = a_list |
| 63 | + b0 = b_dict["b0"] |
| 64 | + (b1,) = b_tuple |
| 65 | + c0, c1 = torch.empty_like(a0), torch.empty_like(a1) |
| 66 | + for tile in hl.tile(a0.size()): |
| 67 | + c0[tile] = a0[tile] + b0[tile] |
| 68 | + c1[tile] = a1[tile] + b1[tile] |
| 69 | + return [c0, c1] |
| 70 | + |
| 71 | + x = torch.randn(4, device=DEVICE) |
| 72 | + code, result = code_and_output(kernel, ([x, x], {"b0": x}, (x,))) |
| 73 | + torch.testing.assert_close(result[0], 2 * x) |
| 74 | + torch.testing.assert_close(result[1], 2 * x) |
| 75 | + self.assertExpectedInline( |
| 76 | + code, |
| 77 | + """\ |
| 78 | +from __future__ import annotations |
| 79 | +
|
| 80 | +import torch |
| 81 | +import triton |
| 82 | +import triton.language as tl |
| 83 | +
|
| 84 | +@triton.jit |
| 85 | +def _kernel_kernel(a0, c0, c1, a0_size_0, a0_stride_0, c0_stride_0, c1_stride_0, _BLOCK_SIZE_0: tl.constexpr): |
| 86 | + pid_0 = tl.program_id(0) |
| 87 | + offset_0 = pid_0 * _BLOCK_SIZE_0 |
| 88 | + indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) |
| 89 | + mask_0 = indices_0 < a0_size_0 |
| 90 | + load = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0) |
| 91 | + load_1 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0) |
| 92 | + v_0 = load + load_1 |
| 93 | + tl.store(c0 + indices_0 * c0_stride_0, v_0, mask_0) |
| 94 | + load_2 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0) |
| 95 | + load_3 = tl.load(a0 + indices_0 * a0_stride_0, mask_0, other=0) |
| 96 | + v_1 = load_2 + load_3 |
| 97 | + tl.store(c1 + indices_0 * c1_stride_0, v_1, mask_0) |
| 98 | +
|
| 99 | +def kernel(a_list, b_dict, b_tuple): |
| 100 | + a0, a1 = a_list |
| 101 | + b0 = b_dict['b0'] |
| 102 | + b1, = b_tuple |
| 103 | + c0, c1 = (torch.empty_like(a0), torch.empty_like(a1)) |
| 104 | + _BLOCK_SIZE_0 = 4 |
| 105 | + _kernel_kernel[triton.cdiv(a0.size(0), _BLOCK_SIZE_0),](a0, c0, c1, a0.size(0), a0.stride(0), c0.stride(0), c1.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) |
| 106 | + return [c0, c1] |
| 107 | +
|
| 108 | +def _kernel_make_precompiler(a_list, b_dict, b_tuple): |
| 109 | + a0, a1 = a_list |
| 110 | + b0 = b_dict['b0'] |
| 111 | + b1, = b_tuple |
| 112 | + c0, c1 = (torch.empty_like(a0), torch.empty_like(a1)) |
| 113 | + _BLOCK_SIZE_0 = 4 |
| 114 | + from helion.runtime.precompile_shim import make_precompiler |
| 115 | + return make_precompiler(_kernel_kernel)(a0, c0, c1, a0.size(0), a0.stride(0), c0.stride(0), c1.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""", |
| 116 | + ) |
| 117 | + |
| 118 | + |
| 119 | +if __name__ == "__main__": |
| 120 | + unittest.main() |
0 commit comments