Skip to content

Commit 774818c

Browse files
dominikgreweftynse
authored andcommitted
Expose MlirOperationClone in Python bindings.
Expose MlirOperationClone in Python bindings. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D122526
1 parent 58d0da8 commit 774818c

File tree

3 files changed

+49
-12
lines changed

3 files changed

+49
-12
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,21 @@ py::object PyOperation::createFromCapsule(py::object capsule) {
10751075
.releaseObject();
10761076
}
10771077

1078+
static void maybeInsertOperation(PyOperationRef &op,
1079+
const py::object &maybeIp) {
1080+
// InsertPoint active?
1081+
if (!maybeIp.is(py::cast(false))) {
1082+
PyInsertionPoint *ip;
1083+
if (maybeIp.is_none()) {
1084+
ip = PyThreadContextEntry::getDefaultInsertionPoint();
1085+
} else {
1086+
ip = py::cast<PyInsertionPoint *>(maybeIp);
1087+
}
1088+
if (ip)
1089+
ip->insert(*op.get());
1090+
}
1091+
}
1092+
10781093
py::object PyOperation::create(
10791094
const std::string &name, llvm::Optional<std::vector<PyType *>> results,
10801095
llvm::Optional<std::vector<PyValue *>> operands,
@@ -1192,22 +1207,20 @@ py::object PyOperation::create(
11921207
MlirOperation operation = mlirOperationCreate(&state);
11931208
PyOperationRef created =
11941209
PyOperation::createDetached(location->getContext(), operation);
1195-
1196-
// InsertPoint active?
1197-
if (!maybeIp.is(py::cast(false))) {
1198-
PyInsertionPoint *ip;
1199-
if (maybeIp.is_none()) {
1200-
ip = PyThreadContextEntry::getDefaultInsertionPoint();
1201-
} else {
1202-
ip = py::cast<PyInsertionPoint *>(maybeIp);
1203-
}
1204-
if (ip)
1205-
ip->insert(*created.get());
1206-
}
1210+
maybeInsertOperation(created, maybeIp);
12071211

12081212
return created->createOpView();
12091213
}
12101214

1215+
py::object PyOperation::clone(const py::object &maybeIp) {
1216+
MlirOperation clonedOperation = mlirOperationClone(operation);
1217+
PyOperationRef cloned =
1218+
PyOperation::createDetached(getContext(), clonedOperation);
1219+
maybeInsertOperation(cloned, maybeIp);
1220+
1221+
return cloned->createOpView();
1222+
}
1223+
12111224
py::object PyOperation::createOpView() {
12121225
checkValid();
12131226
MlirIdentifier ident = mlirOperationGetName(get());
@@ -2616,6 +2629,7 @@ void mlir::python::populateIRCore(py::module &m) {
26162629
return py::none();
26172630
})
26182631
.def("erase", &PyOperation::erase)
2632+
.def("clone", &PyOperation::clone, py::arg("ip") = py::none())
26192633
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
26202634
&PyOperation::getCapsule)
26212635
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
575575
/// parent context's live operations map, and sets the valid bit false.
576576
void erase();
577577

578+
/// Clones this operation.
579+
pybind11::object clone(const pybind11::object &ip);
580+
578581
private:
579582
PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
580583
static PyOperationRef createInstance(PyMlirContextRef contextRef,

mlir/test/python/ir/operation.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,26 @@ def testOperationErase():
767767
Operation.create("custom.op2")
768768

769769

770+
# CHECK-LABEL: TEST: testOperationClone
771+
@run
772+
def testOperationClone():
773+
ctx = Context()
774+
ctx.allow_unregistered_dialects = True
775+
with Location.unknown(ctx):
776+
m = Module.create()
777+
with InsertionPoint(m.body):
778+
op = Operation.create("custom.op1")
779+
780+
# CHECK: "custom.op1"
781+
print(m)
782+
783+
clone = op.operation.clone()
784+
op.operation.erase()
785+
786+
# CHECK: "custom.op1"
787+
print(m)
788+
789+
770790
# CHECK-LABEL: TEST: testOperationLoc
771791
@run
772792
def testOperationLoc():

0 commit comments

Comments
 (0)