Skip to content

Commit 4c0ad72

Browse files
authored
Fix bug with errors on unreachable if branch (#138)
1 parent f7ed720 commit 4c0ad72

File tree

2 files changed

+143
-7
lines changed

2 files changed

+143
-7
lines changed

helion/_compiler/type_propagation.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1950,16 +1950,17 @@ def visit_AugAssign(self, node: ast.AugAssign) -> TypeInfo:
19501950

19511951
def visit_If(self, node: ast.If) -> TypeInfo:
19521952
test = self.visit(node.test)
1953-
body = self._body(node.body)
1954-
orelse = self._body(node.orelse)
19551953
try:
19561954
truth_val = test.truth_value()
1957-
if truth_val:
1958-
self.scope.merge(body)
1959-
else:
1960-
self.scope.merge(orelse)
1955+
has_truth_val = True
19611956
except NotImplementedError:
1962-
self.scope.merge_if_else(body, orelse)
1957+
truth_val = None
1958+
has_truth_val = False
1959+
if has_truth_val:
1960+
# For constant conditions, only type propagate one branch
1961+
self.scope.merge(self._body(node.body if truth_val else node.orelse))
1962+
else:
1963+
self.scope.merge_if_else(self._body(node.body), self._body(node.orelse))
19631964
return NoType(origin=self.origin())
19641965

19651966
def _body(self, stmts: list[ast.stmt]) -> LocalScope:

test/test_control_flow.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,141 @@ def _fn_make_precompiler(x):
200200
return make_precompiler(_fn_kernel)(x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""",
201201
)
202202

203+
def test_error_in_non_taken_branch(self):
204+
def mul_relu_block_back_spec(x, y, dz):
205+
z = torch.relu(x * y[:, None])
206+
grad_x, grad_y = torch.autograd.grad(z, [x, y], dz, retain_graph=True)
207+
return grad_x, grad_y
208+
209+
@helion.kernel(config=helion.Config(block_sizes=[32, 32]))
210+
def mul_relu_block_backward_kernel(
211+
x: torch.Tensor,
212+
y: torch.Tensor,
213+
dz: torch.Tensor,
214+
use_atomics: hl.constexpr = False,
215+
):
216+
# Get tensor sizes
217+
m, n = x.shape
218+
# Create output tensor for gradients
219+
dx = torch.empty_like(x)
220+
221+
if use_atomics:
222+
dy = torch.zeros_like(y)
223+
else:
224+
dy = torch.empty_like(x)
225+
226+
# Use Helion to tile the computation
227+
for tile_i, tile_j in hl.tile([m, n]):
228+
# Get input tiles
229+
x_tile = x[tile_i, tile_j]
230+
y_tile = y[tile_i]
231+
dz_tile = dz[tile_i, tile_j]
232+
233+
# For ReLU, gradient is 1 where input > 0, 0 otherwise
234+
relu_mask = (x_tile * y_tile[:, None]) > 0
235+
# Chain rule: dx = dz * relu_grad * y
236+
relu_grad = torch.where(relu_mask, 1, 0)
237+
dx[tile_i, tile_j] = dz_tile * relu_grad * y_tile[:, None]
238+
239+
# Chain rule: dy = dz * relu_grad * x -> backwards of broadcast(sum)
240+
if use_atomics:
241+
local_dy_grad = torch.sum(dz_tile * relu_grad * x_tile, dim=1)
242+
hl.atomic_add(dy, [tile_i], local_dy_grad)
243+
else:
244+
local_dy_grad = dz_tile * relu_grad * x_tile
245+
dy[tile_i, tile_j] = local_dy_grad
246+
247+
if use_atomics:
248+
return dx, dy
249+
return dx, dy.sum(axis=-1)
250+
251+
x = torch.randn(512, 1024, device="cuda", requires_grad=True)
252+
y = torch.randn(512, device="cuda", requires_grad=True)
253+
dz = torch.randn(512, 1024, device="cuda")
254+
expected = mul_relu_block_back_spec(x, y, dz)
255+
torch.testing.assert_close(
256+
mul_relu_block_backward_kernel(x, y, dz, False),
257+
expected,
258+
)
259+
code, output = code_and_output(
260+
mul_relu_block_backward_kernel,
261+
(x, y, dz, True),
262+
)
263+
self.assertExpectedInline(
264+
code,
265+
"""\
266+
from __future__ import annotations
267+
268+
import torch
269+
import helion.language as hl
270+
import triton
271+
import triton.language as tl
272+
273+
@triton.jit
274+
def _mul_relu_block_backward_kernel_kernel(x, y, dz, dx, dy, dx_stride_0, dx_stride_1, dy_stride_0, dz_stride_0, dz_stride_1, x_stride_0, x_stride_1, y_stride_0, m, n, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
275+
num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0)
276+
pid_0 = tl.program_id(0) % num_blocks_0
277+
pid_1 = tl.program_id(0) // num_blocks_0
278+
offset_0 = pid_0 * _BLOCK_SIZE_0
279+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
280+
mask_0 = indices_0 < m
281+
offset_1 = pid_1 * _BLOCK_SIZE_1
282+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
283+
mask_1 = indices_1 < n
284+
x_tile = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
285+
y_tile = tl.load(y + indices_0 * y_stride_0, mask_0, other=0)
286+
dz_tile = tl.load(dz + (indices_0[:, None] * dz_stride_0 + indices_1[None, :] * dz_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
287+
subscript = y_tile[:, None]
288+
v_0 = x_tile * subscript
289+
v_1 = 0.0
290+
v_2 = v_0 > v_1
291+
v_3 = tl.full([], 0, tl.int64)
292+
v_4 = tl.full([], 1, tl.int64)
293+
v_5 = v_4[None, None]
294+
v_6 = v_3[None, None]
295+
v_7 = tl.where(v_2, v_5, v_6)
296+
v_8 = v_7.to(tl.float32)
297+
v_9 = dz_tile * v_8
298+
subscript_1 = y_tile[:, None]
299+
v_10 = v_9 * subscript_1
300+
tl.store(dx + (indices_0[:, None] * dx_stride_0 + indices_1[None, :] * dx_stride_1), v_10, mask_0[:, None] & mask_1[None, :])
301+
v_11 = v_7.to(tl.float32)
302+
v_12 = dz_tile * v_11
303+
v_13 = v_12 * x_tile
304+
local_dy_grad = tl.sum(v_13, 1)
305+
tl.atomic_add(dy + indices_0 * dy_stride_0, local_dy_grad, mask=mask_0, sem='relaxed')
306+
307+
def mul_relu_block_backward_kernel(x: torch.Tensor, y: torch.Tensor, dz: torch.Tensor, use_atomics: hl.constexpr=False):
308+
m, n = x.shape
309+
dx = torch.empty_like(x)
310+
if True:
311+
dy = torch.zeros_like(y)
312+
else:
313+
dy = torch.empty_like(x)
314+
_BLOCK_SIZE_0 = 32
315+
_BLOCK_SIZE_1 = 32
316+
_mul_relu_block_backward_kernel_kernel[triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),](x, y, dz, dx, dy, dx.stride(0), dx.stride(1), dy.stride(0), dz.stride(0), dz.stride(1), x.stride(0), x.stride(1), y.stride(0), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
317+
if True:
318+
return (dx, dy)
319+
return (dx, dy.sum(axis=-1))
320+
321+
def _mul_relu_block_backward_kernel_make_precompiler(x: torch.Tensor, y: torch.Tensor, dz: torch.Tensor, use_atomics: hl.constexpr=False):
322+
m, n = x.shape
323+
dx = torch.empty_like(x)
324+
if True:
325+
dy = torch.zeros_like(y)
326+
else:
327+
dy = torch.empty_like(x)
328+
_BLOCK_SIZE_0 = 32
329+
_BLOCK_SIZE_1 = 32
330+
from helion.runtime.precompile_shim import make_precompiler
331+
return make_precompiler(_mul_relu_block_backward_kernel_kernel)(x, y, dz, dx, dy, dx.stride(0), dx.stride(1), dy.stride(0), dz.stride(0), dz.stride(1), x.stride(0), x.stride(1), y.stride(0), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""",
332+
)
333+
torch.testing.assert_close(
334+
output,
335+
expected,
336+
)
337+
203338

204339
if __name__ == "__main__":
205340
unittest.main()

0 commit comments

Comments
 (0)