Skip to content

Commit 24685aa

Browse files
committed
[mlir][python] allow for detaching operations from a block
Provide support for removing an operation from the block that contains it and moving it back to detached state. This allows for the operation to be moved to a different block, a common IR manipulation for, e.g., module merging. Also fix a potential one-past-end iterator dereference in Operation::moveAfter discovered in the process. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D112700
1 parent 81e8c39 commit 24685aa

File tree

6 files changed

+161
-4
lines changed

6 files changed

+161
-4
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,10 @@ MLIR_CAPI_EXPORTED MlirOperation mlirOperationClone(MlirOperation op);
346346
/// Takes an operation owned by the caller and destroys it.
347347
MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op);
348348

349+
/// Removes the given operation from its parent block. The operation is not
350+
/// destroyed. The ownership of the operation is transferred to the caller.
351+
MLIR_CAPI_EXPORTED void mlirOperationRemoveFromParent(MlirOperation op);
352+
349353
/// Checks whether the underlying operation is null.
350354
static inline bool mlirOperationIsNull(MlirOperation op) { return !op.ptr; }
351355

@@ -455,6 +459,19 @@ MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op);
455459
/// Verify the operation and return true if it passes, false if it fails.
456460
MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op);
457461

462+
/// Moves the given operation immediately after the other operation in its
463+
/// parent block. The given operation may be owned by the caller or by its
464+
/// current block. The other operation must belong to a block. In any case, the
465+
/// ownership is transferred to the block of the other operation.
466+
MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op,
467+
MlirOperation other);
468+
469+
/// Moves the given operation immediately before the other operation in its
470+
/// parent block. The given operation may be owner by the caller or by its
471+
/// current block. The other operation must belong to a block. In any case, the
472+
/// ownership is transferred to the block of the other operation.
473+
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op,
474+
MlirOperation other);
458475
//===----------------------------------------------------------------------===//
459476
// Region API.
460477
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,24 @@ py::object PyOperationBase::getAsm(bool binary,
875875
return fileObject.attr("getvalue")();
876876
}
877877

878+
void PyOperationBase::moveAfter(PyOperationBase &other) {
879+
PyOperation &operation = getOperation();
880+
PyOperation &otherOp = other.getOperation();
881+
operation.checkValid();
882+
otherOp.checkValid();
883+
mlirOperationMoveAfter(operation, otherOp);
884+
operation.parentKeepAlive = otherOp.parentKeepAlive;
885+
}
886+
887+
void PyOperationBase::moveBefore(PyOperationBase &other) {
888+
PyOperation &operation = getOperation();
889+
PyOperation &otherOp = other.getOperation();
890+
operation.checkValid();
891+
otherOp.checkValid();
892+
mlirOperationMoveBefore(operation, otherOp);
893+
operation.parentKeepAlive = otherOp.parentKeepAlive;
894+
}
895+
878896
llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
879897
checkValid();
880898
if (!isAttached())
@@ -2185,7 +2203,25 @@ void mlir::python::populateIRCore(py::module &m) {
21852203
return mlirOperationVerify(self.getOperation());
21862204
},
21872205
"Verify the operation and return true if it passes, false if it "
2188-
"fails.");
2206+
"fails.")
2207+
.def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
2208+
"Puts self immediately after the other operation in its parent "
2209+
"block.")
2210+
.def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
2211+
"Puts self immediately before the other operation in its parent "
2212+
"block.")
2213+
.def(
2214+
"detach_from_parent",
2215+
[](PyOperationBase &self) {
2216+
PyOperation &operation = self.getOperation();
2217+
operation.checkValid();
2218+
if (!operation.isAttached())
2219+
throw py::value_error("Detached operation has no parent.");
2220+
2221+
operation.detachFromParent();
2222+
return operation.createOpView();
2223+
},
2224+
"Detaches the operation from its parent block.");
21892225

21902226
py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
21912227
.def_static("create", &PyOperation::create, py::arg("name"),
@@ -2380,7 +2416,20 @@ void mlir::python::populateIRCore(py::module &m) {
23802416
printAccum.getUserData());
23812417
return printAccum.join();
23822418
},
2383-
"Returns the assembly form of the block.");
2419+
"Returns the assembly form of the block.")
2420+
.def(
2421+
"append",
2422+
[](PyBlock &self, PyOperationBase &operation) {
2423+
if (operation.getOperation().isAttached())
2424+
operation.getOperation().detachFromParent();
2425+
2426+
MlirOperation mlirOperation = operation.getOperation().get();
2427+
mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
2428+
operation.getOperation().setAttached(
2429+
self.getParentOperation().getObject());
2430+
},
2431+
"Appends an operation to this block. If the operation is currently "
2432+
"in another block, it will be moved.");
23842433

23852434
//----------------------------------------------------------------------------
23862435
// Mapping of PyInsertionPoint.

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,10 @@ class PyOperationBase {
399399
bool enableDebugInfo, bool prettyDebugInfo,
400400
bool printGenericOpForm, bool useLocalScope);
401401

