Skip to content

Commit 1cd37d6

Browse files
committed
improve location tracking (and test it)
1 parent 49c5f5f commit 1cd37d6

File tree

6 files changed

+160
-33
lines changed

6 files changed

+160
-33
lines changed

mlir_utils/dialects/ext/arith.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,22 @@
1313
_is_index_type,
1414
)
1515
from mlir.ir import (
16+
Attribute,
17+
Context,
18+
DenseElementsAttr,
19+
IndexType,
20+
IntegerAttr,
21+
IntegerType,
22+
Location,
1623
OpView,
1724
Operation,
25+
RankedTensorType,
1826
Type,
1927
Value,
20-
IndexType,
21-
RankedTensorType,
22-
IntegerAttr,
23-
IntegerType,
24-
DenseElementsAttr,
2528
register_attribute_builder,
26-
Context,
27-
Attribute,
2829
)
2930

30-
from mlir_utils.util import get_result_or_results, maybe_cast
31+
from mlir_utils.util import get_result_or_results, maybe_cast, get_user_code_loc
3132

3233
try:
3334
from mlir_utils.dialects.arith import *
@@ -41,6 +42,8 @@ def constant(
4142
value: Union[int, float, bool, np.ndarray],
4243
type: Optional[Type] = None,
4344
index: Optional[bool] = None,
45+
*,
46+
loc: Location = None,
4447
) -> arith_dialect.ConstantOp:
4548
"""Instantiate arith.constant with value `value`.
4649
@@ -56,6 +59,8 @@ def constant(
5659
Returns:
5760
ir.OpView instance that corresponds to instantiated arith.constant op.
5861
"""
62+
if loc is None:
63+
loc = get_user_code_loc()
5964
if index is not None and index:
6065
type = IndexType.get()
6166
if type is None:
@@ -73,8 +78,9 @@ def constant(
7378
value,
7479
type=type,
7580
)
76-
77-
return maybe_cast(get_result_or_results(arith_dialect.ConstantOp(type, value)))
81+
return maybe_cast(
82+
get_result_or_results(arith_dialect.ConstantOp(type, value, loc=loc))
83+
)
7884

7985

8086
class ArithValueMeta(type(Value)):
@@ -217,7 +223,12 @@ def _arith_CmpFPredicateAttr(predicate: str | Attribute, context: Context):
217223

218224

219225
def _binary_op(
220-
lhs: "ArithValue", rhs: "ArithValue", op: str, predicate: str = None
226+
lhs: "ArithValue",
227+
rhs: "ArithValue",
228+
op: str,
229+
predicate: str = None,
230+
*,
231+
loc: Location = None,
221232
) -> "ArithValue":
222233
"""Generic for handling infix binary operator dispatch.
223234
@@ -230,6 +241,8 @@ def _binary_op(
230241
Returns:
231242
Result of binary operation. This will be a handle to an arith(add|sub|mul) op.
232243
"""
244+
if loc is None:
245+
loc = get_user_code_loc()
233246
if not isinstance(rhs, lhs.__class__):
234247
rhs = lhs.__class__(rhs, dtype=lhs.type)
235248

@@ -258,9 +271,9 @@ def _binary_op(
258271
predicate = "s" + predicate
259272
else:
260273
predicate = "u" + predicate
261-
return lhs.__class__(op(predicate, lhs, rhs), dtype=lhs.dtype)
274+
return lhs.__class__(op(predicate, lhs, rhs, loc=loc), dtype=lhs.dtype)
262275
else:
263-
return lhs.__class__(op(lhs, rhs), dtype=lhs.dtype)
276+
return lhs.__class__(op(lhs, rhs, loc=loc), dtype=lhs.dtype)
264277

265278

266279
class ArithValue(Value, metaclass=ArithValueMeta):

mlir_utils/dialects/ext/func.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
TypeAttr,
99
FlatSymbolRefAttr,
1010
Type,
11+
Location,
1112
)
1213

