Skip to content

Commit 074afc8

Browse files
committed
make value_caster extensible
1 parent 22e50f9 commit 074afc8

File tree

6 files changed

+98
-33
lines changed

6 files changed

+98
-33
lines changed

README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@ This package is meant to work in concert with the upstream bindings.
2626
Practically speaking that means you need to have *some* package installed that includes mlir python bindings.
2727
In addition, you have to do one of two things to **configure this package** (after installing it):
2828

29-
1. `$ configure-mlir-python-utils -y <MLIR_PYTHON_PACKAGE_PREFIX>`, where `MLIR_PYTHON_PACKAGE_PREFIX` is (as it says) the
29+
1. `$ configure-mlir-python-utils -y <MLIR_PYTHON_PACKAGE_PREFIX>`, where `MLIR_PYTHON_PACKAGE_PREFIX` is (as it says)
30+
the
3031
package prefix for your chosen upstream bindings. So for example, for `torch-mlir`, you would
31-
execute `configure-mlir-python-utils torch_mlir`, since `torch-mlir`'s bindings are the root of the `torch-mlir` python
32+
execute `configure-mlir-python-utils torch_mlir`, since `torch-mlir`'s bindings are the root of the `torch-mlir`
33+
python
3234
package. **When in doubt about this prefix**, it is everything up until `ir` (e.g., as
3335
in `from torch_mlir import ir`).
3436
2. `$ export MLIR_PYTHON_PACKAGE_PREFIX=<MLIR_PYTHON_PACKAGE_PREFIX>`, i.e., you can set this string as an environment
@@ -49,4 +51,12 @@ pip install setuptools -U
4951
pip install -e .[torch-mlir-test] \
5052
-f https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest \
5153
-f https://llvm.github.io/torch-mlir/package-index/
54+
```
55+
56+
There's an annoying bug where if you try to register to a different set of host bindings it won't work the first (e.g.,
57+
going from `torch-mlir` to `mlir`).
58+
Workaround is to delete the prefix token before configuring, like so:
59+
60+
```shell
61+
rm /home/mlevental/dev_projects/mlir_utils/mlir_utils/_configuration/__MLIR_PYTHON_PACKAGE_PREFIX__ && configure-mlir-python-utils mlir
5262
```

mlir_utils/_configuration/module_alias_map.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,11 @@ def find_spec(
8787
)
8888
else:
8989
return None
90+
91+
92+
def maybe_remove_alias_module_loader():
93+
for i in range(len(sys.meta_path)):
94+
finder = sys.meta_path[i]
95+
if isinstance(finder, AliasedModuleFinder):
96+
del sys.meta_path[i]
97+
return

mlir_utils/dialects/ext/arith.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def __call__(cls, *args, **kwargs):
155155

156156

157157
class ArithValue(Value, metaclass=ArithValueMeta):
158-
"""Mixin class for functionality shared by Value subclasses that support
158+
"""Class for functionality shared by Value subclasses that support
159159
arithmetic operations.
160160
161161
Note, since we bind the ArithValueMeta here, it is here that the __new__ and

mlir_utils/dialects/ext/tensor.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from mlir.ir import Type, Value, RankedTensorType, DenseElementsAttr, ShapedType
77

88
from mlir_utils.dialects.ext.arith import ArithValue
9+
from mlir_utils.dialects.util import register_value_caster
910

1011
try:
1112
from mlir_utils.dialects.tensor import *
@@ -64,28 +65,5 @@ def empty(
6465

6566
return cls(EmptyOp(shape, el_type).result)
6667

67-
def __class_getitem__(
68-
cls, dim_sizes_dtype: Tuple[Union[list[int], tuple[int, ...]], Type]
69-
) -> Type:
70-
"""A convenience method for creating RankedTensorType.
71-
72-
Args:
73-
dim_sizes_dtype: A tuple of both the shape of the type and the dtype.
74-
75-
Returns:
76-
An instance of RankedTensorType.
77-
"""
78-
if len(dim_sizes_dtype) != 2:
79-
raise ValueError(
80-
f"Wrong type of argument to {cls.__name__}: {dim_sizes_dtype=}"
81-
)
82-
dim_sizes, dtype = dim_sizes_dtype
83-
if not isinstance(dtype, Type):
84-
raise ValueError(f"{dtype=} is not {Type=}")
85-
static_sizes = []
86-
for s in dim_sizes:
87-
if isinstance(s, int):
88-
static_sizes.append(s)
89-
else:
90-
static_sizes.append(ShapedType.get_dynamic_size())
91-
return RankedTensorType.get(static_sizes, dtype)
68+
69+
register_value_caster(RankedTensorType.static_typeid, Tensor)

mlir_utils/dialects/util.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import ctypes
2-
from functools import wraps
32
import inspect
3+
from collections import defaultdict
4+
from functools import wraps
5+
from typing import Callable
46

57
from mlir.dialects._ods_common import get_op_result_or_value, get_op_results_or_values
6-
from mlir.ir import InsertionPoint, Value, Type
8+
from mlir.ir import InsertionPoint, Value, Type, TypeID
79

810

911
def get_result_or_results(op):
@@ -31,20 +33,52 @@ def maybe_no_args(*args, **kwargs):
3133
return maybe_no_args
3234

3335

36+
__VALUE_CASTERS: defaultdict[
37+
TypeID, list[Callable[[Value], Value | None]]
38+
] = defaultdict(list)
39+
40+
41+
def register_value_caster(
42+
typeid: TypeID, caster: Callable[[Value], Value], priority: int = None
43+
):
44+
if not isinstance(typeid, TypeID):
45+
raise ValueError(f"{typeid=} is not a TypeID")
46+
if priority is None:
47+
__VALUE_CASTERS[typeid].append(caster)
48+
else:
49+
__VALUE_CASTERS[typeid].insert(priority, caster)
50+
51+
52+
def has_value_caster(typeid: TypeID):
53+
if not isinstance(typeid, TypeID):
54+
raise ValueError(f"{typeid=} is not a TypeID")
55+
if not typeid in __VALUE_CASTERS:
56+
return False
57+
return True
58+
59+
60+
def get_value_caster(typeid: TypeID):
61+
if not has_value_caster(typeid):
62+
raise ValueError(f"no registered caster for {typeid=}")
63+
return __VALUE_CASTERS[typeid]
64+
65+
3466
def maybe_cast(val: Value):
3567
"""Maybe cast an ir.Value to one of Tensor, Scalar.
3668
3769
Args:
3870
val: The ir.Value to maybe cast.
3971
"""
40-
from mlir_utils.dialects.ext.tensor import Tensor
4172
from mlir_utils.dialects.ext.arith import Scalar
4273

4374
if not isinstance(val, Value):
4475
return val
4576

46-
if Tensor.isinstance(val):
47-
return Tensor(val)
77+
if has_value_caster(val.type.typeid):
78+
for caster in get_value_caster(val.type.typeid):
79+
if casted := caster(val):
80+
return casted
81+
raise ValueError(f"no successful casts for {val=}")
4882
if Scalar.isinstance(val):
4983
return Scalar(val)
5084
return val

tests/test_value_caster.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import pytest
2+
from mlir.ir import OpResult
3+
4+
from mlir_utils.dialects.ext.tensor import S, empty
5+
from mlir_utils.dialects.ext.arith import constant
6+
from mlir_utils.dialects.util import register_value_caster
7+
8+
# noinspection PyUnresolvedReferences
9+
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
10+
from mlir_utils.types import f64_t, RankedTensorType
11+
12+
# needed since the fix isn't defined here nor conftest.py
13+
pytest.mark.usefixtures("ctx")
14+
15+
16+
def test_caster_registration(ctx: MLIRContext):
17+
sizes = S, 3, S
18+
ten = empty(sizes, f64_t)
19+
assert repr(ten) == "Tensor(%0, tensor<?x3x?xf64>)"
20+
21+
def dummy_caster(val):
22+
print(val)
23+
return val
24+
25+
register_value_caster(RankedTensorType.static_typeid, dummy_caster)
26+
ten = empty(sizes, f64_t)
27+
assert repr(ten) == "Tensor(%1, tensor<?x3x?xf64>)"
28+
29+
register_value_caster(RankedTensorType.static_typeid, dummy_caster, 0)
30+
ten = empty(sizes, f64_t)
31+
assert repr(ten) != "Tensor(%1, tensor<?x3x?xf64>)"
32+
assert isinstance(ten, OpResult)
33+
34+
one = constant(1)
35+
assert repr(one) == "Scalar(%3, i64)"

0 commit comments

Comments
 (0)