Skip to content

Commit 00d13b6

Browse files
authored
Support dynamic fill value to hl.full (#316)
1 parent 28c3be2 commit 00d13b6

File tree

3 files changed

+75
-5
lines changed

3 files changed

+75
-5
lines changed

helion/language/creation_ops.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,27 @@ def _full_codegen(state: CodegenState) -> ast.AST:
121121
assert isinstance(fake_value, torch.Tensor)
122122
shape_str = state.device_function.tile_strategy.shape_str(fake_value.size())
123123
type_str = triton_type(fake_value.dtype)
124-
value_str = state.device_function.literal_expr(state.proxy_arg(1))
125-
return expr_from_string(f"tl.full({shape_str}, {value_str}, {type_str})")
124+
125+
# Check if the value is static (literal) or dynamic (node)
126+
proxy_value = state.proxy_arg(1)
127+
if isinstance(proxy_value, (int, float, bool)):
128+
# For static values, use literal_expr to preserve special representations like float('-inf')
129+
value_str = state.device_function.literal_expr(proxy_value)
130+
return expr_from_string(f"tl.full({shape_str}, {value_str}, {type_str})")
131+
# For dynamic values, use ast_arg to get the proper AST representation
132+
value_ast = state.ast_arg(1)
133+
return expr_from_string(f"tl.full({shape_str}, value, {type_str})", value=value_ast)
126134

127135

128136
@_decorators.get_masked_value(full)
129137
def _(
130138
node: torch.fx.Node,
131-
) -> float | bool:
139+
) -> float | bool | None:
132140
value = node.args[1]
133-
assert isinstance(value, (int, float, bool))
134-
return value
141+
if isinstance(value, (int, float, bool)):
142+
return value
143+
# Return None for dynamic values (like tensor elements)
144+
return None
135145

136146

137147
def arange(

test/test_loops.expected

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,39 @@ def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor, *, _launcher=_de
368368
_launcher(_fn_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
369369
return out
370370

371+
--- assertExpectedJournal(TestLoops.test_full_with_dynamic_fill_value)
372+
from __future__ import annotations
373+
374+
import torch
375+
import triton
376+
import triton.language as tl
377+
from helion.runtime import default_launcher as _default_launcher
378+
379+
@triton.jit
380+
def _kernel_with_dynamic_fill_kernel(fill_value, x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, B, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
381+
num_blocks_0 = tl.cdiv(B, _BLOCK_SIZE_0)
382+
pid_0 = tl.program_id(0) % num_blocks_0
383+
pid_1 = tl.program_id(0) // num_blocks_0
384+
offset_0 = pid_0 * _BLOCK_SIZE_0
385+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
386+
mask_0 = indices_0 < B
387+
offset_1 = pid_1 * _BLOCK_SIZE_1
388+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
389+
mask_1 = indices_1 < C
390+
load = tl.load(fill_value + tl.zeros([], tl.int32), None)
391+
filled = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], load, tl.float32)
392+
load_1 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
393+
v_0 = load_1 + filled
394+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_0, mask_0[:, None] & mask_1[None, :])
395+
396+
def kernel_with_dynamic_fill(x: torch.Tensor, fill_value: torch.Tensor, *, _launcher=_default_launcher):
397+
B, C = x.shape
398+
out = torch.empty_like(x)
399+
_BLOCK_SIZE_0 = 4
400+
_BLOCK_SIZE_1 = 8
401+
_launcher(_kernel_with_dynamic_fill_kernel, (triton.cdiv(B, _BLOCK_SIZE_0) * triton.cdiv(C, _BLOCK_SIZE_1),), fill_value, x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), B, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
402+
return out
403+
371404
--- assertExpectedJournal(TestLoops.test_l2_grouping_3d)
372405
from __future__ import annotations
373406

test/test_loops.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,33 @@ def add_3d_kernel_reordered(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
962962
) # Original dim 1 = second fastest varying
963963
self.assertIn("offset_0 = pid_2", code) # Original dim 0 = slowest varying
964964

965+
def test_full_with_dynamic_fill_value(self):
966+
"""Test hl.full with dynamic fill value from scalar tensor."""
967+
968+
@helion.kernel(use_default_config=True)
969+
def kernel_with_dynamic_fill(
970+
x: torch.Tensor, fill_value: torch.Tensor
971+
) -> torch.Tensor:
972+
B, C = x.shape
973+
out = torch.empty_like(x)
974+
975+
for b_tile, c_tile in hl.tile([B, C]):
976+
# Use scalar tensor as fill value
977+
filled = hl.full((b_tile, c_tile), fill_value[0], x.dtype)
978+
out[b_tile, c_tile] = x[b_tile, c_tile] + filled
979+
980+
return out
981+
982+
x = torch.randn(4, 8, device=DEVICE, dtype=torch.float32)
983+
fill_value = torch.tensor([3.5], device=DEVICE, dtype=torch.float32)
984+
985+
code, result = code_and_output(kernel_with_dynamic_fill, (x, fill_value))
986+
self.assertExpectedJournal(code)
987+
988+
# Verify correctness
989+
expected = x + fill_value[0]
990+
torch.testing.assert_close(result, expected)
991+
965992

966993
if __name__ == "__main__":
967994
unittest.main()

0 commit comments

Comments
 (0)