Skip to content

Commit a137e4d

Browse files
committed
refactor func to be a class hierarchy
1 parent 91cb08f commit a137e4d

File tree

6 files changed

+170
-89
lines changed

6 files changed

+170
-89
lines changed

mlir_utils/_configuration/configuration.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
from base64 import urlsafe_b64encode
66
from importlib.metadata import distribution, packages_distributions
77
from importlib.resources import files
8+
from importlib.resources.readers import MultiplexedPath
89
from pathlib import Path
910

1011
from .module_alias_map import get_meta_path_insertion_index, AliasedModuleFinder
1112

1213
__MLIR_PYTHON_PACKAGE_PREFIX__ = "__MLIR_PYTHON_PACKAGE_PREFIX__"
1314
PACKAGE = __package__.split(".")[0]
1415
PACKAGE_ROOT_PATH = files(PACKAGE)
16+
if isinstance(PACKAGE_ROOT_PATH, MultiplexedPath):
17+
PACKAGE_ROOT_PATH = PACKAGE_ROOT_PATH._paths[0]
1518
DIST = distribution(packages_distributions()[PACKAGE][0])
1619
MLIR_PYTHON_PACKAGE_PREFIX_TOKEN_PATH = (
1720
Path(__file__).parent / __MLIR_PYTHON_PACKAGE_PREFIX__

mlir_utils/dialects/ext/func.py

Lines changed: 102 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -18,95 +18,114 @@
1818
)
1919

2020

