21
21
@_decorators .api (tiles_as_sizes = True , allow_host_tensor = True )
22
22
def wait (
23
23
signal_pad : torch .Tensor ,
24
- index : list [object ],
24
+ index : list [object ] | None = None ,
25
25
signal : int = 1 ,
26
26
update : int | None = None ,
27
27
op : str = "ld" ,
28
28
sem : str = "acquire" ,
29
29
scope : str = "gpu" ,
30
30
skip_sync : bool = False ,
31
+ as_ptrs : bool = False ,
31
32
) -> None :
32
33
"""Wait until all entries of the signal_pad slice are equal to the signal value.
33
34
Args:
@@ -39,6 +40,7 @@ def wait(
39
40
sem: The memory semantic for acquiring the lock (default: 'acquire')
40
41
scope: The scope of the lock (default: 'gpu')
41
42
skip_sync: Skip the syncthreads after the wait (default: False)
43
+ as_ptrs: Treat signal_pad as pointers to global memory barriers (default: False)
42
44
43
45
Returns:
44
46
None
@@ -49,14 +51,15 @@ def wait(
49
51
@_decorators .prepare_args (wait )
50
52
def _ (
51
53
signal_pad : torch .Tensor ,
52
- index : list [object ],
54
+ index : list [object ] | None = None ,
53
55
signal : int = 1 ,
54
56
update : int | None = None ,
55
57
op : str = "ld" ,
56
58
sem : str = "acquire" ,
57
59
scope : str = "gpu" ,
58
60
skip_sync : bool = False ,
59
- ) -> tuple [torch .Tensor , object , int , int | None , str , str , str , bool ]:
61
+ as_ptrs : bool = False ,
62
+ ) -> tuple [torch .Tensor , object , int , int | None , str , str , str , bool , bool ]:
60
63
from helion .language .tile_proxy import Tile
61
64
62
65
valid_ops = {"ld" , "atomic_cas" }
@@ -88,22 +91,37 @@ def _(
88
91
if scope not in valid_scopes :
89
92
raise ValueError (f"Invalid scope '{ scope } '. Must be one of { valid_scopes } ." )
90
93
94
+ if as_ptrs :
95
+ if index is not None :
96
+ raise ValueError (
97
+ f"When as_ptrs=True, signal_pad must be used without indexing. "
98
+ f"Expected 0 indices but got { len (index )} . "
99
+ )
100
+ if signal_pad .dtype not in (torch .uint64 , torch .int64 ):
101
+ raise ValueError (
102
+ f"When as_ptrs=True, signal_pad must have dtype torch.uint64 or torch.int64 "
103
+ f"to represent memory pointers. Got dtype { signal_pad .dtype } . "
104
+ )
105
+ if index is None :
106
+ index = []
107
+
91
108
index = Tile ._prepare_index (index )
92
109
index = Tile ._tiles_to_sizes (index )
93
110
94
- return (signal_pad , index , signal , update , op , sem , scope , skip_sync )
111
+ return (signal_pad , index , signal , update , op , sem , scope , skip_sync , as_ptrs )
95
112
96
113
97
114
@_decorators .register_fake (wait )
98
115
def _ (
99
116
signal_pad : torch .Tensor ,
100
- index : list [object ],
117
+ index : list [object ] | None = None ,
101
118
signal : int = 1 ,
102
119
update : int | None = None ,
103
120
op : str = "ld" ,
104
121
sem : str = "acquire" ,
105
122
scope : str = "sys" ,
106
123
skip_sync : bool = False ,
124
+ as_ptrs : bool = False ,
107
125
) -> None :
108
126
return None
109
127
@@ -123,35 +141,38 @@ def _(state: CodegenState) -> ast.AST:
123
141
sem = state .proxy_arg (5 )
124
142
scope = state .proxy_arg (6 )
125
143
skip_sync = state .proxy_arg (7 )
144
+ as_ptrs = state .proxy_arg (8 )
126
145
127
146
assert isinstance (signal_pad , torch .Tensor )
128
147
assert isinstance (index , (list ))
129
148
130
- indices = SubscriptIndexing .create (state , signal_pad , index )
131
- signal_pad_name = state .device_function .tensor_arg (signal_pad ).name
132
-
133
- signal_expr = ast .Constant (value = signal ) # pyright: ignore[reportArgumentType]
134
- update_expr = ast .Constant (value = update ) # pyright: ignore[reportArgumentType]
135
-
136
149
assert type (op ) is str
137
150
assert type (sem ) is str
138
151
assert type (scope ) is str
139
152
140
- bar_tensor_shape = SubscriptIndexing .compute_shape (signal_pad , index )
141
- is_scalar = len (bar_tensor_shape ) == 0
142
-
143
- if is_scalar :
144
- call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={ signal_pad_name } + offset, expect=signal, update=update, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync={ skip_sync } )"
153
+ if as_ptrs :
154
+ bar_tensor_shape = signal_pad .shape
155
+ bar_addrs = "signal_pad_arg.to(tl.pointer_type(tl.int32))"
145
156
else :
157
+ indices = SubscriptIndexing .create (state , signal_pad , index )
146
158
if signal_pad .dtype not in (torch .int32 , torch .uint32 ):
147
159
raise NotImplementedError (
148
160
f"Unsupported signal pad dtype: { signal_pad .dtype } . Must be of torch.int32 or torch.uint32."
149
161
)
150
- call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={ signal_pad_name } + offset, expect=signal, update=update, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync={ skip_sync } )"
162
+ signal_pad_name = state .device_function .tensor_arg (signal_pad ).name
163
+ bar_tensor_shape = SubscriptIndexing .compute_shape (signal_pad , index )
164
+ bar_addrs = f"{ signal_pad_name } + signal_pad_arg"
165
+
166
+ signal_expr = ast .Constant (value = signal ) # pyright: ignore[reportArgumentType]
167
+ update_expr = ast .Constant (value = update ) # pyright: ignore[reportArgumentType]
168
+
169
+ is_scalar = len (bar_tensor_shape ) == 0
170
+
171
+ call_triton_wait_signal = f"helion.runtime.triton_wait_{ '' if is_scalar else 'multiple_' } signal(addr={ bar_addrs } , expect=signal, update=update, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync={ skip_sync } )"
151
172
152
173
return expr_from_string (
153
174
call_triton_wait_signal ,
154
- offset = indices .index_expr ,
175
+ signal_pad_arg = state . ast_arg ( 0 ) if as_ptrs else indices .index_expr , # pyright: ignore[reportPossiblyUnboundVariable]
155
176
signal = signal_expr ,
156
177
update = update_expr ,
157
178
)
@@ -161,13 +182,14 @@ def _(state: CodegenState) -> ast.AST:
161
182
@_decorators .api (tiles_as_sizes = True , allow_host_tensor = True )
162
183
def signal (
163
184
signal_pad : torch .Tensor ,
164
- index : list [object ],
185
+ index : list [object ] | None = None ,
165
186
signal : int = 1 ,
166
187
wait_for : int | None = None ,
167
188
op : str = "atomic_xchg" ,
168
189
sem : str = "release" ,
169
190
scope : str = "gpu" ,
170
191
skip_sync : bool = False ,
192
+ as_ptrs : bool = False ,
171
193
) -> torch .Tensor :
172
194
"""Set the signal_pad slice to the signal value.
173
195
Args:
@@ -179,21 +201,25 @@ def signal(
179
201
sem: The memory semantic for acquiring the lock (default: 'release')
180
202
scope: The scope of the lock (default: 'gpu')
181
203
skip_sync: Skip the syncthreads before sending signal (default: False)
204
+ as_ptrs: Treat signal_pad as pointers to global memory barriers (default: False)
205
+ Returns:
206
+ The old value of the signal_pad slice before the update.
182
207
"""
183
208
raise exc .NotInsideKernel
184
209
185
210
186
211
@_decorators .prepare_args (signal )
187
212
def _ (
188
213
signal_pad : torch .Tensor ,
189
- index : list [object ],
214
+ index : list [object ] | None = None ,
190
215
signal : int = 1 ,
191
216
wait_for : int | None = None ,
192
217
op : str = "atomic_xchg" ,
193
218
sem : str = "release" ,
194
219
scope : str = "gpu" ,
195
220
skip_sync : bool = False ,
196
- ) -> tuple [torch .Tensor , object , int , int | None , str , str , str , bool ]:
221
+ as_ptrs : bool = False ,
222
+ ) -> tuple [torch .Tensor , object , int , int | None , str , str , str , bool , bool ]:
197
223
from helion .language .tile_proxy import Tile
198
224
199
225
valid_ops = {"atomic_add" , "atomic_xchg" , "atomic_cas" }
@@ -220,23 +246,42 @@ def _(
220
246
if scope not in valid_scopes :
221
247
raise ValueError (f"Invalid scope '{ scope } '. Must be one of { valid_scopes } ." )
222
248
249
+ if as_ptrs :
250
+ if index is not None :
251
+ raise ValueError (
252
+ f"When as_ptrs=True, signal_pad must be used without indexing. "
253
+ f"Expected 0 indices but got { len (index )} . "
254
+ )
255
+ if signal_pad .dtype not in (torch .uint64 , torch .int64 ):
256
+ raise ValueError (
257
+ f"When as_ptrs=True, signal_pad must have dtype torch.uint64 or torch.int64 "
258
+ f"to represent memory pointers. Got dtype { signal_pad .dtype } . "
259
+ )
260
+ if index is None :
261
+ index = []
262
+
223
263
index = Tile ._prepare_index (index )
224
264
index = Tile ._tiles_to_sizes (index )
225
265
226
- return (signal_pad , index , signal , wait_for , op , sem , scope , skip_sync )
266
+ return (signal_pad , index , signal , wait_for , op , sem , scope , skip_sync , as_ptrs )
227
267
228
268
229
269
@_decorators .register_fake (signal )
230
270
def _ (
231
271
signal_pad : torch .Tensor ,
232
- index : list [object ],
272
+ index : list [object ] | None = None ,
233
273
signal : int = 1 ,
234
274
wait_for : int | None = None ,
235
275
op : str = "atomic_xchg" ,
236
276
sem : str = "release" ,
237
277
scope : str = "gpu" ,
238
278
skip_sync : bool = False ,
279
+ as_ptrs : bool = False ,
239
280
) -> torch .Tensor :
281
+ if index is None :
282
+ index = []
283
+ if as_ptrs :
284
+ return signal_pad .new_empty (signal_pad .shape )
240
285
return signal_pad .new_empty (SubscriptIndexing .compute_shape (signal_pad , index ))
241
286
242
287
@@ -255,43 +300,51 @@ def _(state: CodegenState) -> ast.AST:
255
300
sem = state .proxy_arg (5 )
256
301
scope = state .proxy_arg (6 )
257
302
skip_sync = state .proxy_arg (7 )
303
+ as_ptrs = state .proxy_arg (8 )
258
304
259
305
assert isinstance (signal_pad , torch .Tensor )
260
306
assert isinstance (index , list )
261
307
262
- indices = SubscriptIndexing .create (state , signal_pad , index )
263
- signal_pad_name = state .device_function .tensor_arg (signal_pad ).name
308
+ assert type (op ) is str
309
+ assert type (sem ) is str
310
+ assert type (scope ) is str
311
+
312
+ if as_ptrs :
313
+ bar_tensor_shape = signal_pad .shape
314
+ bar_addrs = "signal_pad_arg.to(tl.pointer_type(tl.int32))"
315
+ else :
316
+ indices = SubscriptIndexing .create (state , signal_pad , index )
317
+ if signal_pad .dtype not in (torch .int32 , torch .uint32 ):
318
+ raise NotImplementedError (
319
+ f"Unsupported signal pad dtype: { signal_pad .dtype } . Must be of torch.int32 or torch.uint32."
320
+ )
321
+ signal_pad_name = state .device_function .tensor_arg (signal_pad ).name
322
+ bar_tensor_shape = SubscriptIndexing .compute_shape (signal_pad , index )
323
+ bar_addrs = f"{ signal_pad_name } + signal_pad_arg"
324
+
325
+ is_scalar = len (bar_tensor_shape ) == 0
264
326
265
327
signal_expr = ast .Constant (value = signal ) # pyright: ignore[reportArgumentType]
266
328
if wait_for is not None :
267
329
wait_for_expr = ast .Constant (value = wait_for ) # pyright: ignore[reportArgumentType]
268
330
else :
269
331
wait_for_expr = ast .Constant (value = 0 )
270
332
skip_sync_expr = ast .Constant (value = skip_sync ) # pyright: ignore[reportArgumentType]
271
- assert type (op ) is str
272
- assert type (sem ) is str
273
- assert type (scope ) is str
274
333
275
334
if op == "atomic_cas" :
276
- bar_tensor_shape = SubscriptIndexing .compute_shape (signal_pad , index )
277
- is_scalar = len (bar_tensor_shape ) == 0
278
- if is_scalar :
279
- call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={ signal_pad_name } + offset, expect=wait_for, update=signal, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync=True, sync_before=(not skip_sync))"
280
- else :
281
- call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={ signal_pad_name } + offset, expect=wait_for, update=signal, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync=True, sync_before=(not skip_sync))"
282
-
335
+ call_triton_wait_signal = f"helion.runtime.triton_wait_{ '' if is_scalar else 'multiple_' } signal(addr={ bar_addrs } , expect=wait_for, update=signal, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync=True, sync_before=(not skip_sync))"
283
336
return expr_from_string (
284
337
call_triton_wait_signal ,
285
- offset = indices .index_expr ,
338
+ signal_pad_arg = state . ast_arg ( 0 ) if as_ptrs else indices .index_expr , # pyright: ignore[reportPossiblyUnboundVariable]
286
339
wait_for = wait_for_expr ,
287
340
signal = signal_expr ,
288
341
skip_sync = skip_sync_expr ,
289
342
)
290
- call_triton_send_signal = f"helion.runtime.triton_send_signal(addr={ signal_pad_name } + offset , update=signal, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync=skip_sync)"
343
+ call_triton_send_signal = f"helion.runtime.triton_send_signal(addr={ bar_addrs } , update=signal, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync=skip_sync)"
291
344
292
345
return expr_from_string (
293
346
call_triton_send_signal ,
294
- offset = indices .index_expr ,
347
+ signal_pad_arg = state . ast_arg ( 0 ) if as_ptrs else indices .index_expr , # pyright: ignore[reportPossiblyUnboundVariable]
295
348
signal = signal_expr ,
296
349
skip_sync = skip_sync_expr ,
297
350
)
0 commit comments