Skip to content

Commit e8cf178

Browse files
authored
Add host side dead code elimination (#289)
1 parent b092b6c commit e8cf178

11 files changed

+151
-67
lines changed

helion/_compiler/ast_read_writes.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,130 @@ def ast_delete_assignments(body: list[ast.AST], to_remove: set[str]) -> list[ast
132132
if new_node is not None:
133133
new_body.append(new_node)
134134
return new_body
135+
136+
137+
class _NotPureException(Exception):
138+
pass
139+
140+
141+
class _PureExpressionVisitor(ast.NodeVisitor):
142+
"""
143+
AST visitor that determines if an expression is guaranteed to be pure.
144+
"""
145+
146+
def generic_visit(self, node: ast.AST) -> None:
147+
# Anything without a specific visitor is not pure
148+
raise _NotPureException
149+
150+
def visit_Constant(self, node: ast.Constant) -> None:
151+
pass
152+
153+
def visit_Num(self, node: ast.Num) -> None:
154+
pass
155+
156+
def visit_Str(self, node: ast.Str) -> None:
157+
pass
158+
159+
def visit_Bytes(self, node: ast.Bytes) -> None:
160+
pass
161+
162+
def visit_NameConstant(self, node: ast.NameConstant) -> None:
163+
pass
164+
165+
def visit_Ellipsis(self, node: ast.Ellipsis) -> None:
166+
pass
167+
168+
def visit_Name(self, node: ast.Name) -> None:
169+
pass
170+
171+
def visit_Tuple(self, node: ast.Tuple) -> None:
172+
for elt in node.elts:
173+
self.visit(elt)
174+
175+
def visit_List(self, node: ast.List) -> None:
176+
for elt in node.elts:
177+
self.visit(elt)
178+
179+
def visit_Set(self, node: ast.Set) -> None:
180+
for elt in node.elts:
181+
self.visit(elt)
182+
183+
def visit_Dict(self, node: ast.Dict) -> None:
184+
for key in node.keys:
185+
if key is not None: # Handle dict unpacking
186+
self.visit(key)
187+
for value in node.values:
188+
self.visit(value)
189+
190+
def visit_BinOp(self, node: ast.BinOp) -> None:
191+
self.visit(node.left)
192+
self.visit(node.right)
193+
194+
def visit_UnaryOp(self, node: ast.UnaryOp) -> None:
195+
self.visit(node.operand)
196+
197+
def visit_Starred(self, node: ast.Starred) -> None:
198+
self.visit(node.value)
199+
200+
201+
def definitely_does_not_have_side_effects(expr: ast.expr) -> bool:
202+
try:
203+
_PureExpressionVisitor().visit(expr)
204+
return True
205+
except _NotPureException:
206+
return False
207+
208+
209+
class _DeletePureExpressions(ast.NodeTransformer):
210+
def visit_Expr(self, node: ast.Expr) -> ast.Expr | None:
211+
if definitely_does_not_have_side_effects(node.value):
212+
return None
213+
return node
214+
215+
216+
def dead_assignment_elimination(
217+
body: list[ast.AST],
218+
dce_vars: list[str],
219+
num_iterations=8,
220+
input_rw: ReadWrites | None = None,
221+
) -> None:
222+
"""
223+
Eliminates dead assignments from body
224+
"""
225+
226+
# num_iterations and input_rw are not compatible with each other
227+
assert num_iterations == 1 or input_rw is None
228+
for _ in range(num_iterations):
229+
rw = input_rw if input_rw is not None else ReadWrites.from_list(body)
230+
to_remove = set()
231+
for name in dce_vars:
232+
if name in rw.writes and name not in rw.reads:
233+
to_remove.add(name)
234+
if not to_remove:
235+
break
236+
body[:] = ast_delete_assignments(body, to_remove)
237+
238+
239+
def is_string_expr(node: ast.AST) -> bool:
240+
return (
241+
isinstance(node, ast.Expr)
242+
and isinstance(node.value, ast.Constant)
243+
and isinstance(node.value.value, str)
244+
)
245+
246+
247+
def dead_expression_elimination(body: list[ast.AST]) -> None:
248+
"""
249+
Eliminates dead expressions from body
250+
"""
251+
new_body = []
252+
for node in body:
253+
if is_string_expr(node):
254+
# triple quoted comments and strings are indistinguishable
255+
# do not eliminate them
256+
new_body.append(node)
257+
continue
258+
new_node = _DeletePureExpressions().visit(node)
259+
if new_node is not None:
260+
new_body.append(new_node)
261+
body[:] = new_body

helion/_compiler/device_function.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
from .ast_extension import expr_from_string
2525
from .ast_extension import statement_from_string
2626
from .ast_read_writes import ReadWrites
27-
from .ast_read_writes import ast_delete_assignments
2827
from .ast_read_writes import ast_rename
28+
from .ast_read_writes import dead_assignment_elimination
2929
from .compile_environment import CompileEnvironment
3030
from .host_function import HostFunction
3131
from .host_function import NoCurrentFunction
@@ -463,14 +463,8 @@ def dead_code_elimination(self) -> None:
463463

464464
for _ in range(8):
465465
rw = ReadWrites.from_list([*self.preamble, *self.body])
466-
to_remove = set()
467-
for name in self.dce_vars:
468-
if name in rw.writes and name not in rw.reads:
469-
to_remove.add(name)
470-
if not to_remove:
471-
break
472-
self.body[:] = ast_delete_assignments(self.body, to_remove)
473-
self.preamble[:] = ast_delete_assignments(self.preamble, to_remove)
466+
dead_assignment_elimination(self.body, self.dce_vars, 1, rw)
467+
dead_assignment_elimination(self.preamble, self.dce_vars, 1, rw)
474468

475469
# drop any unused args
476470
args_to_remove = {

helion/_compiler/generate_ast.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from typing import TYPE_CHECKING
77
from typing import NamedTuple
88

9+
from torch.utils._ordered_set import OrderedSet
10+
911
from .. import exc
1012
from ..language._decorators import is_api_func
1113
from ..runtime.precompile_shim import make_precompiler
@@ -15,6 +17,9 @@
1517
from .ast_extension import create
1618
from .ast_extension import expr_from_string
1719
from .ast_extension import statement_from_string
20+
from .ast_read_writes import dead_assignment_elimination
21+
from .ast_read_writes import dead_expression_elimination
22+
from .ast_read_writes import definitely_does_not_have_side_effects
1823
from .compile_environment import CompileEnvironment
1924
from .device_function import DeviceFunction
2025
from .helper_function import CodegenInterface
@@ -322,6 +327,21 @@ def visit_Call(self, node: ast.Call) -> ast.AST:
322327
)
323328
return self.generic_visit(node)
324329

330+
def host_dead_code_elimination(self) -> None:
331+
dce_vars: OrderedSet[str] = OrderedSet()
332+
for stmt in self.host_statements:
333+
if (
334+
isinstance(stmt, ast.Assign)
335+
and definitely_does_not_have_side_effects(stmt.value)
336+
and all(isinstance(name, ast.Name) for name in stmt.targets)
337+
):
338+
for name in stmt.targets:
339+
assert isinstance(name, ast.Name)
340+
dce_vars.add(name.id)
341+
342+
dead_assignment_elimination(self.host_statements, list(dce_vars))
343+
dead_expression_elimination(self.host_statements)
344+
325345

326346
class TensorReference(NamedTuple):
327347
node: ast.AST
@@ -413,6 +433,7 @@ def generate_ast(func: HostFunction, config: Config) -> ast.AST:
413433
for stmt in func.body:
414434
codegen.add_statement(codegen.visit(stmt))
415435
kernel_def = codegen.device_function.codegen_function_def()
436+
codegen.host_dead_code_elimination()
416437
host_def = func.codegen_function_def(codegen.host_statements)
417438
precompile_def = codegen_precompile_def(
418439
host_def, codegen.device_function.name

test/test_broadcasting.expected

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,6 @@ def fn(a, idx1):
242242
out0 = torch.empty_like(a)
243243
out1 = torch.empty_like(a)
244244
out2 = torch.empty_like(a)
245-
idx0 = 11
246245
_BLOCK_SIZE_0 = 16
247246
_BLOCK_SIZE_1 = 16
248247
_fn_kernel[triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),](a, out0, out1, out2, a.size(0), a.size(1), a.stride(0), a.stride(1), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), out2.stride(0), out2.stride(1), idx1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
@@ -252,7 +251,6 @@ def _fn_make_precompiler(a, idx1):
252251
out0 = torch.empty_like(a)
253252
out1 = torch.empty_like(a)
254253
out2 = torch.empty_like(a)
255-
idx0 = 11
256254
_BLOCK_SIZE_0 = 16
257255
_BLOCK_SIZE_1 = 16
258256
from helion.runtime.precompile_shim import make_precompiler

test/test_control_flow.expected

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,13 @@ def _fn_kernel(x, out, out_size_0, out_size_1, x_size_0, x_size_1, out_stride_0,
2222

2323
def fn(x):
2424
out = torch.empty_like(x)
25-
v = 15
2625
_BLOCK_SIZE_0 = 32
2726
_BLOCK_SIZE_1 = 32
2827
_fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),](x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
2928
return out
3029

3130
def _fn_make_precompiler(x):
3231
out = torch.empty_like(x)
33-
v = 15
3432
_BLOCK_SIZE_0 = 32
3533
_BLOCK_SIZE_1 = 32
3634
from helion.runtime.precompile_shim import make_precompiler
@@ -55,14 +53,12 @@ def _fn_kernel(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_
5553

5654
def fn(x):
5755
out = torch.empty_like(x)
58-
v = 4
5956
_BLOCK_SIZE_0_1 = 128
6057
_fn_kernel[triton.cdiv(x.size(0) * x.size(1), _BLOCK_SIZE_0_1), 1, 1](x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0_1, num_warps=4, num_stages=3)
6158
return out
6259

6360
def _fn_make_precompiler(x):
6461
out = torch.empty_like(x)
65-
v = 4
6662
_BLOCK_SIZE_0_1 = 128
6763
from helion.runtime.precompile_shim import make_precompiler
6864
return make_precompiler(_fn_kernel)(x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0_1, num_warps=4, num_stages=3)

test/test_examples.expected

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,7 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
104104
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
105105
out = torch.empty_like(q_view)
106106
sm_scale = 1.0 / math.sqrt(head_dim)
107-
qk_scale = sm_scale * 1.44269504
108107
_BLOCK_SIZE_1 = 128
109-
_RDIM_SIZE_2 = 64
110108
_BLOCK_SIZE_3 = 64
111109
_attention_kernel[64 * triton.cdiv(1024, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
112110
return out.view(q_in.size())
@@ -122,9 +120,7 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
122120
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
123121
out = torch.empty_like(q_view)
124122
sm_scale = 1.0 / math.sqrt(head_dim)
125-
qk_scale = sm_scale * 1.44269504
126123
_BLOCK_SIZE_1 = 128
127-
_RDIM_SIZE_2 = 64
128124
_BLOCK_SIZE_3 = 64
129125
from helion.runtime.precompile_shim import make_precompiler
130126
return make_precompiler(_attention_kernel)(q_view, k_view, v_view, out, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
@@ -208,7 +204,6 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
208204
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
209205
out = torch.empty_like(q_view)
210206
sm_scale = 1.0 / math.sqrt(head_dim)
211-
qk_scale = sm_scale * 1.44269504
212207
_BLOCK_SIZE_1 = 128
213208
_RDIM_SIZE_2 = 64
214209
_BLOCK_SIZE_3 = 16
@@ -226,7 +221,6 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
226221
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
227222
out = torch.empty_like(q_view)
228223
sm_scale = 1.0 / math.sqrt(head_dim)
229-
qk_scale = sm_scale * 1.44269504
230224
_BLOCK_SIZE_1 = 128
231225
_RDIM_SIZE_2 = 64
232226
_BLOCK_SIZE_3 = 16
@@ -303,7 +297,6 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
303297
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
304298
out = torch.empty_like(q_view)
305299
sm_scale = 1.0 / math.sqrt(head_dim)
306-
qk_scale = sm_scale * 1.44269504
307300
_BLOCK_SIZE_1 = 64
308301
_RDIM_SIZE_2 = 64
309302
_BLOCK_SIZE_3 = 64
@@ -321,7 +314,6 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
321314
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
322315
out = torch.empty_like(q_view)
323316
sm_scale = 1.0 / math.sqrt(head_dim)
324-
qk_scale = sm_scale * 1.44269504
325317
_BLOCK_SIZE_1 = 64
326318
_RDIM_SIZE_2 = 64
327319
_BLOCK_SIZE_3 = 64
@@ -1570,17 +1562,13 @@ def _softmax_two_pass_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_s
15701562
def softmax_two_pass(x: torch.Tensor):
15711563
m, n = x.size()
15721564
out = torch.empty_like(x)
1573-
block_size_m = 1
1574-
block_size_n = 128
15751565
_BLOCK_SIZE_1 = 128
15761566
_softmax_two_pass_kernel[m,](x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), n, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
15771567
return out
15781568

15791569
def _softmax_two_pass_make_precompiler(x: torch.Tensor):
15801570
m, n = x.size()
15811571
out = torch.empty_like(x)
1582-
block_size_m = 1
1583-
block_size_n = 128
15841572
_BLOCK_SIZE_1 = 128
15851573
from helion.runtime.precompile_shim import make_precompiler
15861574
return make_precompiler(_softmax_two_pass_kernel)(x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), n, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
@@ -1640,8 +1628,6 @@ def _softmax_two_pass_kernel(x, out, out_size_0, out_size_1, x_size_0, x_size_1,
16401628
def softmax_two_pass(x: torch.Tensor):
16411629
m, n = x.size()
16421630
out = torch.empty_like(x)
1643-
block_size_m = 8
1644-
block_size_n = 64
16451631
_BLOCK_SIZE_0 = 8
16461632
_BLOCK_SIZE_1 = 64
16471633
_softmax_two_pass_kernel[triton.cdiv(m, _BLOCK_SIZE_0),](x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
@@ -1650,8 +1636,6 @@ def softmax_two_pass(x: torch.Tensor):
16501636
def _softmax_two_pass_make_precompiler(x: torch.Tensor):
16511637
m, n = x.size()
16521638
out = torch.empty_like(x)
1653-
block_size_m = 8
1654-
block_size_n = 64
16551639
_BLOCK_SIZE_0 = 8
16561640
_BLOCK_SIZE_1 = 64
16571641
from helion.runtime.precompile_shim import make_precompiler

0 commit comments

Comments
 (0)