21-
def func_base(
22-
FuncOp,
23-
ReturnOp,
24-
CallOp,
25-
sym_visibility=None,
26-
arg_attrs=None,
27-
res_attrs=None,
28-
loc=None,
29-
ip=None,
30-
):
31-
ip = ip or InsertionPoint.current
32-
33-
# if this is set to true then wrapper below won't emit a call op
34-
# it is set below by a def emit fn that is attached to the body_builder
35-
# wrapper; thus you can call wrapped_fn.emit() (i.e., without an operands)
36-
# and the func will be emitted.
37-
_emit = False
38-
39-
def builder_wrapper(body_builder):
40-
@wraps(body_builder)
41-
def wrapper(*call_args):
42-
# TODO(max): implement constexpr ie enable passing constants that skip being
43-
# part of the signature
44-
sig = inspect.signature(body_builder)
45-
implicit_return = sig.return_annotation is inspect._empty
46-
input_types = [p.annotation for p in sig.parameters.values()]
47-
if not (
48-
len(input_types) == len(sig.parameters)
49-
and all(isinstance(t, Type) for t in input_types)
50-
):
51-
input_types = [a.type for a in call_args]
52-
function_type = TypeAttr.get(
53-
FunctionType.get(
54-
inputs=input_types,
55-
results=[] if implicit_return else sig.return_annotation,
56-
)
21+
class FuncOpMeta(type):
22+
def __call__(cls, *args, **kwargs):
23+
cls_obj = cls.__new__(cls)
24+
if len(args) == 1 and len(kwargs) == 0 and inspect.isfunction(args[0]):
25+
return cls.__init__(cls_obj, args[0])
26+
else:
27+
28+
def init_wrapper(f):
29+
cls.__init__(cls_obj, f, *args, **kwargs)
30+
return cls_obj
31+
32+
return lambda f: init_wrapper(f)
33+
34+
35+
class FuncBase(metaclass=FuncOpMeta):
36+
def __init__(
37+
self,
38+
body_builder,
39+
func_op_ctor,
40+
return_op_ctor,
41+
call_op_ctor,
42+
sym_visibility=None,
43+
arg_attrs=None,
44+
res_attrs=None,
45+
loc=None,
46+
ip=None,
47+
):
48+
assert inspect.isfunction(body_builder), body_builder
49+
assert inspect.isclass(func_op_ctor), func_op_ctor
50+
assert inspect.isclass(return_op_ctor), return_op_ctor
51+
assert inspect.isclass(call_op_ctor), call_op_ctor
52+
53+
self.body_builder = body_builder
54+
self.func_name = self.body_builder.__name__
55+
56+
self.func_op_ctor = func_op_ctor
57+
self.return_op_ctor = return_op_ctor
58+
self.call_op_ctor = call_op_ctor
59+
self.sym_visibility = (
60+
StringAttr.get(str(sym_visibility)) if sym_visibility is not None else None
61+
)
62+
self.arg_attrs = arg_attrs
63+
self.res_attrs = res_attrs
64+
self.loc = loc
65+
self.ip = ip or InsertionPoint.current
66+
self.emitted = False
67+
68+
def __str__(self):
69+
return str(f"{self.__class__} {self.__dict__}")
70+
71+
def body_builder_wrapper(self, *call_args):
72+
sig = inspect.signature(self.body_builder)
73+
implicit_return = sig.return_annotation is inspect._empty
74+
input_types = [p.annotation for p in sig.parameters.values()]
75+
if not (
76+
len(input_types) == len(sig.parameters)
77+
and all(isinstance(t, Type) for t in input_types)
78+
):
79+
input_types = [a.type for a in call_args]
80+
function_type = TypeAttr.get(
81+
FunctionType.get(
82+
inputs=input_types,
83+
results=[] if implicit_return else sig.return_annotation,
5784
)
58-
# FuncOp is extended but we do really want the base
59-
func_name = body_builder.__name__
60-
func_op = FuncOp(
61-
func_name,
62-
function_type,
63-
sym_visibility=StringAttr.get(str(sym_visibility))
64-
if sym_visibility is not None
65-
else None,
66-
arg_attrs=arg_attrs,
67-
res_attrs=res_attrs,
68-
loc=loc,
69-
ip=ip,
85+
)
86+
func_op = self.func_op_ctor(
87+
self.func_name,
88+
function_type,
89+
sym_visibility=self.sym_visibility,
90+
arg_attrs=self.arg_attrs,
91+
res_attrs=self.res_attrs,
92+
loc=self.loc,
93+
ip=self.ip,
94+
)
95+
func_op.regions[0].blocks.append(*input_types)
96+
with InsertionPoint(func_op.regions[0].blocks[0]):
97+
results = get_result_or_results(
98+
self.body_builder(*func_op.regions[0].blocks[0].arguments)
7099
)
71-
func_op.regions[0].blocks.append(*input_types)
72-
with InsertionPoint(func_op.regions[0].blocks[0]):
73-
results = get_result_or_results(
74-
body_builder(*func_op.regions[0].blocks[0].arguments)
75-
)
76-
if results is not None:
77-
if isinstance(results, (tuple, list)):
78-
results = list(results)
79-
else:
80-
results = [results]
100+
if results is not None:
101+
if isinstance(results, (tuple, list)):
102+
results = list(results)
81103
else:
82-
results = []
83-
ReturnOp(results)
84-
# Recompute the function type.
85-
return_types = [v.type for v in results]
86-
function_type = FunctionType.get(inputs=input_types, results=return_types)
87-
func_op.attributes["function_type"] = TypeAttr.get(function_type)
88-
89-
if _emit:
90-
return maybe_cast(get_result_or_results(func_op))
104+
results = [results]
91105
else:
92-
call_op = CallOp(
93-
[r.type for r in results],
94-
FlatSymbolRefAttr.get(func_name),
95-
call_args,
96-
)
97-
return maybe_cast(get_result_or_results(call_op))
106+
results = []
107+
self.return_op_ctor(results)
98108

99-
def emit():
100-
nonlocal _emit
101-
_emit = True
102-
wrapper()
109+
return results, input_types, func_op
103110

104-
wrapper.emit = emit
105-
return wrapper
111+
def emit(self):
112+
self.results, input_types, func_op = self.body_builder_wrapper()
113+
return_types = [v.type for v in self.results]
114+
function_type = FunctionType.get(inputs=input_types, results=return_types)
115+
func_op.attributes["function_type"] = TypeAttr.get(function_type)
116+
self.emitted = True
117+
# this is the func op itself (funcs never have a resulting ssa value)
118+
return maybe_cast(get_result_or_results(func_op))
106119

107-
return builder_wrapper
120+
def __call__(self, *call_args):
121+
if not self.emitted:
122+
self.emit()
123+
call_op = CallOp(
124+
[r.type for r in self.results],
125+
FlatSymbolRefAttr.get(self.func_name),
126+
call_args,
127+
)
128+
return maybe_cast(get_result_or_results(call_op))
108129

109130

110-
func = make_maybe_no_args_decorator(
111-
partial(func_base, FuncOp=FuncOp.__base__, ReturnOp=ReturnOp, CallOp=CallOp)
112-
)
131+
func = FuncBase(FuncOp.__base__, ReturnOp, CallOp.__base__)

mlir_utils/dialects/ext/tensor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Union, Tuple, Sequence
33

44
import numpy as np
5-
from mlir.dialects.tensor import EmptyOp
5+
from mlir.dialects.tensor import EmptyOp, GenerateOp
66
from mlir.ir import Type, Value, RankedTensorType, DenseElementsAttr, ShapedType
77

88
from mlir_utils.dialects.ext.arith import ArithValue
@@ -62,7 +62,6 @@ def empty(
6262
shape: Union[list[Union[int, Value]], tuple[Union[int, Value], ...]],
6363
el_type: Type,
6464
) -> "Tensor":
65-
6665
return cls(EmptyOp(shape, el_type).result)
6766

6867

tests/test_func.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from textwrap import dedent
33

44
import pytest
5-
65
from mlir_utils.dialects.ext.arith import constant
76
from mlir_utils.dialects.ext.func import func
87

@@ -20,7 +19,7 @@ def demo_fun1():
2019
return one
2120

2221
assert hasattr(demo_fun1, "emit")
23-
assert inspect.isfunction(demo_fun1.emit)
22+
assert inspect.ismethod(demo_fun1.emit)
2423
demo_fun1.emit()
2524
correct = dedent(
2625
"""\
@@ -33,3 +32,63 @@ def demo_fun1():
3332
"""
3433
)
3534
filecheck(correct, ctx.module)
35+
36+
37+
def test_func_base_meta(ctx: MLIRContext):
38+
print()
39+
40+
@func
41+
def foo1():
42+
one = constant(1)
43+
return one
44+
45+
# print("wrapped foo", foo1)
46+
foo1.emit()
47+
correct = dedent(
48+
"""\
49+
module {
50+
func.func @foo1() -> i64 {
51+
%c1_i64 = arith.constant 1 : i64
52+
return %c1_i64 : i64
53+
}
54+
}
55+
"""
56+
)
57+
filecheck(correct, ctx.module)
58+
59+
foo1()
60+
correct = dedent(
61+
"""\
62+
module {
63+
func.func @foo1() -> i64 {
64+
%c1_i64 = arith.constant 1 : i64
65+
return %c1_i64 : i64
66+
}
67+
%0 = func.call @foo1() : () -> i64
68+
}
69+
"""
70+
)
71+
filecheck(correct, ctx.module)
72+
73+
74+
def test_func_base_meta2(ctx: MLIRContext):
75+
print()
76+
77+
@func
78+
def foo1():
79+
one = constant(1)
80+
return one
81+
82+
foo1()
83+
correct = dedent(
84+
"""\
85+
module {
86+
func.func @foo1() -> i64 {
87+
%c1_i64 = arith.constant 1 : i64
88+
return %c1_i64 : i64
89+
}
90+
%0 = func.call @foo1() : () -> i64
91+
}
92+
"""
93+
)
94+
filecheck(correct, ctx.module)

tests/test_regions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from mlir_utils.dialects.ext.arith import constant
99
from mlir_utils.dialects.ext.func import func
10-
from mlir_utils.dialects.ext.tensor import Tensor, S, rank
10+
from mlir_utils.dialects.ext.tensor import S, rank
1111

1212
# noinspection PyUnresolvedReferences
1313
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext

tests/test_value_caster.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import pytest
2-
from mlir.ir import OpResult
32

43
from mlir_utils.dialects.ext.tensor import S, empty
54
from mlir_utils.dialects.ext.arith import constant
@@ -9,6 +8,8 @@
98
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
109
from mlir_utils.types import f64_t, RankedTensorType
1110

11+
from mlir.ir import OpResult
12+
1213
# needed since the fix isn't defined here nor conftest.py
1314
pytest.mark.usefixtures("ctx")
1415

0 commit comments

Comments
 (0)