@@ -86,6 +86,70 @@ def _fn_make_precompiler(x, v):
86
86
return make_precompiler(_fn_kernel)(x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), v, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""" ,
87
87
)
88
88
89
+ def test_if_arg_one_element_tensor (self ):
90
+ @helion .kernel
91
+ def fn (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
92
+ output = torch .zeros_like (x )
93
+
94
+ for idx in hl .grid (x .shape [0 ]):
95
+ # Since `y[idx]` is a one-element tensor, comparing it against 0 will also create a one-element tensor.
96
+ if y [idx ] != 0 :
97
+ output [idx ] = x [idx ] * 2
98
+ if (
99
+ y [idx ] == 0
100
+ ): # TODO(yf225): `else:` raises MLIR error in Triton, so we use a second if.
101
+ output [idx ] = x [idx ]
102
+
103
+ return output
104
+
105
+ x = torch .tensor ([1.0 , 2.0 , 3.0 , 4.0 ], device = DEVICE )
106
+ y = torch .tensor ([0 , 1 , 0 , 1 ], device = DEVICE , dtype = torch .int32 )
107
+ expected = torch .tensor ([1.0 , 4.0 , 3.0 , 8.0 ], device = DEVICE )
108
+ code , result = code_and_output (
109
+ fn ,
110
+ (x , y ),
111
+ )
112
+ torch .testing .assert_close (result , expected )
113
+ self .assertExpectedInline (
114
+ code ,
115
+ """\
116
+ from __future__ import annotations
117
+
118
+ import torch
119
+ import triton
120
+ import triton.language as tl
121
+
122
+ @triton.jit
123
+ def _fn_kernel(x, y, output, output_stride_0, x_stride_0, y_stride_0):
124
+ pid_0 = tl.program_id(0)
125
+ offset_0 = pid_0
126
+ indices_0 = offset_0 + tl.zeros([1], tl.int32)
127
+ load = tl.load(y + indices_0 * y_stride_0, None)
128
+ v_0 = tl.full([], 0, tl.int32)
129
+ v_1 = load != v_0
130
+ if tl.sum(v_1):
131
+ load_1 = tl.load(x + indices_0 * x_stride_0, None)
132
+ v_2 = 2.0
133
+ v_3 = load_1 * v_2
134
+ tl.store(output + indices_0 * output_stride_0, v_3, None)
135
+ load_2 = tl.load(y + indices_0 * y_stride_0, None)
136
+ v_4 = tl.full([], 0, tl.int32)
137
+ v_5 = load_2 == v_4
138
+ if tl.sum(v_5):
139
+ load_3 = tl.load(x + indices_0 * x_stride_0, None)
140
+ tl.store(output + indices_0 * output_stride_0, load_3, None)
141
+
142
+ def fn(x: torch.Tensor, y: torch.Tensor):
143
+ output = torch.zeros_like(x)
144
+ _fn_kernel[x.size(0),](x, y, output, output.stride(0), x.stride(0), y.stride(0), num_warps=4, num_stages=3)
145
+ return output
146
+
147
+ def _fn_make_precompiler(x: torch.Tensor, y: torch.Tensor):
148
+ output = torch.zeros_like(x)
149
+ from helion.runtime.precompile_shim import make_precompiler
150
+ return make_precompiler(_fn_kernel)(x, y, output, output.stride(0), x.stride(0), y.stride(0), num_warps=4, num_stages=3)""" ,
151
+ )
152
+
89
153
def test_constant_true (self ):
90
154
@helion .kernel (
91
155
config = {
0 commit comments