Skip to content

Commit 690b9a2

Browse files
committed
Add hl.signal
stack-info: PR: #233, branch: joydddd/stack/8
1 parent 90baafc commit 690b9a2

File tree

6 files changed

+275
-1
lines changed

6 files changed

+275
-1
lines changed

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .scan_ops import associative_scan as associative_scan
1616
from .scan_ops import cumprod as cumprod
1717
from .scan_ops import cumsum as cumsum
18+
from .signal_wait import signal as signal
1819
from .signal_wait import wait as wait
1920
from .tile_ops import tile_begin as tile_begin
2021
from .tile_ops import tile_block_size as tile_block_size

helion/language/signal_wait.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@
66
from torch.fx import has_side_effect
77

88
from .. import exc
9+
from .._compiler.indexing_strategy import SubscriptIndexing
910
from . import _decorators
1011

1112
if TYPE_CHECKING:
1213
import ast
1314

1415
from .._compiler.inductor_lowering import CodegenState
1516

17+
__all__ = ["signal", "wait"]
18+
1619

1720
@has_side_effect
1821
@_decorators.api(tiles_as_sizes=True)
@@ -153,3 +156,108 @@ def _(state: CodegenState) -> ast.AST:
153156
signal=signal_expr,
154157
update=update_expr,
155158
)
159+
160+
161+
@has_side_effect
162+
@_decorators.api(tiles_as_sizes=True)
163+
def signal(
164+
signal_pad: torch.Tensor,
165+
index: list[object],
166+
signal: int = 1,
167+
op: str = "atomic_xchg",
168+
sem: str = "release",
169+
scope: str = "gpu",
170+
skip_sync: bool = False,
171+
) -> torch.Tensor:
172+
"""Set the signal_pad slice to the signal value.
173+
Args:
174+
signal_pad: The signal pad to signal
175+
index: Indices to index into the signal_pad tensor
176+
signal: the value to send
177+
op: The memory op for acquring the lock (default: 'atomic_xchg')
178+
sem: The memory sematic for acquring the lock (default: 'release')
179+
scope: The scope of the lock (default: 'gpu')
180+
skip_sync: Skip the syncthreads before sending signal (default: False)
181+
"""
182+
raise exc.NotInsideKernel
183+
184+
185+
@_decorators.prepare_args(signal)
186+
def _(
187+
signal_pad: torch.Tensor,
188+
index: list[object],
189+
signal: int = 1,
190+
op: str = "atomic_xchg",
191+
sem: str = "release",
192+
scope: str = "gpu",
193+
skip_sync: bool = False,
194+
) -> tuple[torch.Tensor, object, int, str, str, str, bool]:
195+
from helion.language.tile_proxy import Tile
196+
197+
valid_ops = {"atomic_add", "atomic_xchg"}
198+
valid_sems = {"relaxed", "release", "acq_rel"}
199+
valid_scopes = {"sys", "gpu"}
200+
201+
if op not in valid_ops:
202+
raise ValueError(f"Invalid signal op '{op}'. Must be one of {valid_ops}. ")
203+
204+
if sem not in valid_sems:
205+
raise ValueError(
206+
f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}."
207+
)
208+
209+
if scope not in valid_scopes:
210+
raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.")
211+
212+
index = Tile._prepare_index(index)
213+
index = Tile._tiles_to_sizes(index)
214+
215+
return (signal_pad, index, signal, op, sem, scope, skip_sync)
216+
217+
218+
@_decorators.register_fake(signal)
219+
def _(
220+
signal_pad: torch.Tensor,
221+
index: list[object],
222+
signal: int = 1,
223+
op: str = "atomic_xchg",
224+
sem: str = "release",
225+
scope: str = "gpu",
226+
skip_sync: bool = False,
227+
) -> torch.Tensor:
228+
return signal_pad.new_empty(SubscriptIndexing.compute_shape(signal_pad, index))
229+
230+
231+
@_decorators.codegen(signal)
232+
def _(state: CodegenState) -> ast.AST:
233+
import ast
234+
235+
from .._compiler.ast_extension import expr_from_string
236+
from .._compiler.indexing_strategy import SubscriptIndexing
237+
238+
signal_pad = state.proxy_arg(0)
239+
index = state.proxy_arg(1)
240+
signal = state.proxy_arg(2)
241+
op = state.proxy_arg(3)
242+
sem = state.proxy_arg(4)
243+
scope = state.proxy_arg(5)
244+
skip_sync = state.proxy_arg(6)
245+
246+
assert isinstance(signal_pad, torch.Tensor)
247+
assert isinstance(index, list)
248+
249+
indices = SubscriptIndexing.create(state, signal_pad, index)
250+
signal_pad_name = state.device_function.tensor_arg(signal_pad).name
251+
252+
signal_expr = ast.Constant(value=signal)
253+
assert type(op) is str
254+
assert type(sem) is str
255+
assert type(scope) is str
256+
257+
hl_ext_call_triton_send_signal = f"helion.runtime.triton_send_signal(addr={signal_pad_name} + offset, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
258+
259+
return expr_from_string(
260+
hl_ext_call_triton_send_signal,
261+
offset=indices.index_expr,
262+
signal=signal_expr,
263+
)

