Skip to content

Commit 1a8a998

Browse files
committed
refactor coerce and constant
1 parent 7a38d5d commit 1a8a998

File tree

12 files changed

+835
-701
lines changed

12 files changed

+835
-701
lines changed

mlir_utils/dialects/ext/arith.py

Lines changed: 89 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numpy as np
88
from mlir.dialects import arith as arith_dialect
9+
from mlir.dialects import complex as complex_dialect
910
from mlir.dialects._arith_ops_ext import _is_integer_like_type
1011
from mlir.dialects._ods_common import get_op_result_or_value
1112
from mlir.dialects.linalg.opdsl.lang.emitter import (
@@ -19,6 +20,7 @@
1920
Context,
2021
DenseElementsAttr,
2122
IndexType,
23+
InsertionPoint,
2224
IntegerAttr,
2325
IntegerType,
2426
Location,
@@ -28,9 +30,20 @@
2830
Type,
2931
Value,
3032
register_attribute_builder,
33+
ComplexType,
34+
BF16Type,
35+
F16Type,
36+
F32Type,
37+
F64Type,
38+
FloatAttr,
3139
)
3240

33-
from mlir_utils.util import get_result_or_results, maybe_cast, get_user_code_loc
41+
from mlir_utils.util import (
42+
get_result_or_results,
43+
maybe_cast,
44+
get_user_code_loc,
45+
register_value_caster,
46+
)
3447

3548
try:
3649
from mlir_utils.dialects.arith import *
@@ -46,7 +59,8 @@ def constant(
4659
index: Optional[bool] = None,
4760
*,
4861
loc: Location = None,
49-
) -> arith_dialect.ConstantOp:
62+
ip: InsertionPoint = None,
63+
) -> Value:
5064
"""Instantiate arith.constant with value `value`.
5165
5266
Args:
@@ -67,21 +81,62 @@ def constant(
6781
type = IndexType.get()
6882
if type is None:
6983
type = infer_mlir_type(value)
70-
elif RankedTensorType.isinstance(type) and isinstance(value, (int, float, bool)):
84+
85+
assert type is not None
86+
87+
if _is_complex_type(type):
88+
value = complex(value)
89+
return maybe_cast(
90+
get_result_or_results(
91+
complex_dialect.ConstantOp(
92+
type,
93+
list(
94+
map(
95+
lambda x: FloatAttr.get(type.element_type, x),
96+
[value.real, value.imag],
97+
)
98+
),
99+
loc=loc,
100+
ip=ip,
101+
)
102+
)
103+
)
104+
105+
if _is_floating_point_type(type) and not isinstance(value, np.ndarray):
106+
value = float(value)
107+
108+
if RankedTensorType.isinstance(type) and isinstance(value, (int, float, bool)):
71109
ranked_tensor_type = RankedTensorType(type)
72-
value = np.ones(
110+
value = np.full(
73111
ranked_tensor_type.shape,
112+
value,
74113
dtype=mlir_type_to_np_dtype(ranked_tensor_type.element_type),
75114
)
76-
assert type is not None
77115

78116
if isinstance(value, np.ndarray):
79117
value = DenseElementsAttr.get(
80118
value,
81119
type=type,
82120
)
121+
83122
return maybe_cast(
84-
get_result_or_results(arith_dialect.ConstantOp(type, value, loc=loc))
123+
get_result_or_results(arith_dialect.ConstantOp(type, value, loc=loc, ip=ip))
124+
)
125+
126+
127+
def index_cast(
128+
value: Value,
129+
*,
130+
to: Type = None,
131+
loc: Location = None,
132+
ip: InsertionPoint = None,
133+
) -> Value:
134+
if loc is None:
135+
loc = get_user_code_loc()
136+
if to is None:
137+
to = IndexType.get()
138+
return maybe_cast(
139+
get_result_or_results(arith_dialect.IndexCastOp(to, value, loc=loc, ip=ip))
85140
)
86141

87142

@@ -231,6 +286,7 @@ def _binary_op(
231286
rhs: "ArithValue",
232287
op: str,
233288
predicate: str = None,
289+
signedness: str = None,
234290
*,
235291
loc: Location = None,
236292
) -> "ArithValue":
@@ -247,12 +303,15 @@ def _binary_op(
247303
"""
248304
if loc is None:
249305
loc = get_user_code_loc()
250-
if not isinstance(rhs, lhs.__class__):
306+
if (
307+
isinstance(rhs, Value)
308+
and lhs.type != rhs.type
309+
or isinstance(rhs, (float, int, bool, np.ndarray))
310+
):
251311
lhs, rhs = lhs.coerce(rhs)
252-
if lhs.type != rhs.type:
253-
raise ValueError(f"{lhs=} {rhs=} must have the same type.")
312+
assert lhs.type == rhs.type, f"{lhs=} {rhs=} must have the same type."
254313

