Skip to content

Commit 3dee9c5

Browse files
committed
bump version
1 parent 175450f commit 3dee9c5

File tree

10 files changed

+78
-44
lines changed

10 files changed

+78
-44
lines changed

.github/workflows/test_pypi.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,26 @@ jobs:
1212
runs-on: ubuntu-20.04
1313
steps:
1414
- uses: actions/checkout@v3
15+
- uses: actions/setup-python@v4
16+
with:
17+
python-version: '3.11'
1518

1619
- name: Build wheels
1720
run: |
1821
pip wheel -w wheelhouse .
1922
2023
- uses: actions/upload-artifact@v3
2124
with:
22-
path: ./wheelhouse/*.whl
25+
path: ./wheelhouse/mlir_python_utils*.whl
2326

2427
build_sdist:
2528
name: Build source distribution
2629
runs-on: ubuntu-latest
2730
steps:
2831
- uses: actions/checkout@v3
32+
- uses: actions/setup-python@v4
33+
with:
34+
python-version: '3.11'
2935

3036
- name: Build sdist
3137
run: pipx run build --sdist

.github/workflows/wheels.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
- name: Upload wheels
3030
uses: actions/upload-artifact@v3
3131
with:
32-
path: wheelhouse/*.whl
32+
path: wheelhouse/mlir_python_utils*.whl
3333
name: build_artifact
3434

3535
upload_wheels:

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ or for maximum convenience
1515

1616
```shell
1717
$ pip install mlir-python-utils[mlir] \
18+
-i https://test.pypi.org/simple \
1819
-f https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest
19-
-f https://github.com/makslevental/mlir-python-utils/releases/expanded_assets/latest
2020
$ configure-mlir-python-utils mlir
2121
```
2222

examples/throwaway.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from mlir_utils.dialects import gpu
1919
from mlir_utils.dialects.ext import func
2020
from mlir_utils.dialects.ext.arith import constant
21-
from mlir_utils.types import f64, index
21+
from mlir_utils.types import f64_t, index_t
2222

2323
generate_all_upstream_trampolines()
2424
# from mlir.dialects.scf import WhileOp
@@ -51,11 +51,11 @@
5151
#
5252
with mlir_mod_ctx() as ctx:
5353

54-
one = constant(1, index)
55-
two = constant(2, index)
54+
one = constant(1, index_t)
55+
two = constant(2, index_t)
5656

5757
@generate(
58-
Tensor[(S, 3, S), f64], dynamic_extents=[one, two], block_args=[index] * 3
58+
Tensor[(S, 3, S), f64_t], dynamic_extents=[one, two], block_args=[index_t] * 3
5959
)
6060
def demo_fun1(i, j, k):
6161
one = constant(1.0)

mlir_utils/_configuration/module_alias_map.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def module_repr(self, module: ModuleType) -> str:
5656

5757
class AliasedModuleFinder(MetaPathFinder):
5858
def __init__(self, alias_map: Mapping[str, str]):
59+
for k, v in dict(alias_map).items():
60+
if k == v:
61+
alias_map.pop(k)
5962
self.alias_map = alias_map
6063

6164
def find_spec(

mlir_utils/dialects/util.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import ctypes
22
from functools import wraps
3+
import inspect
34

45
from mlir.dialects._ods_common import get_op_result_or_value, get_op_results_or_values
5-
from mlir.ir import InsertionPoint, Value
6+
from mlir.ir import InsertionPoint, Value, Type
67

78

89
def get_result_or_results(op):
@@ -53,12 +54,21 @@ def maybe_cast(val: Value):
5354
def region_op(op_constructor):
5455
# the decorator itself
5556
def op_decorator(*args, **kwargs):
56-
block_arg_types = kwargs.pop("block_args", [])
5757
op = op_constructor(*args, **kwargs)
5858

5959
def builder_wrapper(body_builder):
6060
# add a block with block args having types ...
61-
op.regions[0].blocks.append(*[t for t in block_arg_types])
61+
sig = inspect.signature(body_builder)
62+
types = [p.annotation for p in sig.parameters.values()]
63+
if not (
64+
len(types) == len(sig.parameters)
65+
and all(isinstance(t, Type) for t in types)
66+
):
67+
raise ValueError(
68+
f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}"
69+
)
70+
71+
op.regions[0].blocks.append(*types)
6272
with InsertionPoint(op.regions[0].blocks[0]):
6373
body_builder(
6474
*[maybe_cast(a) for a in op.regions[0].blocks[0].arguments]

mlir_utils/types.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,31 @@
88
IndexType,
99
F16Type,
1010
F32Type,
11+
Type,
1112
)
1213

13-
index = IndexType.get()
14-
bool_ = IntegerType.get_signless(1)
15-
i8 = IntegerType.get_signless(8)
16-
i16 = IntegerType.get_signless(16)
17-
i32 = IntegerType.get_signless(32)
18-
i64 = IntegerType.get_signless(64)
19-
f16 = F16Type.get()
20-
f32 = F32Type.get()
21-
f64 = F64Type.get()
14+
index_t = IndexType.get()
15+
bool_t = IntegerType.get_signless(1)
16+
i8_t = IntegerType.get_signless(8)
17+
i16_t = IntegerType.get_signless(16)
18+
i32_t = IntegerType.get_signless(32)
19+
i64_t = IntegerType.get_signless(64)
20+
f16_t = F16Type.get()
21+
f32_t = F32Type.get()
22+
f64_t = F64Type.get()
2223

2324
NP_DTYPE_TO_MLIR_TYPE = lambda: {
24-
np.int8: i8,
25-
np.int16: i16,
26-
np.int32: i32,
27-
np.int64: i64,
25+
np.int8: i8_t,
26+
np.int16: i16_t,
27+
np.int32: i32_t,
28+
np.int64: i64_t,
2829
# this is techincally wrong i guess but numpy by default casts python scalars to this
2930
# so to support passing lists of ints we map this to index type
30-
np.longlong: index,
31-
np.uintp: index,
32-
np.float16: f16,
33-
np.float32: f32,
34-
np.float64: f64,
31+
np.longlong: index_t,
32+
np.uintp: index_t,
33+
np.float16: f16_t,
34+
np.float32: f32_t,
35+
np.float64: f64_t,
3536
}
3637

3738
MLIR_TYPE_TO_NP_DTYPE = lambda: {v: k for k, v in NP_DTYPE_TO_MLIR_TYPE().items()}
@@ -51,15 +52,29 @@ def infer_mlir_type(
5152
MLIR type corresponding to py_val.
5253
"""
5354
if isinstance(py_val, bool):
54-
return bool_
55+
return bool_t
5556
elif isinstance(py_val, int):
56-
return i64
57+
return i64_t
5758
elif isinstance(py_val, float):
58-
return f64
59+
return f64_t
5960
elif isinstance(py_val, np.ndarray):
6061
dtype = NP_DTYPE_TO_MLIR_TYPE()[py_val.dtype.type]
6162
return RankedTensorType.get(py_val.shape, dtype)
6263
else:
6364
raise NotImplementedError(
6465
f"Unsupported Python value {py_val=} with type {type(py_val)}"
6566
)
67+
68+
69+
def tensor_t(*args, element_type: Type = None):
70+
if (element_type is None and not isinstance(args[-1], Type)) or (
71+
isinstance(args[-1], Type) and element_type is not None
72+
):
73+
raise ValueError(
74+
f"either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type"
75+
)
76+
if element_type is not None:
77+
type = element_type
78+
else:
79+
type = args[-1]
80+
return RankedTensorType.get(args[:-1], type)

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
[project]
22
name = "mlir-python-utils"
3-
version = "0.0.1"
3+
version = "0.0.2"
4+
description = "The missing pieces (as far as boilerplate reduction goes) of the upstream MLIR python bindings."
45
requires-python = ">=3.11"
56
license = { file = "LICENSE" }
7+
readme = "README.md"
68
dependencies = [
79
"numpy",
810
"black",

tests/test_operator_overloading.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,24 @@
77

88
# noinspection PyUnresolvedReferences
99
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
10-
from mlir_utils.types import f64, index
10+
from mlir_utils.types import f64_t, index_t
1111

1212
# needed since the fix isn't defined here nor conftest.py
1313
pytest.mark.usefixtures("ctx")
1414

1515

1616
def test_tensor_arithmetic(ctx: MLIRContext):
1717
print()
18-
one = constant(1, index)
18+
one = constant(1, index_t)
1919
assert isinstance(one, Scalar)
20-
two = constant(2, index)
20+
two = constant(2, index_t)
2121
assert isinstance(two, Scalar)
2222
three = one + two
2323
assert isinstance(three, Scalar)
2424

25-
ten1 = empty((10, 10, 10), f64)
25+
ten1 = empty((10, 10, 10), f64_t)
2626
assert isinstance(ten1, Tensor)
27-
ten2 = empty((10, 10, 10), f64)
27+
ten2 = empty((10, 10, 10), f64_t)
2828
assert isinstance(ten2, Tensor)
2929
ten3 = ten1 + ten2
3030
assert isinstance(ten3, Tensor)

tests/test_regions.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
# noinspection PyUnresolvedReferences
1313
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
14-
from mlir_utils.types import f64, index
14+
from mlir_utils.types import f64_t, index_t, tensor_t
1515

1616
# needed since the fix isn't defined here nor conftest.py
1717
pytest.mark.usefixtures("ctx")
@@ -93,13 +93,11 @@ def demo_fun1():
9393

9494

9595
def test_block_args(ctx: MLIRContext):
96-
one = constant(1, index)
97-
two = constant(2, index)
96+
one = constant(1, index_t)
97+
two = constant(2, index_t)
9898

99-
@generate(
100-
Tensor[(S, 3, S), f64], dynamic_extents=[one, two], block_args=[index] * 3
101-
)
102-
def demo_fun1(i, j, k):
99+
@generate(tensor_t(S, 3, S, f64_t), dynamic_extents=[one, two])
100+
def demo_fun1(i: index_t, j: index_t, k: index_t):
103101
one = constant(1.0)
104102
tensor_yield(one)
105103

0 commit comments

Comments
 (0)