|
18 | 18 | )
|
19 | 19 |
|
20 | 20 |
|
21 |
| -def func_base( |
22 |
| - FuncOp, |
23 |
| - ReturnOp, |
24 |
| - CallOp, |
25 |
| - sym_visibility=None, |
26 |
| - arg_attrs=None, |
27 |
| - res_attrs=None, |
28 |
| - loc=None, |
29 |
| - ip=None, |
30 |
| -): |
31 |
| - ip = ip or InsertionPoint.current |
32 |
| - |
33 |
| - # if this is set to true then wrapper below won't emit a call op |
34 |
| - # it is set below by a def emit fn that is attached to the body_builder |
35 |
| - # wrapper; thus you can call wrapped_fn.emit() (i.e., without an operands) |
36 |
| - # and the func will be emitted. |
37 |
| - _emit = False |
38 |
| - |
39 |
| - def builder_wrapper(body_builder): |
40 |
| - @wraps(body_builder) |
41 |
| - def wrapper(*call_args): |
42 |
| - # TODO(max): implement constexpr ie enable passing constants that skip being |
43 |
| - # part of the signature |
44 |
| - sig = inspect.signature(body_builder) |
45 |
| - implicit_return = sig.return_annotation is inspect._empty |
46 |
| - input_types = [p.annotation for p in sig.parameters.values()] |
47 |
| - if not ( |
48 |
| - len(input_types) == len(sig.parameters) |
49 |
| - and all(isinstance(t, Type) for t in input_types) |
50 |
| - ): |
51 |
| - input_types = [a.type for a in call_args] |
52 |
| - function_type = TypeAttr.get( |
53 |
| - FunctionType.get( |
54 |
| - inputs=input_types, |
55 |
| - results=[] if implicit_return else sig.return_annotation, |
56 |
| - ) |
| 21 | +class FuncOpMeta(type): |
| 22 | + def __call__(cls, *args, **kwargs): |
| 23 | + cls_obj = cls.__new__(cls) |
| 24 | + if len(args) == 1 and len(kwargs) == 0 and inspect.isfunction(args[0]): |
| 25 | + return cls.__init__(cls_obj, args[0]) |
| 26 | + else: |
| 27 | + |
| 28 | + def init_wrapper(f): |
| 29 | + cls.__init__(cls_obj, f, *args, **kwargs) |
| 30 | + return cls_obj |
| 31 | + |
| 32 | + return lambda f: init_wrapper(f) |
| 33 | + |
| 34 | + |
| 35 | +class FuncBase(metaclass=FuncOpMeta): |
| 36 | + def __init__( |
| 37 | + self, |
| 38 | + body_builder, |
| 39 | + func_op_ctor, |
| 40 | + return_op_ctor, |
| 41 | + call_op_ctor, |
| 42 | + sym_visibility=None, |
| 43 | + arg_attrs=None, |
| 44 | + res_attrs=None, |
| 45 | + loc=None, |
| 46 | + ip=None, |
| 47 | + ): |
| 48 | + assert inspect.isfunction(body_builder), body_builder |
| 49 | + assert inspect.isclass(func_op_ctor), func_op_ctor |
| 50 | + assert inspect.isclass(return_op_ctor), return_op_ctor |
| 51 | + assert inspect.isclass(call_op_ctor), call_op_ctor |
| 52 | + |
| 53 | + self.body_builder = body_builder |
| 54 | + self.func_name = self.body_builder.__name__ |
| 55 | + |
| 56 | + self.func_op_ctor = func_op_ctor |
| 57 | + self.return_op_ctor = return_op_ctor |
| 58 | + self.call_op_ctor = call_op_ctor |
| 59 | + self.sym_visibility = ( |
| 60 | + StringAttr.get(str(sym_visibility)) if sym_visibility is not None else None |
| 61 | + ) |
| 62 | + self.arg_attrs = arg_attrs |
| 63 | + self.res_attrs = res_attrs |
| 64 | + self.loc = loc |
| 65 | + self.ip = ip or InsertionPoint.current |
| 66 | + self.emitted = False |
| 67 | + |
| 68 | + def __str__(self): |
| 69 | + return str(f"{self.__class__} {self.__dict__}") |
| 70 | + |
| 71 | + def body_builder_wrapper(self, *call_args): |
| 72 | + sig = inspect.signature(self.body_builder) |
| 73 | + implicit_return = sig.return_annotation is inspect._empty |
| 74 | + input_types = [p.annotation for p in sig.parameters.values()] |
| 75 | + if not ( |
| 76 | + len(input_types) == len(sig.parameters) |
| 77 | + and all(isinstance(t, Type) for t in input_types) |
| 78 | + ): |
| 79 | + input_types = [a.type for a in call_args] |
| 80 | + function_type = TypeAttr.get( |
| 81 | + FunctionType.get( |
| 82 | + inputs=input_types, |
| 83 | + results=[] if implicit_return else sig.return_annotation, |
57 | 84 | )
|
58 |
| - # FuncOp is extended but we do really want the base |
59 |
| - func_name = body_builder.__name__ |
60 |
| - func_op = FuncOp( |
61 |
| - func_name, |
62 |
| - function_type, |
63 |
| - sym_visibility=StringAttr.get(str(sym_visibility)) |
64 |
| - if sym_visibility is not None |
65 |
| - else None, |
66 |
| - arg_attrs=arg_attrs, |
67 |
| - res_attrs=res_attrs, |
68 |
| - loc=loc, |
69 |
| - ip=ip, |
| 85 | + ) |
| 86 | + func_op = self.func_op_ctor( |
| 87 | + self.func_name, |
| 88 | + function_type, |
| 89 | + sym_visibility=self.sym_visibility, |
| 90 | + arg_attrs=self.arg_attrs, |
| 91 | + res_attrs=self.res_attrs, |
| 92 | + loc=self.loc, |
| 93 | + ip=self.ip, |
| 94 | + ) |
| 95 | + func_op.regions[0].blocks.append(*input_types) |
| 96 | + with InsertionPoint(func_op.regions[0].blocks[0]): |
| 97 | + results = get_result_or_results( |
| 98 | + self.body_builder(*func_op.regions[0].blocks[0].arguments) |
70 | 99 | )
|
71 |
| - func_op.regions[0].blocks.append(*input_types) |
72 |
| - with InsertionPoint(func_op.regions[0].blocks[0]): |
73 |
| - results = get_result_or_results( |
74 |
| - body_builder(*func_op.regions[0].blocks[0].arguments) |
75 |
| - ) |
76 |
| - if results is not None: |
77 |
| - if isinstance(results, (tuple, list)): |
78 |
| - results = list(results) |
79 |
| - else: |
80 |
| - results = [results] |
| 100 | + if results is not None: |
| 101 | + if isinstance(results, (tuple, list)): |
| 102 | + results = list(results) |
81 | 103 | else:
|
82 |
| - results = [] |
83 |
| - ReturnOp(results) |
84 |
| - # Recompute the function type. |
85 |
| - return_types = [v.type for v in results] |
86 |
| - function_type = FunctionType.get(inputs=input_types, results=return_types) |
87 |
| - func_op.attributes["function_type"] = TypeAttr.get(function_type) |
88 |
| - |
89 |
| - if _emit: |
90 |
| - return maybe_cast(get_result_or_results(func_op)) |
| 104 | + results = [results] |
91 | 105 | else:
|
92 |
| - call_op = CallOp( |
93 |
| - [r.type for r in results], |
94 |
| - FlatSymbolRefAttr.get(func_name), |
95 |
| - call_args, |
96 |
| - ) |
97 |
| - return maybe_cast(get_result_or_results(call_op)) |
| 106 | + results = [] |
| 107 | + self.return_op_ctor(results) |
98 | 108 |
|
99 |
| - def emit(): |
100 |
| - nonlocal _emit |
101 |
| - _emit = True |
102 |
| - wrapper() |
| 109 | + return results, input_types, func_op |
103 | 110 |
|
104 |
| - wrapper.emit = emit |
105 |
| - return wrapper |
| 111 | + def emit(self): |
| 112 | + self.results, input_types, func_op = self.body_builder_wrapper() |
| 113 | + return_types = [v.type for v in self.results] |
| 114 | + function_type = FunctionType.get(inputs=input_types, results=return_types) |
| 115 | + func_op.attributes["function_type"] = TypeAttr.get(function_type) |
| 116 | + self.emitted = True |
| 117 | + # this is the func op itself (funcs never have a resulting ssa value) |
| 118 | + return maybe_cast(get_result_or_results(func_op)) |
106 | 119 |
|
107 |
| - return builder_wrapper |
| 120 | + def __call__(self, *call_args): |
| 121 | + if not self.emitted: |
| 122 | + self.emit() |
| 123 | + call_op = CallOp( |
| 124 | + [r.type for r in self.results], |
| 125 | + FlatSymbolRefAttr.get(self.func_name), |
| 126 | + call_args, |
| 127 | + ) |
| 128 | + return maybe_cast(get_result_or_results(call_op)) |
108 | 129 |
|
109 | 130 |
|
110 |
| -func = make_maybe_no_args_decorator( |
111 |
| - partial(func_base, FuncOp=FuncOp.__base__, ReturnOp=ReturnOp, CallOp=CallOp) |
112 |
| -) |
| 131 | +func = FuncBase(FuncOp.__base__, ReturnOp, CallOp.__base__) |
0 commit comments