@@ -200,6 +200,141 @@ def _fn_make_precompiler(x):
200
200
return make_precompiler(_fn_kernel)(x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""" ,
201
201
)
202
202
203
+ def test_error_in_non_taken_branch (self ):
204
+ def mul_relu_block_back_spec (x , y , dz ):
205
+ z = torch .relu (x * y [:, None ])
206
+ grad_x , grad_y = torch .autograd .grad (z , [x , y ], dz , retain_graph = True )
207
+ return grad_x , grad_y
208
+
209
+ @helion .kernel (config = helion .Config (block_sizes = [32 , 32 ]))
210
+ def mul_relu_block_backward_kernel (
211
+ x : torch .Tensor ,
212
+ y : torch .Tensor ,
213
+ dz : torch .Tensor ,
214
+ use_atomics : hl .constexpr = False ,
215
+ ):
216
+ # Get tensor sizes
217
+ m , n = x .shape
218
+ # Create output tensor for gradients
219
+ dx = torch .empty_like (x )
220
+
221
+ if use_atomics :
222
+ dy = torch .zeros_like (y )
223
+ else :
224
+ dy = torch .empty_like (x )
225
+
226
+ # Use Helion to tile the computation
227
+ for tile_i , tile_j in hl .tile ([m , n ]):
228
+ # Get input tiles
229
+ x_tile = x [tile_i , tile_j ]
230
+ y_tile = y [tile_i ]
231
+ dz_tile = dz [tile_i , tile_j ]
232
+
233
+ # For ReLU, gradient is 1 where input > 0, 0 otherwise
234
+ relu_mask = (x_tile * y_tile [:, None ]) > 0
235
+ # Chain rule: dx = dz * relu_grad * y
236
+ relu_grad = torch .where (relu_mask , 1 , 0 )
237
+ dx [tile_i , tile_j ] = dz_tile * relu_grad * y_tile [:, None ]
238
+
239
+ # Chain rule: dy = dz * relu_grad * x -> backwards of broadcast(sum)
240
+ if use_atomics :
241
+ local_dy_grad = torch .sum (dz_tile * relu_grad * x_tile , dim = 1 )
242
+ hl .atomic_add (dy , [tile_i ], local_dy_grad )
243
+ else :
244
+ local_dy_grad = dz_tile * relu_grad * x_tile
245
+ dy [tile_i , tile_j ] = local_dy_grad
246
+
247
+ if use_atomics :
248
+ return dx , dy
249
+ return dx , dy .sum (axis = - 1 )
250
+
251
+ x = torch .randn (512 , 1024 , device = "cuda" , requires_grad = True )
252
+ y = torch .randn (512 , device = "cuda" , requires_grad = True )
253
+ dz = torch .randn (512 , 1024 , device = "cuda" )
254
+ expected = mul_relu_block_back_spec (x , y , dz )
255
+ torch .testing .assert_close (
256
+ mul_relu_block_backward_kernel (x , y , dz , False ),
257
+ expected ,
258
+ )
259
+ code , output = code_and_output (
260
+ mul_relu_block_backward_kernel ,
261
+ (x , y , dz , True ),
262
+ )
263
+ self .assertExpectedInline (
264
+ code ,
265
+ """\
266
+ from __future__ import annotations
267
+
268
+ import torch
269
+ import helion.language as hl
270
+ import triton
271
+ import triton.language as tl
272
+
273
+ @triton.jit
274
+ def _mul_relu_block_backward_kernel_kernel(x, y, dz, dx, dy, dx_stride_0, dx_stride_1, dy_stride_0, dz_stride_0, dz_stride_1, x_stride_0, x_stride_1, y_stride_0, m, n, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
275
+ num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0)
276
+ pid_0 = tl.program_id(0) % num_blocks_0
277
+ pid_1 = tl.program_id(0) // num_blocks_0
278
+ offset_0 = pid_0 * _BLOCK_SIZE_0
279
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
280
+ mask_0 = indices_0 < m
281
+ offset_1 = pid_1 * _BLOCK_SIZE_1
282
+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
283
+ mask_1 = indices_1 < n
284
+ x_tile = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
285
+ y_tile = tl.load(y + indices_0 * y_stride_0, mask_0, other=0)
286
+ dz_tile = tl.load(dz + (indices_0[:, None] * dz_stride_0 + indices_1[None, :] * dz_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
287
+ subscript = y_tile[:, None]
288
+ v_0 = x_tile * subscript
289
+ v_1 = 0.0
290
+ v_2 = v_0 > v_1
291
+ v_3 = tl.full([], 0, tl.int64)
292
+ v_4 = tl.full([], 1, tl.int64)
293
+ v_5 = v_4[None, None]
294
+ v_6 = v_3[None, None]
295
+ v_7 = tl.where(v_2, v_5, v_6)
296
+ v_8 = v_7.to(tl.float32)
297
+ v_9 = dz_tile * v_8
298
+ subscript_1 = y_tile[:, None]
299
+ v_10 = v_9 * subscript_1
300
+ tl.store(dx + (indices_0[:, None] * dx_stride_0 + indices_1[None, :] * dx_stride_1), v_10, mask_0[:, None] & mask_1[None, :])
301
+ v_11 = v_7.to(tl.float32)
302
+ v_12 = dz_tile * v_11
303
+ v_13 = v_12 * x_tile
304
+ local_dy_grad = tl.sum(v_13, 1)
305
+ tl.atomic_add(dy + indices_0 * dy_stride_0, local_dy_grad, mask=mask_0, sem='relaxed')
306
+
307
+ def mul_relu_block_backward_kernel(x: torch.Tensor, y: torch.Tensor, dz: torch.Tensor, use_atomics: hl.constexpr=False):
308
+ m, n = x.shape
309
+ dx = torch.empty_like(x)
310
+ if True:
311
+ dy = torch.zeros_like(y)
312
+ else:
313
+ dy = torch.empty_like(x)
314
+ _BLOCK_SIZE_0 = 32
315
+ _BLOCK_SIZE_1 = 32
316
+ _mul_relu_block_backward_kernel_kernel[triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),](x, y, dz, dx, dy, dx.stride(0), dx.stride(1), dy.stride(0), dz.stride(0), dz.stride(1), x.stride(0), x.stride(1), y.stride(0), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
317
+ if True:
318
+ return (dx, dy)
319
+ return (dx, dy.sum(axis=-1))
320
+
321
+ def _mul_relu_block_backward_kernel_make_precompiler(x: torch.Tensor, y: torch.Tensor, dz: torch.Tensor, use_atomics: hl.constexpr=False):
322
+ m, n = x.shape
323
+ dx = torch.empty_like(x)
324
+ if True:
325
+ dy = torch.zeros_like(y)
326
+ else:
327
+ dy = torch.empty_like(x)
328
+ _BLOCK_SIZE_0 = 32
329
+ _BLOCK_SIZE_1 = 32
330
+ from helion.runtime.precompile_shim import make_precompiler
331
+ return make_precompiler(_mul_relu_block_backward_kernel_kernel)(x, y, dz, dx, dy, dx.stride(0), dx.stride(1), dy.stride(0), dz.stride(0), dz.stride(1), x.stride(0), x.stride(1), y.stride(0), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""" ,
332
+ )
333
+ torch .testing .assert_close (
334
+ output ,
335
+ expected ,
336
+ )
337
+
203
338
204
339
if __name__ == "__main__" :
205
340
unittest .main ()
0 commit comments