402+
/// Moves the operation before or after the other operation.
403+
void moveAfter(PyOperationBase &other);
404+
void moveBefore(PyOperationBase &other);
405+
402406
/// Each must provide access to the raw Operation.
403407
virtual PyOperation &getOperation() = 0;
404408
};
@@ -428,6 +432,14 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
428432
createDetached(PyMlirContextRef contextRef, MlirOperation operation,
429433
pybind11::object parentKeepAlive = pybind11::object());
430434

435+
/// Detaches the operation from its parent block and updates its state
436+
/// accordingly.
437+
void detachFromParent() {
438+
mlirOperationRemoveFromParent(getOperation());
439+
setDetached();
440+
parentKeepAlive = pybind11::object();
441+
}
442+
431443
/// Gets the backing operation.
432444
operator MlirOperation() const { return get(); }
433445
MlirOperation get() const {
@@ -441,10 +453,14 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
441453
}
442454

443455
bool isAttached() { return attached; }
444-
void setAttached() {
456+
void setAttached(pybind11::object parent = pybind11::object()) {
445457
assert(!attached && "operation already attached");
446458
attached = true;
447459
}
460+
void setDetached() {
461+
assert(attached && "operation already detached");
462+
attached = false;
463+
}
448464
void checkValid() const;
449465

450466
/// Gets the owning block or raises an exception if the operation has no
@@ -495,6 +511,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
495511
pybind11::object parentKeepAlive;
496512
bool attached = true;
497513
bool valid = true;
514+
515+
friend class PyOperationBase;
498516
};
499517

500518
/// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,8 @@ MlirOperation mlirOperationClone(MlirOperation op) {
338338

339339
void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
340340

341+
void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(op)->remove(); }
342+
341343
bool mlirOperationEqual(MlirOperation op, MlirOperation other) {
342344
return unwrap(op) == unwrap(other);
343345
}
@@ -451,6 +453,14 @@ bool mlirOperationVerify(MlirOperation op) {
451453
return succeeded(verify(unwrap(op)));
452454
}
453455

456+
void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) {
457+
return unwrap(op)->moveAfter(unwrap(other));
458+
}
459+
460+
void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
461+
return unwrap(op)->moveBefore(unwrap(other));
462+
}
463+
454464
//===----------------------------------------------------------------------===//
455465
// Region API.
456466
//===----------------------------------------------------------------------===//

mlir/lib/IR/Operation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ void Operation::moveAfter(Operation *existingOp) {
505505
void Operation::moveAfter(Block *block,
506506
llvm::iplist<Operation>::iterator iterator) {
507507
assert(iterator != block->end() && "cannot move after end of block");
508-
moveBefore(&*std::next(iterator));
508+
moveBefore(block, std::next(iterator));
509509
}
510510

511511
/// This drops all operand uses from this operation, which is an essential

mlir/test/python/ir/operation.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,3 +740,66 @@ def testOperationLoc():
740740
op = Operation.create("custom.op", loc=loc)
741741
assert op.location == loc
742742
assert op.operation.location == loc
743+
744+
# CHECK-LABEL: TEST: testModuleMerge
745+
@run
746+
def testModuleMerge():
747+
with Context():
748+
m1 = Module.parse("func private @foo()")
749+
m2 = Module.parse("""
750+
func private @bar()
751+
func private @qux()
752+
""")
753+
foo = m1.body.operations[0]
754+
bar = m2.body.operations[0]
755+
qux = m2.body.operations[1]
756+
bar.move_before(foo)
757+
qux.move_after(foo)
758+
759+
# CHECK: module
760+
# CHECK: func private @bar
761+
# CHECK: func private @foo
762+
# CHECK: func private @qux
763+
print(m1)
764+
765+
# CHECK: module {
766+
# CHECK-NEXT: }
767+
print(m2)
768+
769+
770+
# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock
771+
@run
772+
def testAppendMoveFromAnotherBlock():
773+
with Context():
774+
m1 = Module.parse("func private @foo()")
775+
m2 = Module.parse("func private @bar()")
776+
func = m1.body.operations[0]
777+
m2.body.append(func)
778+
779+
# CHECK: module
780+
# CHECK: func private @bar
781+
# CHECK: func private @foo
782+
783+
print(m2)
784+
# CHECK: module {
785+
# CHECK-NEXT: }
786+
print(m1)
787+
788+
789+
# CHECK-LABEL: TEST: testDetachFromParent
790+
@run
791+
def testDetachFromParent():
792+
with Context():
793+
m1 = Module.parse("func private @foo()")
794+
func = m1.body.operations[0].detach_from_parent()
795+
796+
try:
797+
func.detach_from_parent()
798+
except ValueError as e:
799+
if "has no parent" not in str(e):
800+
raise
801+
else:
802+
assert False, "expected ValueError when detaching a detached operation"
803+
804+
print(m1)
805+
# CHECK-NOT: func private @foo

0 commit comments

Comments
 (0)