1314
from mlir_utils.util import (
@@ -103,13 +104,16 @@ def emit(self):
103104
# this is the func op itself (funcs never have a resulting ssa value)
104105
return maybe_cast(get_result_or_results(func_op))
105106

106-
def __call__(self, *call_args):
107+
def __call__(self, *call_args, loc: Location = None):
108+
if loc is None:
109+
loc = get_user_code_loc()
107110
if not self.emitted:
108111
self.emit()
109112
call_op = self.call_op_ctor(
110113
[r.type for r in self.results],
111114
FlatSymbolRefAttr.get(self.func_name),
112115
call_args,
116+
loc=loc,
113117
)
114118
return maybe_cast(get_result_or_results(call_op))
115119

mlir_utils/dialects/ext/scf.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import libcst as cst
77
import libcst.matchers as m
88
from bytecode import ConcreteBytecode, ConcreteInstr
9-
from mlir.dialects import scf
9+
from mlir.dialects.scf import IfOp, ForOp
1010
from mlir.ir import InsertionPoint, Value, OpResultList, OpResult
1111

1212
from mlir_utils.ast.canonicalize import (
@@ -24,6 +24,7 @@
2424
maybe_cast,
2525
_update_caller_vars,
2626
get_result_or_results,
27+
get_user_code_loc,
2728
)
2829

2930
logger = logging.getLogger(__name__)
@@ -49,7 +50,9 @@ def _for(
4950
stop = constant(stop, index=True)
5051
if isinstance(step, int):
5152
step = constant(step, index=True)
52-
return scf.ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
53+
if loc is None:
54+
loc = get_user_code_loc()
55+
return ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
5356

5457

5558
for_ = region_op(_for, terminator=yield__)
@@ -91,7 +94,9 @@ def _if(cond, results_=None, *, has_else=False, loc=None, ip=None):
9194
results_ = []
9295
if results_:
9396
has_else = True
94-
return scf.IfOp(cond, results_, hasElse=has_else, loc=loc, ip=ip)
97+
if loc is None:
98+
loc = get_user_code_loc()
99+
return IfOp(cond, results_, hasElse=has_else, loc=loc, ip=ip)
95100

96101

97102
if_ = region_op(_if, terminator=yield__)
@@ -100,7 +105,7 @@ def _if(cond, results_=None, *, has_else=False, loc=None, ip=None):
100105

101106

102107
class IfStack:
103-
__current_if_op: list[scf.IfOp] = []
108+
__current_if_op: list[IfOp] = []
104109
__if_ip: list[InsertionPoint] = []
105110

106111
@staticmethod
@@ -423,6 +428,7 @@ def patch_bytecode(self, code: ConcreteBytecode, f):
423428
f.__globals__[end_if.__name__] = end_if
424429
f.__globals__[stack_if.__name__] = stack_if
425430
f.__globals__[stack_yield.__name__] = stack_yield
431+
f.__globals__[yield_.__name__] = yield_
426432
f.__globals__["_placeholder_opaque_t"] = _placeholder_opaque_t
427433
return code
428434

mlir_utils/util.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ def builder_wrapper(body_builder):
134134
f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}"
135135
)
136136

