Skip to content

Commit 053923a

Browse files
committed
Add indirect pointer to barrier support in hl.signal & hl.wait (as_ptrs)
stack-info: PR: #261, branch: joydddd/stack/13
1 parent 41f371d commit 053923a

File tree

4 files changed

+304
-39
lines changed

4 files changed

+304
-39
lines changed

helion/language/signal_wait.py

Lines changed: 91 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@
2121
@_decorators.api(tiles_as_sizes=True)
2222
def wait(
2323
signal_pad: torch.Tensor,
24-
index: list[object],
24+
index: list[object] | None = None,
2525
signal: int = 1,
2626
update: int | None = None,
2727
op: str = "ld",
2828
sem: str = "acquire",
2929
scope: str = "gpu",
3030
skip_sync: bool = False,
31+
as_ptrs: bool = False,
3132
) -> None:
3233
"""Wait until all entries of the signal_pad slice are equal to the signal value.
3334
Args:
@@ -39,6 +40,7 @@ def wait(
3940
sem: The memory sematic for acquring the lock (default: 'acquire')
4041
scope: The scope of the lock (default: 'gpu')
4142
skip_sync: Skip the syncthreads after the wait (default: False)
43+
as_ptrs: Treat signal_pad as pointers to global memory barriers (default: False)
4244
4345
Returns:
4446
None
@@ -49,14 +51,15 @@ def wait(
4951
@_decorators.prepare_args(wait)
5052
def _(
5153
signal_pad: torch.Tensor,
52-
index: list[object],
54+
index: list[object] | None = None,
5355
signal: int = 1,
5456
update: int | None = None,
5557
op: str = "ld",
5658
sem: str = "acquire",
5759
scope: str = "gpu",
5860
skip_sync: bool = False,
59-
) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool]:
61+
as_ptrs: bool = False,
62+
) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool, bool]:
6063
from helion.language.tile_proxy import Tile
6164

6265
valid_ops = {"ld", "atomic_cas"}
@@ -88,22 +91,37 @@ def _(
8891
if scope not in valid_scopes:
8992
raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.")
9093

94+
if as_ptrs:
95+
if index is not None:
96+
raise ValueError(
97+
f"When as_ptrs=True, signal_pad must be used without indexing. "
98+
f"Expected 0 indices but got {len(index)}. "
99+
)
100+
if signal_pad.dtype not in (torch.uint64, torch.int64):
101+
raise ValueError(
102+
f"When as_ptrs=True, signal_pad must have dtype torch.uint64 or torch.int64 "
103+
f"to represent memory pointers. Got dtype {signal_pad.dtype}. "
104+
)
105+
if index is None:
106+
index = []
107+
91108
index = Tile._prepare_index(index)
92109
index = Tile._tiles_to_sizes(index)
93110

94-
return (signal_pad, index, signal, update, op, sem, scope, skip_sync)
111+
return (signal_pad, index, signal, update, op, sem, scope, skip_sync, as_ptrs)
95112

96113

97114
@_decorators.register_fake(wait)
98115
def _(
99116
signal_pad: torch.Tensor,
100-
index: list[object],
117+
index: list[object] | None = None,
101118
signal: int = 1,
102119
update: int | None = None,
103120
op: str = "ld",
104121
sem: str = "acquire",
105122
scope: str = "sys",
106123
skip_sync: bool = False,
124+
as_ptrs: bool = False,
107125
) -> None:
108126
return None
109127

@@ -123,35 +141,38 @@ def _(state: CodegenState) -> ast.AST:
123141
sem = state.proxy_arg(5)
124142
scope = state.proxy_arg(6)
125143
skip_sync = state.proxy_arg(7)
144+
as_ptrs = state.proxy_arg(8)
126145

127146
assert isinstance(signal_pad, torch.Tensor)
128147
assert isinstance(index, (list))
129148

130-
indices = SubscriptIndexing.create(state, signal_pad, index)
131-
signal_pad_name = state.device_function.tensor_arg(signal_pad).name
132-
133-
signal_expr = ast.Constant(value=signal) # pyright: ignore[reportArgumentType]
134-
update_expr = ast.Constant(value=update) # pyright: ignore[reportArgumentType]
135-
136149
assert type(op) is str
137150
assert type(sem) is str
138151
assert type(scope) is str
139152

140-
bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index)
141-
is_scalar = len(bar_tensor_shape) == 0
142-
143-
if is_scalar:
144-
call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
153+
if as_ptrs:
154+
bar_tensor_shape = signal_pad.shape
155+
bar_addrs = "signal_pad_arg.to(tl.pointer_type(tl.int32))"
145156
else:
157+
indices = SubscriptIndexing.create(state, signal_pad, index)
146158
if signal_pad.dtype not in (torch.int32, torch.uint32):
147159
raise NotImplementedError(
148160
f"Unsupported signal pad dtype: {signal_pad.dtype}. Must be of torch.int32 or torch.uint32."
149161
)
150-
call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
162+
signal_pad_name = state.device_function.tensor_arg(signal_pad).name
163+
bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index)
164+
bar_addrs = f"{signal_pad_name} + signal_pad_arg"
165+
166+
signal_expr = ast.Constant(value=signal) # pyright: ignore[reportArgumentType]
167+
update_expr = ast.Constant(value=update) # pyright: ignore[reportArgumentType]
168+
169+
is_scalar = len(bar_tensor_shape) == 0
170+
171+
call_triton_wait_signal = f"helion.runtime.triton_wait_{'' if is_scalar else 'multiple_'}signal(addr={bar_addrs}, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
151172

152173
return expr_from_string(
153174
call_triton_wait_signal,
154-
offset=indices.index_expr,
175+
signal_pad_arg=state.ast_arg(0) if as_ptrs else indices.index_expr, # pyright: ignore[reportPossiblyUnboundVariable]
155176
signal=signal_expr,
156177
update=update_expr,
157178
)
@@ -161,13 +182,14 @@ def _(state: CodegenState) -> ast.AST:
161182
@_decorators.api(tiles_as_sizes=True)
162183
def signal(
163184
signal_pad: torch.Tensor,
164-
index: list[object],
185+
index: list[object] | None = None,
165186
signal: int = 1,
166187
wait_for: int | None = None,
167188
op: str = "atomic_xchg",
168189
sem: str = "release",
169190
scope: str = "gpu",
170191
skip_sync: bool = False,
192+
as_ptrs: bool = False,
171193
) -> torch.Tensor:
172194
"""Set the signal_pad slice to the signal value.
173195
Args:
@@ -179,21 +201,25 @@ def signal(
179201
sem: The memory sematic for acquring the lock (default: 'release')
180202
scope: The scope of the lock (default: 'gpu')
181203
skip_sync: Skip the syncthreads before sending signal (default: False)
204+
as_ptrs: Treat signal_pad as pointers to global memory barriers (default: False)
205+
Returns:
206+
The old value of the signal_pad slice before the update.
182207
"""
183208
raise exc.NotInsideKernel
184209

185210

186211
@_decorators.prepare_args(signal)
187212
def _(
188213
signal_pad: torch.Tensor,
189-
index: list[object],
214+
index: list[object] | None = None,
190215
signal: int = 1,
191216
wait_for: int | None = None,
192217
op: str = "atomic_xchg",
193218
sem: str = "release",
194219
scope: str = "gpu",
195220
skip_sync: bool = False,
196-
) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool]:
221+
as_ptrs: bool = False,
222+
) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool, bool]:
197223
from helion.language.tile_proxy import Tile
198224

199225
valid_ops = {"atomic_add", "atomic_xchg", "atomic_cas"}
@@ -220,23 +246,42 @@ def _(
220246
if scope not in valid_scopes:
221247
raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.")
222248

249+
if as_ptrs:
250+
if index is not None:
251+
raise ValueError(
252+
f"When as_ptrs=True, signal_pad must be used without indexing. "
253+
f"Expected 0 indices but got {len(index)}. "
254+
)
255+
if signal_pad.dtype not in (torch.uint64, torch.int64):
256+
raise ValueError(
257+
f"When as_ptrs=True, signal_pad must have dtype torch.uint64 or torch.int64 "
258+
f"to represent memory pointers. Got dtype {signal_pad.dtype}. "
259+
)
260+
if index is None:
261+
index = []
262+
223263
index = Tile._prepare_index(index)
224264
index = Tile._tiles_to_sizes(index)
225265

226-
return (signal_pad, index, signal, wait_for, op, sem, scope, skip_sync)
266+
return (signal_pad, index, signal, wait_for, op, sem, scope, skip_sync, as_ptrs)
227267

228268

229269
@_decorators.register_fake(signal)
230270
def _(
231271
signal_pad: torch.Tensor,
232-
index: list[object],
272+
index: list[object] | None = None,
233273
signal: int = 1,
234274
wait_for: int | None = None,
235275
op: str = "atomic_xchg",
236276
sem: str = "release",
237277
scope: str = "gpu",
238278
skip_sync: bool = False,
279+
as_ptrs: bool = False,
239280
) -> torch.Tensor:
281+
if index is None:
282+
index = []
283+
if as_ptrs:
284+
return signal_pad.new_empty(signal_pad.shape)
240285
return signal_pad.new_empty(SubscriptIndexing.compute_shape(signal_pad, index))
241286

242287

@@ -255,43 +300,51 @@ def _(state: CodegenState) -> ast.AST:
255300
sem = state.proxy_arg(5)
256301
scope = state.proxy_arg(6)
257302
skip_sync = state.proxy_arg(7)
303+
as_ptrs = state.proxy_arg(8)
258304

259305
assert isinstance(signal_pad, torch.Tensor)
260306
assert isinstance(index, list)
261307

262-
indices = SubscriptIndexing.create(state, signal_pad, index)
263-
signal_pad_name = state.device_function.tensor_arg(signal_pad).name
308+
assert type(op) is str
309+
assert type(sem) is str
310+
assert type(scope) is str
311+
312+
if as_ptrs:
313+
bar_tensor_shape = signal_pad.shape
314+
bar_addrs = "signal_pad_arg.to(tl.pointer_type(tl.int32))"
315+
else:
316+
indices = SubscriptIndexing.create(state, signal_pad, index)
317+
if signal_pad.dtype not in (torch.int32, torch.uint32):
318+
raise NotImplementedError(
319+
f"Unsupported signal pad dtype: {signal_pad.dtype}. Must be of torch.int32 or torch.uint32."
320+
)
321+
signal_pad_name = state.device_function.tensor_arg(signal_pad).name
322+
bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index)
323+
bar_addrs = f"{signal_pad_name} + signal_pad_arg"
324+
325+
is_scalar = len(bar_tensor_shape) == 0
264326

265327
signal_expr = ast.Constant(value=signal) # pyright: ignore[reportArgumentType]
266328
if wait_for is not None:
267329
wait_for_expr = ast.Constant(value=wait_for) # pyright: ignore[reportArgumentType]
268330
else:
269331
wait_for_expr = ast.Constant(value=0)
270332
skip_sync_expr = ast.Constant(value=skip_sync) # pyright: ignore[reportArgumentType]
271-
assert type(op) is str
272-
assert type(sem) is str
273-
assert type(scope) is str
274333

275334
if op == "atomic_cas":
276-
bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index)
277-
is_scalar = len(bar_tensor_shape) == 0
278-
if is_scalar:
279-
call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))"
280-
else:
281-
call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))"
282-
335+
call_triton_wait_signal = f"helion.runtime.triton_wait_{'' if is_scalar else 'multiple_'}signal(addr={bar_addrs}, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))"
283336
return expr_from_string(
284337
call_triton_wait_signal,
285-
offset=indices.index_expr,
338+
signal_pad_arg=state.ast_arg(0) if as_ptrs else indices.index_expr, # pyright: ignore[reportPossiblyUnboundVariable]
286339
wait_for=wait_for_expr,
287340
signal=signal_expr,
288341
skip_sync=skip_sync_expr,
289342
)
290-
call_triton_send_signal = f"helion.runtime.triton_send_signal(addr={signal_pad_name} + offset, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=skip_sync)"
343+
call_triton_send_signal = f"helion.runtime.triton_send_signal(addr={bar_addrs}, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=skip_sync)"
291344

292345
return expr_from_string(
293346
call_triton_send_signal,
294-
offset=indices.index_expr,
347+
signal_pad_arg=state.ast_arg(0) if as_ptrs else indices.index_expr, # pyright: ignore[reportPossiblyUnboundVariable]
295348
signal=signal_expr,
296349
skip_sync=skip_sync_expr,
297350
)

helion/runtime/triton_helpers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,11 @@ def triton_wait_multiple_signal(
171171
"Invalid barrier value type. Only supports int32 for multi barrier signal. ",
172172
)
173173

174+
if sync_before:
175+
tl.inline_asm_elementwise(
176+
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
177+
)
178+
174179
addr = tl.ravel(addr)
175180

176181
tl.static_assert(len(addr.shape) == 1, "addr must be a 1D tensor. ")

test/test_signal_wait.expected

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,40 @@ def _gmem_signal_tensor_bar_kernel_make_precompiler(signal_pad: torch.Tensor):
134134
from helion.runtime.precompile_shim import make_precompiler
135135
return make_precompiler(_gmem_signal_tensor_bar_kernel_kernel)(signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
136136

137+
--- assertExpectedJournal(TestWait.test_signal_pointers)
138+
from __future__ import annotations
139+
140+
import torch
141+
import helion
142+
import helion.language as hl
143+
import triton
144+
import triton.language as tl
145+
146+
@triton.jit
147+
def _gmem_signal_pointers_kernel_kernel(signal_pad_ptrs, signal_pad_ptrs_stride_0, N, _BLOCK_SIZE_1: tl.constexpr):
148+
pid_0 = tl.program_id(0)
149+
offset_0 = pid_0
150+
for offset_1 in tl.range(0, N.to(tl.int32), step=_BLOCK_SIZE_1):
151+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
152+
mask_1 = indices_1 < N
153+
load = tl.load(signal_pad_ptrs + indices_1 * signal_pad_ptrs_stride_0, mask_1, other=0)
154+
symnode_0 = 4 * offset_0
155+
v_0 = symnode_0.to(tl.uint64)
156+
v_1 = load + v_0
157+
helion.runtime.triton_send_signal(addr=v_1.to(tl.pointer_type(tl.int32)), update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=False)
158+
159+
def gmem_signal_pointers_kernel(signal_pad_ptrs: torch.Tensor, pad_shape: hl.constexpr):
160+
N = signal_pad_ptrs.size(0)
161+
_BLOCK_SIZE_1 = N
162+
_gmem_signal_pointers_kernel_kernel[4,](signal_pad_ptrs, signal_pad_ptrs.stride(0), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
163+
return signal_pad_ptrs
164+
165+
def _gmem_signal_pointers_kernel_make_precompiler(signal_pad_ptrs: torch.Tensor, pad_shape: hl.constexpr):
166+
N = signal_pad_ptrs.size(0)
167+
_BLOCK_SIZE_1 = N
168+
from helion.runtime.precompile_shim import make_precompiler
169+
return make_precompiler(_gmem_signal_pointers_kernel_kernel)(signal_pad_ptrs, signal_pad_ptrs.stride(0), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
170+
137171
--- assertExpectedJournal(TestWait.test_wait_2d_tile)
138172
from __future__ import annotations
139173

@@ -210,7 +244,7 @@ import helion
210244
import triton
211245
import triton.language as tl
212246

213-
import test.test_signal_wait as _source_module
247+
import __main__ as _source_module
214248

215249
@triton.jit
216250
def _gmem_wait_multi_bar_kernel_kernel(signal_pad, out, out_stride_0, signal_pad_stride_0, _BLOCK_SIZE_0: tl.constexpr):
@@ -265,3 +299,40 @@ def _gmem_wait_multi_bar_kernel_cas_make_precompiler(signal_pad: torch.Tensor):
265299
_BLOCK_SIZE_0 = 4
266300
from helion.runtime.precompile_shim import make_precompiler
267301
return make_precompiler(_gmem_wait_multi_bar_kernel_cas_kernel)(signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
302+
303+
--- assertExpectedJournal(TestWait.test_wait_pointers)
304+
from __future__ import annotations
305+
306+
import torch
307+
import helion
308+
import helion.language as hl
309+
import triton
310+
import triton.language as tl
311+
312+
@triton.jit
313+
def _gmem_wait_pointers_kernel_kernel(signal_pad_ptrs, out, out_stride_0, signal_pad_ptrs_stride_0, N, _BLOCK_SIZE_1: tl.constexpr):
314+
pid_0 = tl.program_id(0)
315+
offset_0 = pid_0
316+
for offset_1 in tl.range(0, N.to(tl.int32), step=_BLOCK_SIZE_1):
317+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
318+
mask_1 = indices_1 < N
319+
load = tl.load(signal_pad_ptrs + indices_1 * signal_pad_ptrs_stride_0, mask_1, other=0)
320+
symnode_0 = 4 * offset_0
321+
v_0 = symnode_0.to(tl.uint64)
322+
v_1 = load + v_0
323+
helion.runtime.triton_wait_multiple_signal(addr=v_1.to(tl.pointer_type(tl.int32)), expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False)
324+
tl.store(out + offset_0 * out_stride_0, offset_0, None)
325+
326+
def gmem_wait_pointers_kernel(signal_pad_ptrs: torch.Tensor, pad_shape: hl.constexpr):
327+
out = torch.empty(4, device=signal_pad_ptrs.device, dtype=torch.int32)
328+
N = signal_pad_ptrs.size(0)
329+
_BLOCK_SIZE_1 = N
330+
_gmem_wait_pointers_kernel_kernel[4,](signal_pad_ptrs, out, out.stride(0), signal_pad_ptrs.stride(0), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
331+
return out
332+
333+
def _gmem_wait_pointers_kernel_make_precompiler(signal_pad_ptrs: torch.Tensor, pad_shape: hl.constexpr):
334+
out = torch.empty(4, device=signal_pad_ptrs.device, dtype=torch.int32)
335+
N = signal_pad_ptrs.size(0)
336+
_BLOCK_SIZE_1 = N
337+
from helion.runtime.precompile_shim import make_precompiler
338+
return make_precompiler(_gmem_wait_pointers_kernel_kernel)(signal_pad_ptrs, out, out.stride(0), signal_pad_ptrs.stride(0), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3)

0 commit comments

Comments
 (0)