@@ -66,7 +66,7 @@ def _device_loop_3d_kernel(x, out, out_stride_0, out_stride_1, out_stride_2, out
66
66
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
67
67
mask_1 = indices_1 < b
68
68
for offset_3 in tl.range(0, d.to(tl.int32), step=1):
69
- indices_3 = offset_3 + tl.arange(0, 1).to( tl.int32)
69
+ indices_3 = offset_3 + tl.zeros([1], tl.int32)
70
70
load = tl.load(x + (indices_0[:, None, None, None] * x_stride_0 + indices_1[None, :, None, None] * x_stride_1 + indices_2[None, None, :, None] * x_stride_2 + indices_3[None, None, None, :] * x_stride_3), mask_0[:, None, None, None] & mask_1[None, :, None, None] & mask_2[None, None, :, None], other=0)
71
71
v_0 = tl_math.sin(load)
72
72
tl.store(out + (indices_0[:, None, None, None] * out_stride_0 + indices_1[None, :, None, None] * out_stride_1 + indices_2[None, None, :, None] * out_stride_2 + indices_3[None, None, None, :] * out_stride_3), v_0, mask_0[:, None, None, None] & mask_1[None, :, None, None] & mask_2[None, None, :, None])
@@ -197,7 +197,7 @@ def _chebyshev_kernel_kernel(x, w, out, out_stride_0, out_stride_1, w_stride_0,
197
197
v_3 = 2.0
198
198
v_4 = in_x * v_3
199
199
for offset_2 in tl.range(2, 5, step=1):
200
- indices_2 = offset_2 + tl.arange(0, 1).to( tl.int32)
200
+ indices_2 = offset_2 + tl.zeros([1], tl.int32)
201
201
v_4_copy = v_4
202
202
in_x_0_copy = in_x_0
203
203
T0_copy = T0
@@ -245,13 +245,13 @@ import triton
245
245
import triton.language as tl
246
246
247
247
@triton.jit
248
- def _fn_kernel(x, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
248
+ def _fn_kernel(x, end, out, x_size_0, end_stride_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
249
249
pid_0 = tl.program_id(0)
250
250
offset_1 = pid_0 * _BLOCK_SIZE_1
251
251
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
252
252
mask_1 = indices_1 < x_size_0
253
253
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
254
- load = tl.load(end + tl.zeros([], tl.int32) , None)
254
+ load = tl.load(end + 0 * end_stride_0 , None)
255
255
for offset_0 in tl.range(0, load.to(tl.int32), step=_BLOCK_SIZE_0):
256
256
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
257
257
mask_0 = indices_0 < load
@@ -267,7 +267,7 @@ def fn(x: torch.Tensor, end: torch.Tensor):
267
267
bs = 32
268
268
_BLOCK_SIZE_1 = 32
269
269
_BLOCK_SIZE_0 = 32
270
- _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
270
+ _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, end, out, x.size(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
271
271
return out
272
272
273
273
def _fn_make_precompiler(x: torch.Tensor, end: torch.Tensor):
@@ -276,7 +276,7 @@ def _fn_make_precompiler(x: torch.Tensor, end: torch.Tensor):
276
276
_BLOCK_SIZE_1 = 32
277
277
_BLOCK_SIZE_0 = 32
278
278
from helion.runtime.precompile_shim import make_precompiler
279
- return make_precompiler(_fn_kernel)(x, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
279
+ return make_precompiler(_fn_kernel)(x, end, out, x.size(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
280
280
281
281
--- assertExpectedJournal(TestLoops.test_data_dependent_bounds2)
282
282
from __future__ import annotations
@@ -286,13 +286,13 @@ import triton
286
286
import triton.language as tl
287
287
288
288
@triton.jit
289
- def _fn_kernel(x, end, out, out_size_0, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
289
+ def _fn_kernel(x, end, out, out_size_0, x_size_0, end_stride_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
290
290
pid_0 = tl.program_id(0)
291
291
offset_0 = pid_0 * _BLOCK_SIZE_0
292
292
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
293
293
mask_0 = indices_0 < x_size_0
294
294
acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
295
- load = tl.load(end + tl.zeros([], tl.int32) , None)
295
+ load = tl.load(end + 0 * end_stride_0 , None)
296
296
for offset_1 in tl.range(0, load.to(tl.int32), step=_BLOCK_SIZE_1):
297
297
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
298
298
mask_1 = indices_1 < load
@@ -307,15 +307,15 @@ def fn(x: torch.Tensor, end: torch.Tensor):
307
307
out = x.new_empty([x.size(0)])
308
308
_BLOCK_SIZE_0 = 32
309
309
_BLOCK_SIZE_1 = 32
310
- _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end, out, out.size(0), x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
310
+ _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end, out, out.size(0), x.size(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
311
311
return out
312
312
313
313
def _fn_make_precompiler(x: torch.Tensor, end: torch.Tensor):
314
314
out = x.new_empty([x.size(0)])
315
315
_BLOCK_SIZE_0 = 32
316
316
_BLOCK_SIZE_1 = 32
317
317
from helion.runtime.precompile_shim import make_precompiler
318
- return make_precompiler(_fn_kernel)(x, end, out, out.size(0), x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
318
+ return make_precompiler(_fn_kernel)(x, end, out, out.size(0), x.size(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
319
319
320
320
--- assertExpectedJournal(TestLoops.test_data_dependent_bounds3)
321
321
from __future__ import annotations
@@ -325,14 +325,14 @@ import triton
325
325
import triton.language as tl
326
326
327
327
@triton.jit
328
- def _fn_kernel(x, end0, end1, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, x_stride_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
328
+ def _fn_kernel(x, end0, end1, out, x_size_0, end0_stride_0, end1_stride_0, out_stride_0, x_stride_0, x_stride_1, x_stride_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
329
329
pid_0 = tl.program_id(0)
330
330
offset_0 = pid_0 * _BLOCK_SIZE_0
331
331
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
332
332
mask_0 = indices_0 < x_size_0
333
333
acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float64)
334
- load = tl.load(end0 + tl.zeros([], tl.int32) , None)
335
- load_1 = tl.load(end1 + tl.zeros([], tl.int32) , None)
334
+ load = tl.load(end0 + 0 * end0_stride_0 , None)
335
+ load_1 = tl.load(end1 + 0 * end1_stride_0 , None)
336
336
for offset_1 in tl.range(0, load.to(tl.int32), step=_BLOCK_SIZE_1):
337
337
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
338
338
mask_1 = indices_1 < load
@@ -352,7 +352,7 @@ def fn(x: torch.Tensor, end0: torch.Tensor, end1: torch.Tensor):
352
352
_BLOCK_SIZE_0 = 32
353
353
_BLOCK_SIZE_2 = 32
354
354
_BLOCK_SIZE_1 = 32
355
- _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end0, end1, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
355
+ _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end0, end1, out, x.size(0), end0.stride(0), end1.stride(0), out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
356
356
return out
357
357
358
358
def _fn_make_precompiler(x: torch.Tensor, end0: torch.Tensor, end1: torch.Tensor):
@@ -361,7 +361,7 @@ def _fn_make_precompiler(x: torch.Tensor, end0: torch.Tensor, end1: torch.Tensor
361
361
_BLOCK_SIZE_2 = 32
362
362
_BLOCK_SIZE_1 = 32
363
363
from helion.runtime.precompile_shim import make_precompiler
364
- return make_precompiler(_fn_kernel)(x, end0, end1, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
364
+ return make_precompiler(_fn_kernel)(x, end0, end1, out, x.size(0), end0.stride(0), end1.stride(0), out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
365
365
366
366
--- assertExpectedJournal(TestLoops.test_data_dependent_bounds4)
367
367
from __future__ import annotations
@@ -371,14 +371,14 @@ import triton
371
371
import triton.language as tl
372
372
373
373
@triton.jit
374
- def _fn_kernel(x, begin, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
374
+ def _fn_kernel(x, begin, end, out, x_size_0, begin_stride_0, end_stride_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
375
375
pid_0 = tl.program_id(0)
376
376
offset_1 = pid_0 * _BLOCK_SIZE_1
377
377
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
378
378
mask_1 = indices_1 < x_size_0
379
379
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
380
- load = tl.load(begin + tl.zeros([], tl.int32) , None)
381
- load_1 = tl.load(end + tl.zeros([], tl.int32) , None)
380
+ load = tl.load(begin + 0 * begin_stride_0 , None)
381
+ load_1 = tl.load(end + 0 * end_stride_0 , None)
382
382
for offset_0 in tl.range(load.to(tl.int32), load_1.to(tl.int32), step=_BLOCK_SIZE_0):
383
383
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
384
384
mask_0 = indices_0 < load_1
@@ -394,7 +394,7 @@ def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor):
394
394
bs = 32
395
395
_BLOCK_SIZE_1 = 32
396
396
_BLOCK_SIZE_0 = 32
397
- _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
397
+ _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, begin, end, out, x.size(0), begin.stride(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
398
398
return out
399
399
400
400
def _fn_make_precompiler(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor):
@@ -403,7 +403,7 @@ def _fn_make_precompiler(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor
403
403
_BLOCK_SIZE_1 = 32
404
404
_BLOCK_SIZE_0 = 32
405
405
from helion.runtime.precompile_shim import make_precompiler
406
- return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
406
+ return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), begin.stride(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
407
407
408
408
--- assertExpectedJournal(TestLoops.test_data_dependent_bounds5)
409
409
from __future__ import annotations
@@ -413,14 +413,14 @@ import triton
413
413
import triton.language as tl
414
414
415
415
@triton.jit
416
- def _fn_kernel(x, begin, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
416
+ def _fn_kernel(x, begin, end, out, x_size_0, begin_stride_0, end_stride_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
417
417
pid_0 = tl.program_id(0)
418
418
offset_0 = pid_0 * _BLOCK_SIZE_0
419
419
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
420
420
mask_0 = indices_0 < x_size_0
421
421
acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
422
- load = tl.load(begin + tl.zeros([], tl.int32) , None)
423
- load_1 = tl.load(end + tl.zeros([], tl.int32) , None)
422
+ load = tl.load(begin + 0 * begin_stride_0 , None)
423
+ load_1 = tl.load(end + 0 * end_stride_0 , None)
424
424
for offset_1 in tl.range(load.to(tl.int32), load_1.to(tl.int32), step=_BLOCK_SIZE_1):
425
425
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
426
426
mask_1 = indices_1 < load_1
@@ -435,15 +435,15 @@ def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor):
435
435
out = x.new_empty([x.size(0)])
436
436
_BLOCK_SIZE_0 = 32
437
437
_BLOCK_SIZE_1 = 32
438
- _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
438
+ _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, begin, end, out, x.size(0), begin.stride(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
439
439
return out
440
440
441
441
def _fn_make_precompiler(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor):
442
442
out = x.new_empty([x.size(0)])
443
443
_BLOCK_SIZE_0 = 32
444
444
_BLOCK_SIZE_1 = 32
445
445
from helion.runtime.precompile_shim import make_precompiler
446
- return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
446
+ return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), begin.stride(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
447
447
448
448
--- assertExpectedJournal(TestLoops.test_l2_grouping_with_register_block_size)
449
449
from __future__ import annotations
0 commit comments