137-
op.regions[0].blocks.append(*types)
137+
arg_locs = [get_user_code_loc()] * len(sig.parameters)
138+
op.regions[0].blocks.append(*types, arg_locs=arg_locs)
138139
with InsertionPoint(op.regions[0].blocks[0]):
139140
results = body_builder(
140141
*[maybe_cast(a) for a in op.regions[0].blocks[0].arguments]
@@ -209,17 +210,9 @@ def get_user_code_loc():
209210
mlir_utis_root_path = Path(mlir_utils.__path__[0])
210211

211212
prev_frame = inspect.currentframe().f_back
212-
stack = traceback.StackSummary.extract(traceback.walk_stack(prev_frame))
213-
214-
user_frame = next(
215-
(
216-
fr
217-
for fr in stack
218-
if not Path(fr.filename).is_relative_to(mlir_utis_root_path)
219-
),
220-
None,
213+
while Path(prev_frame.f_code.co_filename).is_relative_to(mlir_utis_root_path):
214+
prev_frame = prev_frame.f_back
215+
frame_info = inspect.getframeinfo(prev_frame)
216+
return Location.file(
217+
frame_info.filename, frame_info.lineno, frame_info.positions.col_offset
221218
)
222-
if user_frame is None:
223-
warnings.warn("couldn't find user code frame")
224-
return
225-
return Location.file(user_frame.filename, user_frame.lineno, user_frame.colno or 0)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "mlir-python-utils"
3-
version = "0.0.2"
3+
version = "0.0.3"
44
description = "The missing pieces (as far as boilerplate reduction goes) of the upstream MLIR python bindings."
55
requires-python = ">=3.11"
66
license = { file = "LICENSE" }

tests/test_location_tracking.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from pathlib import Path
2+
from textwrap import dedent
3+
from os import sep
4+
import pytest
5+
6+
from mlir_utils.ast.canonicalize import canonicalize
7+
from mlir_utils.dialects.ext.arith import constant
8+
from mlir_utils.dialects.ext.scf import (
9+
canonicalizer,
10+
stack_if,
11+
)
12+
from mlir_utils.dialects.ext.tensor import S
13+
from mlir_utils.dialects.tensor import generate, yield_ as tensor_yield, rank
14+
15+
# noinspection PyUnresolvedReferences
16+
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
17+
from mlir_utils.types import f64_t, index_t, tensor_t
18+
19+
# needed since the fix isn't defined here nor conftest.py
20+
pytest.mark.usefixtures("ctx")
21+
22+
23+
THIS_DIR = str(Path(__file__).parent.absolute())
24+
25+
26+
def get_asm(operation):
27+
return operation.get_asm(enable_debug_info=True, pretty_debug_info=True).replace(
28+
THIS_DIR, "THIS_DIR"
29+
)
30+
31+
32+
def test_if_replace_yield_5(ctx: MLIRContext):
33+
@canonicalize(using=canonicalizer)
34+
def iffoo():
35+
one = constant(1.0)
36+
two = constant(2.0)
37+
if res := stack_if(one < two, (f64_t, f64_t, f64_t)):
38+
three = constant(3.0)
39+
yield three, three, three
40+
else:
41+
four = constant(4.0)
42+
yield four, four, four
43+
return
44+
45+
iffoo()
46+
ctx.module.operation.verify()
47+
correct = dedent(
48+
f"""\
49+
module {{
50+
%cst = arith.constant 1.000000e+00 : f64 THIS_DIR{sep}test_location_tracking.py:35:10
51+
%cst_0 = arith.constant 2.000000e+00 : f64 THIS_DIR{sep}test_location_tracking.py:36:10
52+
%0 = arith.cmpf olt, %cst, %cst_0 : f64 THIS_DIR{sep}test_location_tracking.py:37:23
53+
%1:3 = scf.if %0 -> (f64, f64, f64) {{
54+
%cst_1 = arith.constant 3.000000e+00 : f64 THIS_DIR{sep}test_location_tracking.py:38:16
55+
scf.yield %cst_1, %cst_1, %cst_1 : f64, f64, f64 THIS_DIR{sep}test_location_tracking.py:39:8
56+
}} else {{
57+
%cst_1 = arith.constant 4.000000e+00 : f64 THIS_DIR{sep}test_location_tracking.py:41:24
58+
scf.yield %cst_1, %cst_1, %cst_1 : f64, f64, f64 THIS_DIR{sep}test_location_tracking.py:42:8
59+
}} THIS_DIR{sep}test_location_tracking.py:37:14
60+
}} [unknown]
61+
#loc = [unknown]
62+
#loc1 = THIS_DIR{sep}test_location_tracking.py:35:10
63+
#loc2 = THIS_DIR{sep}test_location_tracking.py:36:10
64+
#loc3 = THIS_DIR{sep}test_location_tracking.py:37:23
65+
#loc4 = THIS_DIR{sep}test_location_tracking.py:37:14
66+
#loc5 = THIS_DIR{sep}test_location_tracking.py:38:16
67+
#loc6 = THIS_DIR{sep}test_location_tracking.py:39:8
68+
#loc7 = THIS_DIR{sep}test_location_tracking.py:41:24
69+
#loc8 = THIS_DIR{sep}test_location_tracking.py:42:8
70+
"""
71+
)
72+
asm = get_asm(ctx.module.operation)
73+
filecheck(correct, asm)
74+
75+
76+
def test_block_args(ctx: MLIRContext):
77+
one = constant(1, index_t)
78+
two = constant(2, index_t)
79+
80+
@generate(tensor_t(S, 3, S, f64_t), dynamic_extents=[one, two])
81+
def demo_fun1(i: index_t, j: index_t, k: index_t):
82+
one = constant(1.0)
83+
tensor_yield(one)
84+
85+
r = rank(demo_fun1)
86+
87+
ctx.module.operation.verify()
88+
89+
correct = dedent(
90+
f"""\
91+
#loc3 = THIS_DIR{sep}test_location_tracking.py:80:5
92+
module {{
93+
%c1 = arith.constant 1 : index THIS_DIR{sep}test_location_tracking.py:77:10
94+
%c2 = arith.constant 2 : index THIS_DIR{sep}test_location_tracking.py:78:10
95+
%generated = tensor.generate %c1, %c2 {{
96+
^bb0(%arg0: index THIS_DIR{sep}test_location_tracking.py:80:5, %arg1: index THIS_DIR{sep}test_location_tracking.py:80:5, %arg2: index THIS_DIR{sep}test_location_tracking.py:80:5):
97+
%cst = arith.constant 1.000000e+00 : f64 THIS_DIR{sep}test_location_tracking.py:82:14
98+
tensor.yield %cst : f64 THIS_DIR{sep}test_location_tracking.py:83:8
99+
}} : tensor<?x3x?xf64> THIS_DIR{sep}test_location_tracking.py:80:5
100+
%rank = tensor.rank %generated : tensor<?x3x?xf64> THIS_DIR{sep}test_location_tracking.py:85:8
101+
}} [unknown]
102+
#loc = [unknown]
103+
#loc1 = THIS_DIR{sep}test_location_tracking.py:77:10
104+
#loc2 = THIS_DIR{sep}test_location_tracking.py:78:10
105+
#loc4 = THIS_DIR{sep}test_location_tracking.py:82:14
106+
#loc5 = THIS_DIR{sep}test_location_tracking.py:83:8
107+
#loc6 = THIS_DIR{sep}test_location_tracking.py:85:8
108+
"""
109+
)
110+
asm = get_asm(ctx.module.operation)
111+
filecheck(correct, asm)

0 commit comments

Comments
 (0)