Skip to content

Commit 70dbd55

Browse files
committed
- refactor func
- fix trampolines casing bug - fix configuration again - warn about TypeID instead of crash
1 parent 074afc8 commit 70dbd55

File tree

10 files changed

+183
-30
lines changed

10 files changed

+183
-30
lines changed

README.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,15 @@ Workaround is to delete the prefix token before configuring, like so:
5959

6060
```shell
6161
rm /home/mlevental/dev_projects/mlir_utils/mlir_utils/_configuration/__MLIR_PYTHON_PACKAGE_PREFIX__ && configure-mlir-python-utils mlir
62-
```
62+
```
63+
64+
## Gotchas
65+
66+
There's a `DefaultContext` created when this package is loaded. If you have weird things happen like
67+
68+
```
69+
E error: unknown: 'arith.constant' op requires attribute 'value'
70+
E note: unknown: see current operation: %0 = "arith.constant"() {value = 64 : i32} : () -> i32
71+
```
72+
73+
which looks patently insane (because `value` is in fact there as an attribute), then you have a `Context`s problem.

mlir_utils/_configuration/configuration.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ def _get_mlir_package_prefix():
5252

5353
def alias_upstream_bindings():
5454
if mlir_python_package_prefix := _get_mlir_package_prefix():
55+
# check if valid package/module
56+
try:
57+
_host_bindings_mlir = __import__(f"{mlir_python_package_prefix}._mlir_libs")
58+
except (ImportError, ModuleNotFoundError) as e:
59+
print(f"couldn't import {mlir_python_package_prefix=} due to: {e}")
60+
raise e
61+
5562
sys.meta_path.insert(
5663
get_meta_path_insertion_index(),
5764
AliasedModuleFinder({"mlir": mlir_python_package_prefix}),

mlir_utils/_configuration/generate_trampolines.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,6 @@ def generate_op_trampoline(op_class):
7272
for a in args.args:
7373
a.arg = inflection.underscore(a.arg).lower()
7474

75-
for k in args.kwonlyargs:
76-
k.arg = inflection.underscore(k.arg).lower()
77-
78-
keywords = [
79-
ast.keyword(k.arg, ast.Name(k.arg))
80-
for k, d in zip(args.kwonlyargs, args.kw_defaults)
81-
]
82-
8375
fun_name = op_class.OPERATION_NAME.split(".")[-1].replace("-", "_")
8476
if keyword.iskeyword(fun_name):
8577
fun_name = fun_name + "_"
@@ -88,6 +80,11 @@ def generate_op_trampoline(op_class):
8880
if len(args.args) == 1 and args.args[0].arg == "results_":
8981
args.defaults.append(ast.Constant(None))
9082
body += [ast.parse("results_ = results_ or []").body[0]]
83+
84+
keywords = [
85+
ast.keyword(k.arg, ast.Name(inflection.underscore(k.arg).lower()))
86+
for k, d in zip(args.kwonlyargs, args.kw_defaults)
87+
]
9188
if (
9289
hasattr(op_class, "_ODS_REGIONS")
9390
and op_class._ODS_REGIONS[0] == 1
@@ -103,6 +100,9 @@ def generate_op_trampoline(op_class):
103100
).body[0]
104101
]
105102

103+
for k in args.kwonlyargs:
104+
k.arg = inflection.underscore(k.arg).lower()
105+
106106
args = copy.deepcopy(args)
107107
oper_finder = FindOperands()
108108
oper_finder.visit(init_fn)

mlir_utils/context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import mlir.ir
66

7+
from mlir_utils import DefaultContext
8+
79

810
@dataclass
911
class MLIRContext:
@@ -17,7 +19,7 @@ def __str__(self):
1719
@contextmanager
1820
def mlir_mod_ctx(
1921
src: Optional[str] = None,
20-
context: mlir.ir.Context = None,
22+
context: mlir.ir.Context = DefaultContext,
2123
location: mlir.ir.Location = None,
2224
allow_unregistered_dialects=False,
2325
) -> MLIRContext:

mlir_utils/dialects/ext/arith.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@
1919
Value,
2020
IndexType,
2121
RankedTensorType,
22+
IntegerAttr,
23+
IntegerType,
2224
DenseElementsAttr,
25+
register_attribute_builder,
26+
Context,
27+
Attribute,
2328
)
2429

2530
from mlir_utils.dialects.util import get_result_or_results, maybe_cast
@@ -244,3 +249,53 @@ def isinstance(other: Value):
244249
or _is_index_type(other.type)
245250
or _is_complex_type(other.type)
246251
)
252+
253+
254+
@register_attribute_builder("Arith_CmpIPredicateAttr")
255+
def _arith_CmpIPredicateAttr(predicate: str | Attribute, context: Context):
256+
predicates = {
257+
"eq": 0,
258+
"ne": 1,
259+
"slt": 2,
260+
"sle": 3,
261+
"sgt": 4,
262+
"sge": 5,
263+
"ult": 6,
264+
"ule": 7,
265+
"ugt": 8,
266+
"uge": 9,
267+
}
268+
if isinstance(predicate, Attribute):
269+
return predicate
270+
assert predicate in predicates, f"predicate {predicate} not in predicates"
271+
return IntegerAttr.get(
272+
IntegerType.get_signless(64, context=context), predicates[predicate]
273+
)
274+
275+
276+
@register_attribute_builder("Arith_CmpFPredicateAttr")
277+
def _arith_CmpFPredicateAttr(predicate: str | Attribute, context: Context):
278+
predicates = {
279+
"false": 0,
280+
"oeq": 1,
281+
"ogt": 2,
282+
"oge": 3,
283+
"olt": 4,
284+
"ole": 5,
285+
"one": 6,
286+
"ord": 7,
287+
"ueq": 8,
288+
"ugt": 9,
289+
"uge": 10,
290+
"ult": 11,
291+
"ule": 12,
292+
"une": 13,
293+
"uno": 14,
294+
"true": 15,
295+
}
296+
if isinstance(predicate, Attribute):
297+
return predicate
298+
assert predicate in predicates, f"predicate {predicate} not in predicates"
299+
return IntegerAttr.get(
300+
IntegerType.get_signless(64, context=context), predicates[predicate]
301+
)

mlir_utils/dialects/ext/func.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import inspect
2-
from functools import wraps
2+
from functools import wraps, partial
33

44
from mlir.dialects.func import FuncOp, ReturnOp, CallOp
55
from mlir.ir import (
@@ -8,24 +8,45 @@
88
StringAttr,
99
TypeAttr,
1010
FlatSymbolRefAttr,
11+
Type,
1112
)
1213

1314
from mlir_utils.dialects.util import (
1415
get_result_or_results,
1516
make_maybe_no_args_decorator,
17+
maybe_cast,
1618
)
1719

1820

19-
@make_maybe_no_args_decorator
20-
def func(sym_visibility=None, arg_attrs=None, res_attrs=None, loc=None, ip=None):
21+
def func_base(
22+
FuncOp,
23+
ReturnOp,
24+
CallOp,
25+
sym_visibility=None,
26+
arg_attrs=None,
27+
res_attrs=None,
28+
loc=None,
29+
ip=None,
30+
):
2131
ip = ip or InsertionPoint.current
2232

33+
# if this is set to true then wrapper below won't emit a call op
34+
# it is set below by a def emit fn that is attached to the body_builder
35+
# wrapper; thus you can call wrapped_fn.emit() (i.e., without an operands)
36+
# and the func will be emitted.
37+
_emit = False
38+
2339
def builder_wrapper(body_builder):
2440
@wraps(body_builder)
2541
def wrapper(*call_args):
2642
sig = inspect.signature(body_builder)
2743
implicit_return = sig.return_annotation is inspect._empty
28-
input_types = [a.type for a in call_args]
44+
input_types = [p.annotation for p in sig.parameters.values()]
45+
if not (
46+
len(input_types) == len(sig.parameters)
47+
and all(isinstance(t, Type) for t in input_types)
48+
):
49+
input_types = [a.type for a in call_args]
2950
function_type = TypeAttr.get(
3051
FunctionType.get(
3152
inputs=input_types,
@@ -34,7 +55,7 @@ def wrapper(*call_args):
3455
)
3556
# FuncOp is extended but we do really want the base
3657
func_name = body_builder.__name__
37-
func_op = FuncOp.__base__(
58+
func_op = FuncOp(
3859
func_name,
3960
function_type,
4061
sym_visibility=StringAttr.get(str(sym_visibility))
@@ -45,7 +66,7 @@ def wrapper(*call_args):
4566
loc=loc,
4667
ip=ip,
4768
)
48-
func_op.regions[0].blocks.append(*[a.type for a in call_args])
69+
func_op.regions[0].blocks.append(*input_types)
4970
with InsertionPoint(func_op.regions[0].blocks[0]):
5071
results = get_result_or_results(
5172
body_builder(*func_op.regions[0].blocks[0].arguments)
@@ -63,14 +84,27 @@ def wrapper(*call_args):
6384
function_type = FunctionType.get(inputs=input_types, results=return_types)
6485
func_op.attributes["function_type"] = TypeAttr.get(function_type)
6586

66-
call_op = CallOp(
67-
[r.type for r in results], FlatSymbolRefAttr.get(func_name), call_args
68-
)
69-
if results is None:
70-
return func_op
71-
return get_result_or_results(call_op)
87+
if _emit:
88+
return maybe_cast(get_result_or_results(func_op))
89+
else:
90+
call_op = CallOp(
91+
[r.type for r in results],
92+
FlatSymbolRefAttr.get(func_name),
93+
call_args,
94+
)
95+
return maybe_cast(get_result_or_results(call_op))
96+
97+
def emit():
98+
nonlocal _emit
99+
_emit = True
100+
wrapper()
72101

73-
# wrapper.op = op
102+
wrapper.emit = emit
74103
return wrapper
75104

76105
return builder_wrapper
106+
107+
108+
func = make_maybe_no_args_decorator(
109+
partial(func_base, FuncOp=FuncOp.__base__, ReturnOp=ReturnOp, CallOp=CallOp)
110+
)

mlir_utils/dialects/util.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
import ctypes
22
import inspect
3+
import warnings
34
from collections import defaultdict
45
from functools import wraps
56
from typing import Callable
67

8+
import mlir
79
from mlir.dialects._ods_common import get_op_result_or_value, get_op_results_or_values
8-
from mlir.ir import InsertionPoint, Value, Type, TypeID
10+
from mlir.ir import InsertionPoint, Value, Type
11+
12+
try:
13+
from mlir.ir import TypeID
14+
except ImportError:
15+
warnings.warn(
16+
f"TypeID not supported by {mlir=}; value casting won't work correctly"
17+
)
18+
TypeID = object
919

1020

1121
def get_result_or_results(op):

tests/test_func.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import inspect
2+
from textwrap import dedent
3+
4+
import pytest
5+
6+
from mlir_utils.dialects.ext.arith import constant
7+
from mlir_utils.dialects.ext.func import func
8+
9+
# noinspection PyUnresolvedReferences
10+
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
11+
12+
# needed since the fix isn't defined here nor conftest.py
13+
pytest.mark.usefixtures("ctx")
14+
15+
16+
def test_emit(ctx: MLIRContext):
17+
@func
18+
def demo_fun1():
19+
one = constant(1)
20+
return one
21+
22+
assert hasattr(demo_fun1, "emit")
23+
assert inspect.isfunction(demo_fun1.emit)
24+
demo_fun1.emit()
25+
correct = dedent(
26+
"""\
27+
module {
28+
func.func @demo_fun1() -> i64 {
29+
%c1_i64 = arith.constant 1 : i64
30+
return %c1_i64 : i64
31+
}
32+
}
33+
"""
34+
)
35+
filecheck(correct, ctx.module)

tests/test_operator_overloading.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
def test_tensor_arithmetic(ctx: MLIRContext):
1717
print()
18-
one = constant(1, index_t)
18+
one = constant(1)
1919
assert isinstance(one, Scalar)
20-
two = constant(2, index_t)
20+
two = constant(2)
2121
assert isinstance(two, Scalar)
2222
three = one + two
2323
assert isinstance(three, Scalar)
@@ -34,9 +34,9 @@ def test_tensor_arithmetic(ctx: MLIRContext):
3434
dedent(
3535
"""\
3636
module {
37-
%c1 = arith.constant 1 : index
38-
%c2 = arith.constant 2 : index
39-
%0 = arith.addi %c1, %c2 : index
37+
%c1_i64 = arith.constant 1 : i64
38+
%c2_i64 = arith.constant 2 : i64
39+
%0 = arith.addi %c1_i64, %c2_i64 : i64
4040
%1 = tensor.empty() : tensor<10x10x10xf64>
4141
%2 = tensor.empty() : tensor<10x10x10xf64>
4242
%3 = arith.addf %1, %2 : tensor<10x10x10xf64>

tests/test_value_caster.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def test_caster_registration(ctx: MLIRContext):
1919
assert repr(ten) == "Tensor(%0, tensor<?x3x?xf64>)"
2020

2121
def dummy_caster(val):
22-
print(val)
2322
return val
2423

2524
register_value_caster(RankedTensorType.static_typeid, dummy_caster)

0 commit comments

Comments
 (0)