Skip to content

Commit 1e483a9

Browse files
authored
Improve DCE by marking math functions as pure (#312)
1 parent 38753a6 commit 1e483a9

File tree

3 files changed

+16
-15
lines changed

3 files changed

+16
-15
lines changed

helion/_compiler/ast_read_writes.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,22 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> None:
197197
def visit_Starred(self, node: ast.Starred) -> None:
198198
self.visit(node.value)
199199

200+
def visit_Call(self, node: ast.Call) -> None:
201+
# Math methods are all pure, so allow them
202+
if not (
203+
isinstance(node.func, ast.Attribute)
204+
and isinstance(node.func.value, ast.Name)
205+
and node.func.value.id == "math"
206+
):
207+
raise _NotPureException
208+
209+
# Recurse into children except for func
210+
for arg in node.args:
211+
self.visit(arg)
212+
213+
for keyword in node.keywords:
214+
self.visit(keyword.value)
215+
200216

201217
def definitely_does_not_have_side_effects(expr: ast.expr) -> bool:
202218
try:

test/test_examples.expected

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def _add_make_precompiler(x: torch.Tensor, y: torch.Tensor):
3737
--- assertExpectedJournal(TestExamples.test_attention_block_pointer)
3838
from __future__ import annotations
3939

40-
import math
4140
import torch
4241
import triton
4342
import triton.language as tl
@@ -103,7 +102,6 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
103102
v_view = v_in.reshape([-1, n_dim, head_dim])
104103
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
105104
out = torch.empty_like(q_view)
106-
sm_scale = 1.0 / math.sqrt(head_dim)
107105
_BLOCK_SIZE_1 = 128
108106
_BLOCK_SIZE_3 = 64
109107
_attention_kernel[64 * triton.cdiv(1024, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
@@ -119,7 +117,6 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
119117
v_view = v_in.reshape([-1, n_dim, head_dim])
120118
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
121119
out = torch.empty_like(q_view)
122-
sm_scale = 1.0 / math.sqrt(head_dim)
123120
_BLOCK_SIZE_1 = 128
124121
_BLOCK_SIZE_3 = 64
125122
from helion.runtime.precompile_shim import make_precompiler
@@ -128,7 +125,6 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
128125
--- assertExpectedJournal(TestExamples.test_attention_dynamic)
129126
from __future__ import annotations
130127

131-
import math
132128
import torch
133129
import triton
134130
import triton.language as tl
@@ -198,7 +194,6 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
198194
v_view = v_in.reshape([-1, n_dim, head_dim])
199195
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
200196
out = torch.empty_like(q_view)
201-
sm_scale = 1.0 / math.sqrt(head_dim)
202197
_BLOCK_SIZE_1 = 32
203198
_RDIM_SIZE_2 = 64
204199
_BLOCK_SIZE_3 = 32
@@ -215,7 +210,6 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
215210
v_view = v_in.reshape([-1, n_dim, head_dim])
216211
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
217212
out = torch.empty_like(q_view)
218-
sm_scale = 1.0 / math.sqrt(head_dim)
219213
_BLOCK_SIZE_1 = 32
220214
_RDIM_SIZE_2 = 64
221215
_BLOCK_SIZE_3 = 32
@@ -225,7 +219,6 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
225219
--- assertExpectedJournal(TestExamples.test_attention_pointer)
226220
from __future__ import annotations
227221

228-
import math
229222
import torch
230223
import triton
231224
import triton.language as tl
@@ -291,7 +284,6 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
291284
v_view = v_in.reshape([-1, n_dim, head_dim])
292285
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
293286
out = torch.empty_like(q_view)
294-
sm_scale = 1.0 / math.sqrt(head_dim)
295287
_BLOCK_SIZE_1 = 64
296288
_RDIM_SIZE_2 = 64
297289
_BLOCK_SIZE_3 = 64
@@ -308,7 +300,6 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
308300
v_view = v_in.reshape([-1, n_dim, head_dim])
309301
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
310302
out = torch.empty_like(q_view)
311-
sm_scale = 1.0 / math.sqrt(head_dim)
312303
_BLOCK_SIZE_1 = 64
313304
_RDIM_SIZE_2 = 64
314305
_BLOCK_SIZE_3 = 64

test/test_tensor_descriptor.expected

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environmen
44
--- assertExpectedJournal(TestTensorDescriptor.test_attention_td_dynamic)
55
from __future__ import annotations
66

7-
import math
87
import torch
98
import helion
109
import triton
@@ -79,7 +78,6 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
7978
v_view = v_in.reshape([-1, n_dim, head_dim])
8079
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
8180
out = torch.empty_like(q_view)
82-
sm_scale = 1.0 / math.sqrt(head_dim)
8381
_BLOCK_SIZE_1 = 16
8482
_BLOCK_SIZE_3 = 16
8583
_attention_kernel[q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
@@ -95,7 +93,6 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
9593
v_view = v_in.reshape([-1, n_dim, head_dim])
9694
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
9795
out = torch.empty_like(q_view)
98-
sm_scale = 1.0 / math.sqrt(head_dim)
9996
_BLOCK_SIZE_1 = 16
10097
_BLOCK_SIZE_3 = 16
10198
from helion.runtime.precompile_shim import make_precompiler
@@ -104,7 +101,6 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
104101
--- assertExpectedJournal(TestTensorDescriptor.test_attention_tensor_descriptor)
105102
from __future__ import annotations
106103

107-
import math
108104
import torch
109105
import helion
110106
import triton
@@ -177,7 +173,6 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
177173
v_view = v_in.reshape([-1, n_dim, head_dim])
178174
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
179175
out = torch.empty_like(q_view)
180-
sm_scale = 1.0 / math.sqrt(head_dim)
181176
_BLOCK_SIZE_1 = 128
182177
_BLOCK_SIZE_3 = 64
183178
_attention_kernel[64 * triton.cdiv(1024, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
@@ -193,7 +188,6 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
193188
v_view = v_in.reshape([-1, n_dim, head_dim])
194189
k_view = k_in.reshape([-1, n_dim, head_dim]).transpose(1, 2)
195190
out = torch.empty_like(q_view)
196-
sm_scale = 1.0 / math.sqrt(head_dim)
197191
_BLOCK_SIZE_1 = 128
198192
_BLOCK_SIZE_3 = 64
199193
from helion.runtime.precompile_shim import make_precompiler

0 commit comments

Comments
 (0)