@@ -368,6 +368,39 @@ def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor, *, _launcher=_de
368
368
_launcher(_fn_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
369
369
return out
370
370
371
+ --- assertExpectedJournal(TestLoops.test_full_with_dynamic_fill_value)
372
+ from __future__ import annotations
373
+
374
+ import torch
375
+ import triton
376
+ import triton.language as tl
377
+ from helion.runtime import default_launcher as _default_launcher
378
+
379
+ @triton.jit
380
+ def _kernel_with_dynamic_fill_kernel(fill_value, x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, B, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
381
+ num_blocks_0 = tl.cdiv(B, _BLOCK_SIZE_0)
382
+ pid_0 = tl.program_id(0) % num_blocks_0
383
+ pid_1 = tl.program_id(0) // num_blocks_0
384
+ offset_0 = pid_0 * _BLOCK_SIZE_0
385
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
386
+ mask_0 = indices_0 < B
387
+ offset_1 = pid_1 * _BLOCK_SIZE_1
388
+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
389
+ mask_1 = indices_1 < C
390
+ load = tl.load(fill_value + tl.zeros([], tl.int32), None)
391
+ filled = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], load, tl.float32)
392
+ load_1 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
393
+ v_0 = load_1 + filled
394
+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_0, mask_0[:, None] & mask_1[None, :])
395
+
396
+ def kernel_with_dynamic_fill(x: torch.Tensor, fill_value: torch.Tensor, *, _launcher=_default_launcher):
397
+ B, C = x.shape
398
+ out = torch.empty_like(x)
399
+ _BLOCK_SIZE_0 = 4
400
+ _BLOCK_SIZE_1 = 8
401
+ _launcher(_kernel_with_dynamic_fill_kernel, (triton.cdiv(B, _BLOCK_SIZE_0) * triton.cdiv(C, _BLOCK_SIZE_1),), fill_value, x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), B, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
402
+ return out
403
+
371
404
--- assertExpectedJournal(TestLoops.test_l2_grouping_3d)
372
405
from __future__ import annotations
373
406
0 commit comments