helion/runtime/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .config import Config as Config
99
from .kernel import Kernel as Kernel
1010
from .kernel import kernel as kernel
11+
from .triton_helpers import triton_send_signal as triton_send_signal
1112
from .triton_helpers import triton_wait_signal as triton_wait_signal
1213

1314

helion/runtime/triton_helpers.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,53 @@
33
import triton
44
import triton.language as tl
55

6-
__all__ = ["triton_wait_signal"]
6+
__all__ = ["triton_send_signal", "triton_wait_multiple_signal", "triton_wait_signal"]
7+
8+
9+
@triton.jit
10+
def triton_send_signal(
11+
addr: tl.tensor,
12+
update: tl.constexpr,
13+
sem: tl.constexpr,
14+
scope: tl.constexpr,
15+
op: tl.constexpr,
16+
skip_sync: tl.constexpr,
17+
) -> None:
18+
"""
19+
Signal global memory barrier(s).
20+
21+
This function atomically sets global memory barriers to a update value,
22+
signaling to other CTAs waiting on the barrier(s).
23+
24+
Args:
25+
addr: Memory address of the barrier(s) to wait on
26+
update: Set the barrier to
27+
sem: Memory semantics for the atomic operation. Options: "release", "relaxed".
28+
scope: Scope of the atomic operation. Options: "gpu", "sys"
29+
op: Atomic operation type: "atomic_xchg", "atomic_add"
30+
skip_sync: Skip CTA synchronization before setting the barrier. (default: False)
31+
"""
32+
if not skip_sync:
33+
tl.inline_asm_elementwise(
34+
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
35+
)
36+
37+
tl.static_assert(
38+
sem == "release" or sem == "relaxed",
39+
"Invalid memory semantic. options: 'release', 'relaxed'. ",
40+
)
41+
tl.static_assert(
42+
scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu','sys'. "
43+
)
44+
45+
if op == "atomic_xchg":
46+
tl.atomic_xchg(addr, update, sem=sem, scope=scope)
47+
elif op == "atomic_add":
48+
tl.atomic_add(addr, update, sem=sem, scope=scope)
49+
else:
50+
raise NotImplementedError(
51+
f"Unsupported op '{op}' for send signal on gmem barrier. "
52+
)
753

854

955
@triton.jit
@@ -71,3 +117,17 @@ def triton_wait_signal(
71117
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
72118
)
73119
# tl.debug_barrier() cause significant performance loss. (Perhaps breaks triton prefetching?)
120+
121+
122+
@triton.jit
123+
def triton_wait_multiple_signal(
124+
addr: tl.tensor,
125+
expect: tl.constexpr, # wait until lock is set to expect
126+
update: tl.constexpr, # update the lock once it is aquired.
127+
sem: tl.constexpr,
128+
scope: tl.constexpr,
129+
op: tl.constexpr,
130+
skip_sync: tl.constexpr,
131+
) -> None:
132+
raise NotImplementedError("Waiting on multiple barriers is not implemented yet. ")
133+
# TODO(joydddd): waiting on multiple barriers at the same time whereeach thread waits on a different barrier

test/test_signal_wait.expected

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,57 @@
11
This file is automatically generated by assertExpectedJournal calls in test_signal_wait.py.
22
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.
33

