Skip to content

Commit 90c0b70

Browse files
committed
[python] fix enum ambiguity
1 parent 9d55e86 commit 90c0b70

26 files changed

+513
-87
lines changed

mlir/cmake/modules/AddMLIRPython.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ function(declare_mlir_dialect_python_bindings)
318318
set(LLVM_TARGET_DEFINITIONS ${td_file})
319319
endif()
320320
set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py")
321-
mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
321+
mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
322322
list(APPEND _sources ${enum_filename})
323323
endif()
324324

@@ -390,7 +390,7 @@ function(declare_mlir_dialect_extension_python_bindings)
390390
set(LLVM_TARGET_DEFINITIONS ${td_file})
391391
endif()
392392
set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py")
393-
mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
393+
mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
394394
list(APPEND _sources ${enum_filename})
395395
endif()
396396

mlir/python/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ declare_mlir_dialect_python_bindings(
6363
TD_FILE dialects/AffineOps.td
6464
SOURCES
6565
dialects/affine.py
66-
DIALECT_NAME affine
67-
GEN_ENUM_BINDINGS)
66+
DIALECT_NAME affine)
6867

6968
declare_mlir_dialect_python_bindings(
7069
ADD_TO_PARENT MLIRPythonSources.Dialects

mlir/python/mlir/dialects/_ods_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def get_op_result_or_op_results(
143143
else op
144144
)
145145

146+
146147
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
147148
ResultValueT = _Union[ResultValueTypeTuple]
148149
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]

mlir/python/mlir/dialects/amdgpu.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,21 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from ..ir import IntegerAttr, IntegerType, register_attribute_builder
56
from ._amdgpu_ops_gen import *
67
from ._amdgpu_enum_gen import *
8+
9+
10+
@register_attribute_builder("builtin.AMDGPU_DPPPerm")
11+
def _amdgpu_dppperm(x, context):
12+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
13+
14+
15+
@register_attribute_builder("builtin.AMDGPU_MFMAPermB")
16+
def _amdgpu_mfmapermb(x, context):
17+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
18+
19+
20+
@register_attribute_builder("builtin.AMDGPU_SchedBarrierOpOpt")
21+
def _amdgpu_schedbarrieropopt(x, context):
22+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

mlir/python/mlir/dialects/arith.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,38 @@ def constant(
108108
result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
109109
) -> Value:
110110
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
111+
112+
113+
@register_attribute_builder("builtin.Arith_CmpFPredicateAttr")
114+
def _arith_cmpfpredicateattr(x, context):
115+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
116+
117+
118+
@register_attribute_builder("builtin.Arith_CmpIPredicateAttr")
119+
def _arith_cmpipredicateattr(x, context):
120+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
121+
122+
123+
@register_attribute_builder("builtin.Arith_DenormalMode")
124+
def _arith_denormalmode(x, context):
125+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
126+
127+
128+
@register_attribute_builder("builtin.Arith_IntegerOverflowFlags")
129+
def _arith_integeroverflowflags(x, context):
130+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
131+
132+
133+
@register_attribute_builder("builtin.Arith_RoundingModeAttr")
134+
def _arith_roundingmodeattr(x, context):
135+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
136+
137+
138+
@register_attribute_builder("builtin.AtomicRMWKindAttr")
139+
def _atomicrmwkindattr(x, context):
140+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
141+
142+
143+
@register_attribute_builder("builtin.FastMathFlags")
144+
def _fastmathflags(x, context):
145+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

mlir/python/mlir/dialects/bufferization.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,11 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from ..ir import IntegerAttr, IntegerType, register_attribute_builder
56
from ._bufferization_ops_gen import *
67
from ._bufferization_enum_gen import *
8+
9+
10+
@register_attribute_builder("builtin.LayoutMapOption")
11+
def _layoutmapoption(x, context):
12+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

mlir/python/mlir/dialects/gpu/__init__.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,62 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from ...ir import IntegerAttr, IntegerType, register_attribute_builder
56
from .._gpu_ops_gen import *
67
from .._gpu_enum_gen import *
78
from ..._mlir_libs._mlirDialectsGPU import *
9+
10+
11+
@register_attribute_builder("builtin.GPU_AddressSpaceEnum")
12+
def _gpu_addressspaceenum(x, context):
13+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
14+
15+
16+
@register_attribute_builder("builtin.GPU_AllReduceOperation")
17+
def _gpu_allreduceoperation(x, context):
18+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
19+
20+
21+
@register_attribute_builder("builtin.GPU_CompilationTargetEnum")
22+
def _gpu_compilationtargetenum(x, context):
23+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
24+
25+
26+
@register_attribute_builder("builtin.GPU_Dimension")
27+
def _gpu_dimension(x, context):
28+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
29+
30+
31+
@register_attribute_builder("builtin.GPU_Prune2To4SpMatFlag")
32+
def _gpu_prune2to4spmatflag(x, context):
33+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
34+
35+
36+
@register_attribute_builder("builtin.GPU_ShuffleMode")
37+
def _gpu_shufflemode(x, context):
38+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
39+
40+
41+
@register_attribute_builder("builtin.GPU_SpGEMMWorkEstimationOrComputeKind")
42+
def _gpu_spgemmworkestimationorcomputekind(x, context):
43+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
44+
45+
46+
@register_attribute_builder("builtin.GPU_TransposeMode")
47+
def _gpu_transposemode(x, context):
48+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
49+
50+
51+
@register_attribute_builder("builtin.MMAElementWise")
52+
def _mmaelementwise(x, context):
53+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
54+
55+
56+
@register_attribute_builder("builtin.MappingIdEnum")
57+
def _mappingidenum(x, context):
58+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
59+
60+
61+
@register_attribute_builder("builtin.ProcessorEnum")
62+
def _processorenum(x, context):
63+
return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))

mlir/python/mlir/dialects/index.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,11 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from ..ir import IntegerAttr, IntegerType, register_attribute_builder
56
from ._index_ops_gen import *
67
from ._index_enum_gen import *
8+
9+
10+
@register_attribute_builder("builtin.IndexCmpPredicate")
11+
def _indexcmppredicate(x, context):
12+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,28 @@ def broadcast(
102102
)
103103
fill_builtin_region(op.operation)
104104
return op
105+
106+
107+
@register_attribute_builder("builtin.BinaryFn")
108+
def _binaryfn(x, context):
109+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
110+
111+
112+
@register_attribute_builder("builtin.IteratorType")
113+
def _iteratortype(x, context):
114+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
115+
116+
117+
@register_attribute_builder("builtin.TernaryFn")
118+
def _ternaryfn(x, context):
119+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
120+
121+
122+
@register_attribute_builder("builtin.TypeFn")
123+
def _typefn(x, context):
124+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
125+
126+
127+
@register_attribute_builder("builtin.UnaryFn")
128+
def _unaryfn(x, context):
129+
return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,7 @@ def conv_2d_nchw_fchw_q(
888888
- TypeFn.cast_signed(U, IZp)
889889
) * (TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) - TypeFn.cast_signed(U, KZp))
890890

891+
891892
@linalg_structured_op
892893
def conv_2d_nchw_fchw(
893894
I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),

0 commit comments

Comments
 (0)