255-
assert op in {"add", "sub", "mul", "cmp", "truediv", "floordiv", "mod"}
314+
assert op in {"add", "and", "or", "sub", "mul", "cmp", "truediv", "floordiv", "mod"}
256315

257316
if op == "cmp":
258317
assert predicate is not None
@@ -301,15 +360,20 @@ def _binary_op(
301360
elif _is_integer_like_type(lhs.dtype):
302361
# eq, ne signs don't matter
303362
if predicate not in {"eq", "ne"}:
304-
if lhs.dtype.is_signed:
305-
predicate = "s" + predicate
363+
if signedness is not None:
364+
predicate = signedness + predicate
306365
else:
307-
predicate = "u" + predicate
366+
if lhs.dtype.is_signed:
367+
predicate = "s" + predicate
368+
else:
369+
predicate = "u" + predicate
308370
return lhs.__class__(op(predicate, lhs, rhs, loc=loc), dtype=lhs.dtype)
309371
else:
310372
return lhs.__class__(op(lhs, rhs, loc=loc), dtype=lhs.dtype)
311373

312374

375+
# TODO(max): these could be generic in the dtype
376+
# TODO(max): hit .verify() before constructing (maybe)
313377
class ArithValue(Value, metaclass=ArithValueMeta):
314378
"""Class for functionality shared by Value subclasses that support
315379
arithmetic operations.
@@ -363,6 +427,9 @@ def __repr__(self):
363427
__rsub__ = partialmethod(_binary_op, op="sub")
364428
__rmul__ = partialmethod(_binary_op, op="mul")
365429

430+
__and__ = partialmethod(_binary_op, op="and")
431+
__or__ = partialmethod(_binary_op, op="or")
432+
366433
def __eq__(self, other):
367434
if not isinstance(other, self.__class__):
368435
try:
@@ -435,6 +502,14 @@ def __float__(self):
435502
def coerce(self, other) -> tuple["Scalar", "Scalar"]:
436503
if isinstance(other, (int, float, bool)):
437504
other = Scalar(other, dtype=self.dtype)
505+
elif isinstance(other, Scalar) and _is_index_type(self.type):
506+
other = index_cast(other)
507+
elif isinstance(other, Scalar) and _is_index_type(other.type):
508+
other = index_cast(other, to=self.type)
438509
else:
439-
raise ValueError(f"can't coerce {other=} to Scalar")
510+
raise ValueError(f"can't coerce {other=} to {self=}")
440511
return self, other
512+
513+
514+
for t in [BF16Type, F16Type, F32Type, F64Type, IndexType, IntegerType, ComplexType]:
515+
register_value_caster(t.static_typeid)(Scalar)

mlir_utils/dialects/ext/scf.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Optional, Sequence
66

77
from bytecode import ConcreteBytecode, ConcreteInstr
8+
from mlir.dialects.linalg.opdsl.lang.emitter import _is_index_type
89
from mlir.dialects.scf import IfOp, ForOp
910
from mlir.ir import InsertionPoint, Value, OpResultList, OpResult
1011

@@ -16,7 +17,7 @@
1617
OpCode,
1718
)
1819
from mlir_utils.ast.util import ast_call, set_lineno
19-
from mlir_utils.dialects.ext.arith import constant
20+
from mlir_utils.dialects.ext.arith import constant, index_cast
2021
from mlir_utils.dialects.scf import yield_ as yield__
2122
from mlir_utils.util import (
2223
region_op,
@@ -43,15 +44,17 @@ def _for(
4344
if stop is None:
4445
stop = start
4546
start = 0
46-
if isinstance(start, int):
47-
start = constant(start, index=True)
48-
if isinstance(stop, int):
49-
stop = constant(stop, index=True)
50-
if isinstance(step, int):
51-
step = constant(step, index=True)
47+
params = [start, stop, step]
48+
for i, p in enumerate(params):
49+
if isinstance(p, int):
50+
p = constant(p, index=True)
51+
if not _is_index_type(p.type):
52+
p = index_cast(p)
53+
params[i] = p
54+
5255
if loc is None:
5356
loc = get_user_code_loc()
54-
return ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
57+
return ForOp(*params, iter_args, loc=loc, ip=ip)
5558

5659

5760
for_ = region_op(_for, terminator=yield__)

mlir_utils/dialects/ext/tensor.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def __getitem__(self, idx: tuple) -> "Tensor":
140140
if isinstance(idx, tuple) and all(i == slice(None) for i in idx):
141141
return self
142142
if idx is None:
143-
return _expand_dims(self, (0,))
143+
return expand_dims(self, (0,))
144144

145145
idx = list((idx,) if isinstance(idx, int) else idx)
146146
for i, d in enumerate(idx):
@@ -198,8 +198,10 @@ def coerce(self, other) -> tuple["Tensor", "Tensor"]:
198198
if isinstance(other, (int, float)):
199199
other = Tensor(np.full(self.shape, other), dtype=self.dtype)
200200
return self, other
201-
elif _is_scalar(other):
202-
other = tensor.splat(self.type, other)
201+
elif isinstance(other, Scalar):
202+
other = tensor.splat(
203+
RankedTensorType.get(self.shape, other.dtype), other
204+
)
203205
return self, other
204206

205207
raise ValueError(f"can't coerce unknown {other=}")
@@ -256,7 +258,7 @@ def static_strides(self):
256258
return tuple(strides)
257259

258260

259-
def _expand_dims(inp, newaxis_dims) -> Tensor:
261+
def expand_dims(inp, newaxis_dims) -> Tensor:
260262
"""Expand the shape of a tensor.
261263
262264
Insert a new axis that will appear at the `axis` position in the expanded
@@ -514,7 +516,7 @@ def _extract_slice(
514516
raise ValueError(f"non-constant indices not supported {indexer}")
515517

516518
# This adds newaxis/None dimensions.
517-
return _expand_dims(out, indexer.newaxis_dims)
519+
return expand_dims(out, indexer.newaxis_dims)
518520

519521

520522
def _insert_slice(
@@ -523,7 +525,7 @@ def _insert_slice(
523525
idx,
524526
):
525527
if isinstance(source, Scalar):
526-
source = _expand_dims(source, (0,))
528+
source = expand_dims(source, (0,))
527529

528530
indexer = _indices_to_indexer(idx, dest.shape)
529531

mlir_utils/types.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,24 @@
55
import numpy as np
66
from mlir.ir import (
77
Attribute,
8+
BF16Type,
9+
ComplexType,
810
F16Type,
911
F32Type,
1012
F64Type,
13+
Float8E5M2Type,
14+
Float8E4M3FNType,
15+
Float8E4M3B11FNUZType,
1116
IndexType,
1217
IntegerType,
1318
MemRefType,
19+
NoneType,
20+
OpaqueType,
1421
RankedTensorType,
1522
Type,
1623
UnrankedMemRefType,
1724
UnrankedTensorType,
1825
VectorType,
19-
BF16Type,
20-
OpaqueType,
2126
)
2227

2328
_index_t = lambda: IndexType.get()
@@ -43,6 +48,16 @@
4348
_f64_t = lambda: F64Type.get()
4449
_bf16_t = lambda: BF16Type.get()
4550

51+
_f8e5m2_t = lambda: Float8E5M2Type.get()
52+
_f8e4m3_t = lambda: Float8E4M3FNType.get()
53+
_f8e4m3b11fnuz_t = lambda: Float8E4M3B11FNUZType.get()
54+
55+
_cmp16_t = lambda: ComplexType.get(_f16_t())
56+
_cmp32_t = lambda: ComplexType.get(_f32_t())
57+
_cmp64_t = lambda: ComplexType.get(_f64_t())
58+
59+
_none_t = lambda: NoneType.get()
60+
4661
opaque_t = lambda dialect_namespace, buffer: OpaqueType.get(dialect_namespace, buffer)
4762

4863

@@ -53,26 +68,29 @@ def _placeholder_opaque_t():
5368
_name_to_type = {
5469
"index_t": _index_t,
5570
"bool_t": _bool_t,
56-
5771
"i8_t": _i8_t,
5872
"i16_t": _i16_t,
5973
"i32_t": _i32_t,
6074
"i64_t": _i64_t,
61-
6275
"si8_t": _si8_t,
6376
"si16_t": _si16_t,
6477
"si32_t": _si32_t,
6578
"si64_t": _si64_t,
66-
6779
"ui8_t": _ui8_t,
6880
"ui16_t": _ui16_t,
6981
"ui32_t": _ui32_t,
7082
"ui64_t": _ui64_t,
71-
7283
"f16_t": _f16_t,
7384
"f32_t": _f32_t,
7485
"f64_t": _f64_t,
7586
"bf16_t": _bf16_t,
87+
"f8e5m2_t": _f8e5m2_t,
88+
"f8e4m3_t": _f8e4m3_t,
89+
"f8e4m3b11fnuz_t": _f8e4m3b11fnuz_t,
90+
"cmp16_t": _cmp16_t,
91+
"cmp32_t": _cmp32_t,
92+
"cmp64_t": _cmp64_t,
93+
"none_t": _none_t,
7694
}
7795

7896

@@ -115,7 +133,7 @@ def mlir_type_to_np_dtype(mlir_type):
115133

116134
def infer_mlir_type(
117135
py_val: Union[int, float, bool, np.ndarray]
118-
) -> Union[IntegerType, F64Type, RankedTensorType]:
136+
) -> Union[IntegerType, F32Type, F64Type, RankedTensorType]:
119137
"""Infer MLIR type (`ir.Type`) from supported python values.
120138
121139
Note ints and floats are mapped to 64-bit types.
@@ -129,9 +147,26 @@ def infer_mlir_type(
129147
if isinstance(py_val, bool):
130148
return _bool_t()
131149
elif isinstance(py_val, int):
132-
return _i64_t()
150+
if -(2 ** 31) <= py_val < 2 ** 31:
151+
return _i32_t()
152+
elif 2 ** 31 <= py_val < 2 ** 32:
153+
return _ui32_t()
154+
elif -(2 ** 63) <= py_val < 2 ** 63:
155+
return _i64_t()
156+
elif 2 ** 63 <= py_val < 2 ** 64:
157+
return _ui64_t()
158+
else:
159+
raise RuntimeError(f"Nonrepresentable integer {py_val}.")
133160
elif isinstance(py_val, float):
134-
return _f64_t()
161+
if (
162+
abs(py_val) == float("inf")
163+
or abs(py_val) == 0.0
164+
or py_val != py_val # NaN
165+
or np.finfo(np.float32).min <= abs(py_val) <= np.finfo(np.float32).max
166+
):
167+
return _f32_t()
168+
else:
169+
return _f64_t()
135170
elif isinstance(py_val, np.ndarray):
136171
dtype = np_dtype_to_mlir_type(py_val.dtype.type)
137172
return RankedTensorType.get(py_val.shape, dtype)

0 commit comments

Comments
 (0)