4+
--- assertExpectedJournal(TestWait.test_signal_basic)
5+
from __future__ import annotations
6+
7+
import torch
8+
import helion
9+
import triton
10+
import triton.language as tl
11+
12+
@triton.jit
13+
def _gmem_signal_scalar_bar_kernel_kernel(signal_pad, signal_pad_stride_0):
14+
pid_0 = tl.program_id(0)
15+
offset_0 = pid_0
16+
helion.runtime.triton_send_signal(addr=signal_pad + offset_0 * signal_pad_stride_0, update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=False)
17+
18+
def gmem_signal_scalar_bar_kernel(signal_pad: torch.Tensor):
19+
n, = signal_pad.shape
20+
_gmem_signal_scalar_bar_kernel_kernel[n,](signal_pad, signal_pad.stride(0), num_warps=4, num_stages=3)
21+
return signal_pad
22+
23+
def _gmem_signal_scalar_bar_kernel_make_precompiler(signal_pad: torch.Tensor):
24+
n, = signal_pad.shape
25+
from helion.runtime.precompile_shim import make_precompiler
26+
return make_precompiler(_gmem_signal_scalar_bar_kernel_kernel)(signal_pad, signal_pad.stride(0), num_warps=4, num_stages=3)
27+
28+
--- assertExpectedJournal(TestWait.test_signal_multiple)
29+
from __future__ import annotations
30+
31+
import torch
32+
import helion
33+
import triton
34+
import triton.language as tl
35+
36+
@triton.jit
37+
def _gmem_signal_tensor_bar_kernel_kernel(signal_pad, signal_pad_stride_0, _BLOCK_SIZE_0: tl.constexpr):
38+
pid_0 = tl.program_id(0)
39+
offset_0 = pid_0 * _BLOCK_SIZE_0
40+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
41+
helion.runtime.triton_send_signal(addr=signal_pad + indices_0 * signal_pad_stride_0, update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=False)
42+
43+
def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor):
44+
n, = signal_pad.shape
45+
_BLOCK_SIZE_0 = 4
46+
_gmem_signal_tensor_bar_kernel_kernel[triton.cdiv(n, _BLOCK_SIZE_0),](signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
47+
return signal_pad
48+
49+
def _gmem_signal_tensor_bar_kernel_make_precompiler(signal_pad: torch.Tensor):
50+
n, = signal_pad.shape
51+
_BLOCK_SIZE_0 = 4
52+
from helion.runtime.precompile_shim import make_precompiler
53+
return make_precompiler(_gmem_signal_tensor_bar_kernel_kernel)(signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
54+
455
--- assertExpectedJournal(TestWait.test_wait_2d_tile)
556
from __future__ import annotations
657

test/test_signal_wait.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,59 @@ def wait_for_2d_tile_kernel(
5454
torch.testing.assert_close(result, x)
5555
self.assertExpectedJournal(code)
5656

57+
def test_signal_basic(self):
58+
@helion.kernel
59+
def gmem_signal_scalar_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
60+
(n,) = signal_pad.shape
61+
for i in hl.grid(n):
62+
hl.signal(signal_pad, [i], signal=1)
63+
return signal_pad
64+
65+
signal_pad = torch.zeros(4, device=DEVICE, dtype=torch.int32)
66+
code, result = code_and_output(gmem_signal_scalar_bar_kernel, (signal_pad,))
67+
torch.testing.assert_close(
68+
result, torch.ones(4, device=DEVICE, dtype=torch.int32)
69+
)
70+
self.assertExpectedJournal(code)
71+
72+
def test_signal_multiple(self):
73+
@helion.kernel
74+
def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
75+
(n,) = signal_pad.shape
76+
for tile in hl.tile(n):
77+
hl.signal(signal_pad, [tile], signal=1)
78+
return signal_pad
79+
80+
signal_pad = torch.zeros(16, device=DEVICE, dtype=torch.int32)
81+
code, result = code_and_output(
82+
gmem_signal_tensor_bar_kernel,
83+
(signal_pad,),
84+
block_size=[4],
85+
)
86+
torch.testing.assert_close(
87+
result, torch.ones(16, device=DEVICE, dtype=torch.int32)
88+
)
89+
self.assertExpectedJournal(code)
90+
91+
def test_sent_recieve_cta(self):
92+
@helion.kernel
93+
def gmem_signal_n_wait_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
94+
(n,) = signal_pad.shape
95+
for i in hl.grid(n): # first N ctas sends signal
96+
hl.signal(signal_pad, [i], signal=1)
97+
for i in hl.grid(n): # last N ctas waits for signal
98+
hl.wait(signal_pad, [i], signal=1)
99+
return signal_pad
100+
101+
signal_pad = torch.zeros(4, device=DEVICE, dtype=torch.int32)
102+
103+
code, result = code_and_output(gmem_signal_n_wait_kernel, (signal_pad,))
104+
torch.testing.assert_close(
105+
result, torch.ones(4, device=DEVICE, dtype=torch.int32)
106+
)
107+
self.assertIn("helion.runtime.triton_send_signal", code)
108+
self.assertIn("helion.runtime.triton_wait_signal", code)
109+
57110

58111
if __name__ == "__main__":
59112
unittest.main()

0 commit comments

Comments
 (0)