@@ -53,6 +53,37 @@ def _kernel_make_precompiler(a_list, b_dict, b_tuple, c_named_tuple, d_dataclass
53
53
from helion.runtime.precompile_shim import make_precompiler
54
54
return make_precompiler(_kernel_kernel)(a0, o0, o1, a0.size(0), a0.stride(0), o0.stride(0), o1.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
55
55
56
+ --- assertExpectedJournal(TestMisc.test_scalar_tensor_item_method)
57
+ from __future__ import annotations
58
+
59
+ import torch
60
+ import triton
61
+ import triton.language as tl
62
+
63
+ @triton.jit
64
+ def _kernel_with_scalar_item_kernel(x, result, x_size_0, result_stride_0, x_stride_0, scalar_val, _BLOCK_SIZE_0: tl.constexpr):
65
+ pid_0 = tl.program_id(0)
66
+ offset_0 = pid_0 * _BLOCK_SIZE_0
67
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
68
+ mask_0 = indices_0 < x_size_0
69
+ load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
70
+ v_0 = load + scalar_val
71
+ tl.store(result + indices_0 * result_stride_0, v_0, mask_0)
72
+
73
+ def kernel_with_scalar_item(x: torch.Tensor, scalar_tensor: torch.Tensor):
74
+ result = torch.empty_like(x)
75
+ scalar_val = scalar_tensor.item()
76
+ _BLOCK_SIZE_0 = 128
77
+ _kernel_with_scalar_item_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, result, x.size(0), result.stride(0), x.stride(0), scalar_val, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
78
+ return result
79
+
80
+ def _kernel_with_scalar_item_make_precompiler(x: torch.Tensor, scalar_tensor: torch.Tensor):
81
+ result = torch.empty_like(x)
82
+ scalar_val = scalar_tensor.item()
83
+ _BLOCK_SIZE_0 = 128
84
+ from helion.runtime.precompile_shim import make_precompiler
85
+ return make_precompiler(_kernel_with_scalar_item_kernel)(x, result, x.size(0), result.stride(0), x.stride(0), scalar_val, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
86
+
56
87
--- assertExpectedJournal(TestMisc.test_tile_block_size_constexpr_fix)
57
88
from __future__ import annotations
58
89
0 commit comments