21
21
@_decorators .api (tiles_as_sizes = True )
22
22
def wait (
23
23
signal_pad : torch .Tensor ,
24
- index : list [object ],
24
+ index : list [object ] = [] ,
25
25
signal : int = 1 ,
26
26
update : int | None = None ,
27
27
op : str = "ld" ,
28
28
sem : str = "acquire" ,
29
29
scope : str = "gpu" ,
30
30
skip_sync : bool = False ,
31
+ as_ptrs : bool = False ,
31
32
) -> None :
32
33
"""Wait until all entries of the signal_pad slice are equal to the signal value.
33
34
Args:
@@ -39,6 +40,7 @@ def wait(
39
40
sem: The memory sematic for acquring the lock (default: 'acquire')
40
41
scope: The scope of the lock (default: 'gpu')
41
42
skip_sync: Skip the syncthreads after the wait (default: False)
43
+ as_ptrs: Treat signal_pad as pointers to global memory barriers (default: False)
42
44
43
45
Returns:
44
46
None
@@ -49,14 +51,15 @@ def wait(
49
51
@_decorators .prepare_args (wait )
50
52
def _ (
51
53
signal_pad : torch .Tensor ,
52
- index : list [object ],
54
+ index : list [object ] = [] ,
53
55
signal : int = 1 ,
54
56
update : int | None = None ,
55
57
op : str = "ld" ,
56
58
sem : str = "acquire" ,
57
59
scope : str = "gpu" ,
58
60
skip_sync : bool = False ,
59
- ) -> tuple [torch .Tensor , object , int , int | None , str , str , str , bool ]:
61
+ as_ptrs : bool = False ,
62
+ ) -> tuple [torch .Tensor , object , int , int | None , str , str , str , bool , bool ]:
60
63
from helion .language .tile_proxy import Tile
61
64
62
65
valid_ops = {"ld" , "atomic_cas" }
@@ -88,22 +91,35 @@ def _(
88
91
if scope not in valid_scopes :
89
92
raise ValueError (f"Invalid scope '{ scope } '. Must be one of { valid_scopes } ." )
90
93
94
+ if as_ptrs :
95
+ if len (index ) != 0 :
96
+ raise ValueError (
97
+ f"When as_ptrs=True, signal_pad must be used without indexing. "
98
+ f"Expected 0 indices but got { len (index )} . "
99
+ )
100
+ if signal_pad .dtype not in (torch .uint64 , torch .int64 ):
101
+ raise ValueError (
102
+ f"When as_ptrs=True, signal_pad must have dtype torch.uint64 or torch.int64 "
103
+ f"to represent memory pointers. Got dtype { signal_pad .dtype } . "
104
+ )
105
+
91
106
index = Tile ._prepare_index (index )
92
107
index = Tile ._tiles_to_sizes (index )
93
108
94
- return (signal_pad , index , signal , update , op , sem , scope , skip_sync )
109
+ return (signal_pad , index , signal , update , op , sem , scope , skip_sync , as_ptrs )
95
110
96
111
97
112
@_decorators .register_fake (wait )
98
113
def _ (
99
114
signal_pad : torch .Tensor ,
100
- index : list [object ],
115
+ index : list [object ] = [] ,
101
116
signal : int = 1 ,
102
117
update : int | None = None ,
103
118
op : str = "ld" ,
104
119
sem : str = "acquire" ,
105
120
scope : str = "sys" ,
106
121
skip_sync : bool = False ,
122
+ as_ptrs : bool = False ,
107
123
) -> None :
108
124
return None
109
125
@@ -123,35 +139,38 @@ def _(state: CodegenState) -> ast.AST:
123
139
sem = state .proxy_arg (5 )
124
140
scope = state .proxy_arg (6 )
125
141
skip_sync = state .proxy_arg (7 )
142
+ as_ptrs = state .proxy_arg (8 )
126
143
127
144
assert isinstance (signal_pad , torch .Tensor )
128
145
assert isinstance (index , (list ))
129
146
130
- indices = SubscriptIndexing .create (state , signal_pad , index )
131
- signal_pad_name = state .device_function .tensor_arg (signal_pad ).name
132
-
133
- signal_expr = ast .Constant (value = signal )
134
- update_expr = ast .Constant (value = update )
135
-
136
147
assert type (op ) is str
137
148
assert type (sem ) is str
138
149
assert type (scope ) is str
139
150
140
- bar_tensor_shape = SubscriptIndexing .compute_shape (signal_pad , index )
141
- is_scalar = len (bar_tensor_shape ) == 0
142
-
143
- if is_scalar :
144
- call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={ signal_pad_name } + offset, expect=signal, update=update, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync={ skip_sync } )"
151
+ if as_ptrs :
152
+ bar_tensor_shape = signal_pad .shape
153
+ bar_addrs = "signal_pad_arg.to(tl.pointer_type(tl.int32))"
145
154
else :
155
+ indices = SubscriptIndexing .create (state , signal_pad , index )
146
156
if signal_pad .dtype not in (torch .int32 , torch .uint32 ):
147
157
raise NotImplementedError (
148
158
f"Unsupported signal pad dtype: { signal_pad .dtype } . Must be of torch.int32 or torch.uint32."
149
159
)
150
- call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={ signal_pad_name } + offset, expect=signal, update=update, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync={ skip_sync } )"
160
+ signal_pad_name = state .device_function .tensor_arg (signal_pad ).name
161
+ bar_tensor_shape = SubscriptIndexing .compute_shape (signal_pad , index )
162
+ bar_addrs = f"{ signal_pad_name } + signal_pad_arg"
163
+
164
+ signal_expr = ast .Constant (value = signal )
165
+ update_expr = ast .Constant (value = update )
166
+
167
+ is_scalar = len (bar_tensor_shape ) == 0
168
+
169
+ call_triton_wait_signal = f"helion.runtime.triton_wait_{ '' if is_scalar else 'multiple_' } signal(addr={ bar_addrs } , expect=signal, update=update, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync={ skip_sync } )"
151
170
152
171
return expr_from_string (
153
172
call_triton_wait_signal ,
154
- offset = indices .index_expr ,
173
+ signal_pad_arg = state . ast_arg ( 0 ) if as_ptrs else indices .index_expr ,
155
174
signal = signal_expr ,
156
175
update = update_expr ,
157
176
)
@@ -161,13 +180,14 @@ def _(state: CodegenState) -> ast.AST:
161
180
@_decorators .api (tiles_as_sizes = True )
162
181
def signal (
163
182
signal_pad : torch .Tensor ,
164
- index : list [object ],
183
+ index : list [object ] = [] ,
165
184
signal : int = 1 ,
166
185
wait_for : int | None = None ,
167
186
op : str = "atomic_xchg" ,
168
187
sem : str = "release" ,
169
188
scope : str = "gpu" ,
170
189
skip_sync : bool = False ,
190
+ as_ptrs : bool = False ,
171
191
) -> torch .Tensor :
172
192
"""Set the signal_pad slice to the signal value.
173
193
Args:
@@ -179,21 +199,25 @@ def signal(
179
199
sem: The memory sematic for acquring the lock (default: 'release')
180
200
scope: The scope of the lock (default: 'gpu')
181
201
skip_sync: Skip the syncthreads before sending signal (default: False)
202
+ as_ptrs: Treat signal_pad as pointers to global memory barriers (default: False)
203
+ Returns:
204
+ The old value of the signal_pad slice before the update.
182
205
"""
183
206
raise exc .NotInsideKernel
184
207
185
208
186
209
@_decorators .prepare_args (signal )
187
210
def _ (
188
211
signal_pad : torch .Tensor ,
189
- index : list [object ],
212
+ index : list [object ] = [] ,
190
213
signal : int = 1 ,
191
214
wait_for : int | None = None ,
192
215
op : str = "atomic_xchg" ,
193
216
sem : str = "release" ,
194
217
scope : str = "gpu" ,
195
218
skip_sync : bool = False ,
196
- ) -> tuple [torch .Tensor , object , int , int | None , str , str , str , bool ]:
219
+ as_ptrs : bool = False ,
220
+ ) -> tuple [torch .Tensor , object , int , int | None , str , str , str , bool , bool ]:
197
221
from helion .language .tile_proxy import Tile
198
222
199
223
valid_ops = {"atomic_add" , "atomic_xchg" , "atomic_cas" }
@@ -220,23 +244,38 @@ def _(
220
244
if scope not in valid_scopes :
221
245
raise ValueError (f"Invalid scope '{ scope } '. Must be one of { valid_scopes } ." )
222
246
247
+ if as_ptrs :
248
+ if len (index ) != 0 :
249
+ raise ValueError (
250
+ f"When as_ptrs=True, signal_pad must be used without indexing. "
251
+ f"Expected 0 indices but got { len (index )} . "
252
+ )
253
+ if signal_pad .dtype not in (torch .uint64 , torch .int64 ):
254
+ raise ValueError (
255
+ f"When as_ptrs=True, signal_pad must have dtype torch.uint64 or torch.int64 "
256
+ f"to represent memory pointers. Got dtype { signal_pad .dtype } . "
257
+ )
258
+
223
259
index = Tile ._prepare_index (index )
224
260
index = Tile ._tiles_to_sizes (index )
225
261
226
- return (signal_pad , index , signal , wait_for , op , sem , scope , skip_sync )
262
+ return (signal_pad , index , signal , wait_for , op , sem , scope , skip_sync , as_ptrs )
227
263
228
264
229
265
@_decorators .register_fake (signal )
230
266
def _ (
231
267
signal_pad : torch .Tensor ,
232
- index : list [object ],
268
+ index : list [object ] = [] ,
233
269
signal : int = 1 ,
234
270
wait_for : int | None = None ,
235
271
op : str = "atomic_xchg" ,
236
272
sem : str = "release" ,
237
273
scope : str = "gpu" ,
238
274
skip_sync : bool = False ,
275
+ as_ptrs : bool = False ,
239
276
) -> torch .Tensor :
277
+ if as_ptrs :
278
+ return signal_pad .new_empty (signal_pad .shape )
240
279
return signal_pad .new_empty (SubscriptIndexing .compute_shape (signal_pad , index ))
241
280
242
281
@@ -255,43 +294,51 @@ def _(state: CodegenState) -> ast.AST:
255
294
sem = state .proxy_arg (5 )
256
295
scope = state .proxy_arg (6 )
257
296
skip_sync = state .proxy_arg (7 )
297
+ as_ptrs = state .proxy_arg (8 )
258
298
259
299
assert isinstance (signal_pad , torch .Tensor )
260
300
assert isinstance (index , list )
261
301
262
- indices = SubscriptIndexing .create (state , signal_pad , index )
263
- signal_pad_name = state .device_function .tensor_arg (signal_pad ).name
302
+ assert type (op ) is str
303
+ assert type (sem ) is str
304
+ assert type (scope ) is str
305
+
306
+ if as_ptrs :
307
+ bar_tensor_shape = signal_pad .shape
308
+ bar_addrs = "signal_pad_arg.to(tl.pointer_type(tl.int32))"
309
+ else :
310
+ indices = SubscriptIndexing .create (state , signal_pad , index )
311
+ if signal_pad .dtype not in (torch .int32 , torch .uint32 ):
312
+ raise NotImplementedError (
313
+ f"Unsupported signal pad dtype: { signal_pad .dtype } . Must be of torch.int32 or torch.uint32."
314
+ )
315
+ signal_pad_name = state .device_function .tensor_arg (signal_pad ).name
316
+ bar_tensor_shape = SubscriptIndexing .compute_shape (signal_pad , index )
317
+ bar_addrs = f"{ signal_pad_name } + signal_pad_arg"
264
318
265
319
signal_expr = ast .Constant (value = signal )
266
320
if wait_for is not None :
267
321
wait_for_expr = ast .Constant (value = wait_for )
268
322
else :
269
323
wait_for_expr = ast .Constant (value = 0 )
270
324
skip_sync_expr = ast .Constant (value = skip_sync )
271
- assert type (op ) is str
272
- assert type (sem ) is str
273
- assert type (scope ) is str
274
325
275
326
if op == "atomic_cas" :
276
327
bar_tensor_shape = SubscriptIndexing .compute_shape (signal_pad , index )
277
328
is_scalar = len (bar_tensor_shape ) == 0
278
- if is_scalar :
279
- call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={ signal_pad_name } + offset, expect=wait_for, update=signal, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync=True, sync_before=(not skip_sync))"
280
- else :
281
- call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={ signal_pad_name } + offset, expect=wait_for, update=signal, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync=True, sync_before=(not skip_sync))"
282
-
329
+ call_triton_wait_signal = f"helion.runtime.triton_wait_{ '' if is_scalar else 'multiple_' } signal(addr={ bar_addrs } , expect=wait_for, update=signal, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync=True, sync_before=(not skip_sync))"
283
330
return expr_from_string (
284
331
call_triton_wait_signal ,
285
- offset = indices .index_expr ,
332
+ signal_pad_arg = state . ast_arg ( 0 ) if as_ptrs else indices .index_expr ,
286
333
wait_for = wait_for_expr ,
287
334
signal = signal_expr ,
288
335
skip_sync = skip_sync_expr ,
289
336
)
290
- call_triton_send_signal = f"helion.runtime.triton_send_signal(addr={ signal_pad_name } + offset , update=signal, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync=skip_sync)"
337
+ call_triton_send_signal = f"helion.runtime.triton_send_signal(addr={ bar_addrs } , update=signal, sem='{ sem } ', scope='{ scope } ', op='{ op } ', skip_sync=skip_sync)"
291
338
292
339
return expr_from_string (
293
340
call_triton_send_signal ,
294
- offset = indices .index_expr ,
341
+ signal_pad_arg = state . ast_arg ( 0 ) if as_ptrs else indices .index_expr ,
295
342
signal = signal_expr ,
296
343
skip_sync = skip_sync_expr ,
297
344
)
0 commit comments