@@ -53,6 +53,54 @@ def _kernel_make_precompiler(a_list, b_dict, b_tuple, c_named_tuple, d_dataclass
53
53
from helion.runtime.precompile_shim import make_precompiler
54
54
return make_precompiler(_kernel_kernel)(a0, o0, o1, a0.size(0), a0.stride(0), o0.stride(0), o1.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
55
55
56
+ --- assertExpectedJournal(TestMisc.test_tile_block_size_constexpr_fix)
57
+ from __future__ import annotations
58
+
59
+ import torch
60
+ import triton
61
+ import triton.language as tl
62
+ from torch._inductor.runtime.triton_compat import libdevice
63
+
64
+ @triton.jit
65
+ def _test_tile_block_size_usage_kernel(out, x_size_0, out_stride_0, _BLOCK_SIZE_0: tl.constexpr):
66
+ pid_0 = tl.program_id(0)
67
+ offset_0 = pid_0 * _BLOCK_SIZE_0
68
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
69
+ mask_0 = indices_0 < x_size_0
70
+ _BLOCK_SIZE_0_ = _BLOCK_SIZE_0
71
+ v_0 = _BLOCK_SIZE_0_.to(tl.int32)
72
+ v_1 = indices_0 % v_0
73
+ v_2 = tl.full([], 0, tl.int32)
74
+ v_3 = v_1 != v_2
75
+ v_4 = libdevice.signbit(v_1) != 0 if v_1.dtype is tl.float32 else v_1 < 0
76
+ v_5 = libdevice.signbit(v_0) != 0 if v_0.dtype is tl.float32 else v_0 < 0
77
+ v_6 = v_4 != v_5
78
+ v_7 = v_3 & v_6
79
+ v_8 = v_1 + v_0
80
+ v_9 = tl.where(v_7, v_8, v_1)
81
+ sub = -1 + _BLOCK_SIZE_0
82
+ v_10 = sub.to(tl.int32)
83
+ v_11 = v_9 == v_10
84
+ v_12 = tl.full([], 0, tl.int64)
85
+ v_13 = tl.full([], 1, tl.int64)
86
+ v_14 = v_13[None]
87
+ v_15 = v_12[None]
88
+ v_16 = tl.where(v_11, v_14, v_15)
89
+ v_17 = v_16.to(tl.int32)
90
+ tl.store(out + indices_0 * out_stride_0, v_17, mask_0)
91
+
92
+ def test_tile_block_size_usage(x: torch.Tensor):
93
+ out = torch.zeros_like(x, dtype=torch.int32)
94
+ _BLOCK_SIZE_0 = 32
95
+ _test_tile_block_size_usage_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](out, x.size(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
96
+ return out
97
+
98
+ def _test_tile_block_size_usage_make_precompiler(x: torch.Tensor):
99
+ out = torch.zeros_like(x, dtype=torch.int32)
100
+ _BLOCK_SIZE_0 = 32
101
+ from helion.runtime.precompile_shim import make_precompiler
102
+ return make_precompiler(_test_tile_block_size_usage_kernel)(out, x.size(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
103
+
56
104
--- assertExpectedJournal(TestMisc.test_torch_alloc)
57
105
from __future__ import annotations
58
106
0 commit comments