Skip to content

Commit 46b617d

Browse files
authored
Codegen if tl.sum(one_elem_tensor): instead of if one_elem_tensor (#158)
1 parent f2a137b commit 46b617d

File tree

3 files changed

+74
-7
lines changed

3 files changed

+74
-7
lines changed

helion/_compiler/device_ir.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from .ast_extension import LoopType
3636
from .ast_extension import NodeVisitor
3737
from .ast_extension import create
38+
from .ast_extension import expr_from_string
3839
from .ast_read_writes import ReadWrites
3940
from .compile_environment import CompileEnvironment
4041
from .host_function import HostFunction
@@ -232,6 +233,14 @@ def name(self) -> str:
232233

233234
def codegen(self, state: CodegenState) -> list[object]:
234235
test = state.ast_arg(0)
236+
237+
test_proxy = state.proxy_arg(0)
238+
if isinstance(test_proxy, torch.Tensor) and test_proxy.numel() == 1:
239+
# Triton does not support `if one_elem_tensor:` but supports `if scalar:`,
240+
# so we need to use tl.sum to extract the scalar.
241+
test_code = ast.unparse(test)
242+
test = expr_from_string(f"tl.sum({test_code})")
243+
235244
args = state.ast_args[2]
236245
assert isinstance(args, list)
237246
assert all(isinstance(x, ast.AST) for x in args)

test/test_control_flow.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,70 @@ def _fn_make_precompiler(x, v):
8686
return make_precompiler(_fn_kernel)(x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), v, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""",
8787
)
8888

89+
def test_if_arg_one_element_tensor(self):
90+
@helion.kernel
91+
def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
92+
output = torch.zeros_like(x)
93+
94+
for idx in hl.grid(x.shape[0]):
95+
# Since `y[idx]` is a one-element tensor, comparing it against 0 will also create a one-element tensor.
96+
if y[idx] != 0:
97+
output[idx] = x[idx] * 2
98+
if (
99+
y[idx] == 0
100+
): # TODO(yf225): `else:` raises MLIR error in Triton, so we use a second if.
101+
output[idx] = x[idx]
102+
103+
return output
104+
105+
x = torch.tensor([1.0, 2.0, 3.0, 4.0], device=DEVICE)
106+
y = torch.tensor([0, 1, 0, 1], device=DEVICE, dtype=torch.int32)
107+
expected = torch.tensor([1.0, 4.0, 3.0, 8.0], device=DEVICE)
108+
code, result = code_and_output(
109+
fn,
110+
(x, y),
111+
)
112+
torch.testing.assert_close(result, expected)
113+
self.assertExpectedInline(
114+
code,
115+
"""\
116+
from __future__ import annotations
117+
118+
import torch
119+
import triton
120+
import triton.language as tl
121+
122+
@triton.jit
123+
def _fn_kernel(x, y, output, output_stride_0, x_stride_0, y_stride_0):
124+
pid_0 = tl.program_id(0)
125+
offset_0 = pid_0
126+
indices_0 = offset_0 + tl.zeros([1], tl.int32)
127+
load = tl.load(y + indices_0 * y_stride_0, None)
128+
v_0 = tl.full([], 0, tl.int32)
129+
v_1 = load != v_0
130+
if tl.sum(v_1):
131+
load_1 = tl.load(x + indices_0 * x_stride_0, None)
132+
v_2 = 2.0
133+
v_3 = load_1 * v_2
134+
tl.store(output + indices_0 * output_stride_0, v_3, None)
135+
load_2 = tl.load(y + indices_0 * y_stride_0, None)
136+
v_4 = tl.full([], 0, tl.int32)
137+
v_5 = load_2 == v_4
138+
if tl.sum(v_5):
139+
load_3 = tl.load(x + indices_0 * x_stride_0, None)
140+
tl.store(output + indices_0 * output_stride_0, load_3, None)
141+
142+
def fn(x: torch.Tensor, y: torch.Tensor):
143+
output = torch.zeros_like(x)
144+
_fn_kernel[x.size(0),](x, y, output, output.stride(0), x.stride(0), y.stride(0), num_warps=4, num_stages=3)
145+
return output
146+
147+
def _fn_make_precompiler(x: torch.Tensor, y: torch.Tensor):
148+
output = torch.zeros_like(x)
149+
from helion.runtime.precompile_shim import make_precompiler
150+
return make_precompiler(_fn_kernel)(x, y, output, output.stride(0), x.stride(0), y.stride(0), num_warps=4, num_stages=3)""",
151+
)
152+
89153
def test_constant_true(self):
90154
@helion.kernel(
91155
config={

test/test_examples.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from expecttest import TestCase
77
from packaging import version
88
import torch
9-
from torch._environment import is_fbcode
109

1110
from helion._testing import DEVICE
1211
from helion._testing import code_and_output
@@ -1627,11 +1626,6 @@ def _jagged_dense_add_2d_make_precompiler(x_data: torch.Tensor, x_offsets: torch
16271626
return make_precompiler(_jagged_dense_add_2d_kernel)(x_offsets, x_data, y, out, out.size(0), out.size(1), x_offsets.size(0), y.size(0), y.size(1), out.stride(0), out.stride(1), x_data.stride(0), x_offsets.stride(0), y.stride(0), y.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=8, num_stages=4)""",
16281627
)
16291628

1630-
@unittest.skipIf(
1631-
"RTX 30" in torch.cuda.get_device_name(0),
1632-
"Triton internal error on RTX 30XX series",
1633-
)
1634-
@unittest.skipIf(is_fbcode(), "Triton internal error on fbcode Triton pin")
16351629
def test_moe_matmul_ogs(self):
16361630
mod = import_path(examples_dir / "moe_matmul_ogs.py")
16371631

@@ -1670,7 +1664,7 @@ def _moe_matmul_ogs_kernel(expert_token_offsets, expert_token_counts, sorted_to_
16701664
num_tokens = tl.load(expert_token_counts + indices_0 * expert_token_counts_stride_0, None)
16711665
v_0 = tl.full([], 0, tl.int32)
16721666
v_1 = num_tokens != v_0
1673-
if v_1:
1667+
if tl.sum(v_1):
16741668
num_tokens_copy = num_tokens
16751669
start_copy = start
16761670
for offset_1 in range(0, max_T_per_expert.to(tl.int32), _BLOCK_SIZE_1):

0 commit comments

Comments
 (0)