Skip to content

Commit ed499dd

Browse files
committed
[MLIR] Fix operation clone
Operation clone is currently faulty. Suppose you have a block like as follows: ``` (%x0 : i32) { %x1 = f(%x0) return %x1 } ``` The test case we have is that we want to "unroll" this, in which we want to change this to compute `f(f(x0))` instead of just `f(x0)`. We do so by making a copy of the body at the end of the block and set the uses of the argument in the copy operations with the value returned from the original block. This is implemented as follows: 1) map to the block arguments to the returned value (`map[x0] = x1`). 2) clone the body Now for this small example, this works as intended and we get the following. ``` (%x0 : i32) { %x1 = f(%x0) %x2 = f(%x1) return %x2 } ``` This is because the current logic to clone `x1 = f(x0)` first looks up the arguments in the map (which finds `x0` maps to `x1` from the initialization), and then sets the map of the result to the cloned result (`map[x1] = x2`). However, this fails if `x0` is not an argument to the op, but instead used inside the region, like below. ``` (%x0 : i32) { %x1 = f() { yield %x0 } return %x1 } ``` This is because cloning an op currently first looks up the args (none), sets the map of the result (`map[%x1] = %x2`), and then clones the regions. This results in the following, which is clearly illegal: ``` (%x0 : i32) { %x1 = f() { yield %x0 } %x2 = f() { yield %x2 } return %x2 } ``` Diving deeper, this is partially due to the ordering (how this PR fixes it), as well as how region cloning works. Namely it will first clone with the mapping, and then it will remap all operands. Since the ordering above now has a map of `x0 -> x1` and `x1 -> x2`, we end up with the incorrect behavior here. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D122531
1 parent ca2be81 commit ed499dd

File tree

6 files changed

+101
-5
lines changed

6 files changed

+101
-5
lines changed

mlir/include/mlir/IR/Operation.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,10 @@ class alignas(8) Operation final
8585
/// original one, but they will be left empty.
8686
/// Operands are remapped using `mapper` (if present), and `mapper` is updated
8787
/// to contain the results.
88-
Operation *cloneWithoutRegions(BlockAndValueMapping &mapper);
88+
/// The `mapResults` argument specifies whether the results of the operation
89+
/// should also be mapped.
90+
Operation *cloneWithoutRegions(BlockAndValueMapping &mapper,
91+
bool mapResults = true);
8992

9093
/// Create a partial copy of this operation without traversing into attached
9194
/// regions. The new operation will have the same number of regions as the

mlir/lib/IR/Operation.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,8 @@ InFlightDiagnostic Operation::emitOpError(const Twine &message) {
526526
/// Create a deep copy of this operation but keep the operation regions empty.
527527
/// Operands are remapped using `mapper` (if present), and `mapper` is updated
528528
/// to contain the results.
529-
Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) {
529+
Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper,
530+
bool mapResults) {
530531
SmallVector<Value, 8> operands;
531532
SmallVector<Block *, 2> successors;
532533

@@ -545,8 +546,10 @@ Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) {
545546
successors, getNumRegions());
546547

547548
// Remember the mapping of any results.
548-
for (unsigned i = 0, e = getNumResults(); i != e; ++i)
549-
mapper.map(getResult(i), newOp->getResult(i));
549+
if (mapResults) {
550+
for (unsigned i = 0, e = getNumResults(); i != e; ++i)
551+
mapper.map(getResult(i), newOp->getResult(i));
552+
}
550553

551554
return newOp;
552555
}
@@ -562,12 +565,15 @@ Operation *Operation::cloneWithoutRegions() {
562565
/// sub-operations to the corresponding operation that is copied, and adds
563566
/// those mappings to the map.
564567
Operation *Operation::clone(BlockAndValueMapping &mapper) {
565-
auto *newOp = cloneWithoutRegions(mapper);
568+
auto *newOp = cloneWithoutRegions(mapper, /*mapResults=*/false);
566569

567570
// Clone the regions.
568571
for (unsigned i = 0; i != numRegions; ++i)
569572
getRegion(i).cloneInto(&newOp->getRegion(i), mapper);
570573

574+
for (unsigned i = 0, e = getNumResults(); i != e; ++i)
575+
mapper.map(getResult(i), newOp->getResult(i));
576+
571577
return newOp;
572578
}
573579

mlir/test/IR/test-clone.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="func.func(test-clone)" -split-input-file
2+
3+
module {
4+
func @fixpoint(%arg1 : i32) -> i32 {
5+
%r = "test.use"(%arg1) ({
6+
"test.yield"(%arg1) : (i32) -> ()
7+
}) : (i32) -> i32
8+
return %r : i32
9+
}
10+
}
11+
12+
// CHECK: func @fixpoint(%[[arg0:.+]]: i32) -> i32 {
13+
// CHECK-NEXT: %[[i0:.+]] = "test.use"(%[[arg0]]) ({
14+
// CHECK-NEXT: "test.yield"(%arg0) : (i32) -> ()
15+
// CHECK-NEXT: }) : (i32) -> i32
16+
// CHECK-NEXT: %[[i1:.+]] = "test.use"(%[[i0]]) ({
17+
// CHECK-NEXT: "test.yield"(%[[i0]]) : (i32) -> ()
18+
// CHECK-NEXT: }) : (i32) -> i32
19+
// CHECK-NEXT: return %[[i1]] : i32
20+
// CHECK-NEXT: }

mlir/test/lib/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Exclude tests from libMLIR.so
22
add_mlir_library(MLIRTestIR
33
TestBuiltinAttributeInterfaces.cpp
4+
TestClone.cpp
45
TestDiagnostics.cpp
56
TestDominance.cpp
67
TestFunc.cpp

mlir/test/lib/IR/TestClone.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//===- TestSymbolUses.cpp - Pass to test symbol uselists ------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "TestDialect.h"
10+
#include "mlir/IR/BuiltinOps.h"
11+
#include "mlir/Pass/Pass.h"
12+
13+
using namespace mlir;
14+
15+
namespace {
16+
17+
/// This is a test pass which clones the body of a function. Specifically
18+
/// this pass replaces f(x) to instead return f(f(x)) in which the cloned body
19+
/// takes the result of the first operation return as an input.
20+
struct ClonePass
21+
: public PassWrapper<ClonePass, InterfacePass<FunctionOpInterface>> {
22+
StringRef getArgument() const final { return "test-clone"; }
23+
StringRef getDescription() const final { return "Test clone of op"; }
24+
void runOnOperation() override {
25+
FunctionOpInterface op = getOperation();
26+
27+
// Limit testing to ops with only one region.
28+
if (op->getNumRegions() != 1)
29+
return;
30+
31+
Region &region = op->getRegion(0);
32+
if (!region.hasOneBlock())
33+
return;
34+
35+
Block &regionEntry = region.front();
36+
auto terminator = regionEntry.getTerminator();
37+
38+
// Only handle functions whose returns match the inputs.
39+
if (terminator->getNumOperands() != regionEntry.getNumArguments())
40+
return;
41+
42+
BlockAndValueMapping map;
43+
for (auto tup :
44+
llvm::zip(terminator->getOperands(), regionEntry.getArguments())) {
45+
if (std::get<0>(tup).getType() != std::get<1>(tup).getType())
46+
return;
47+
map.map(std::get<1>(tup), std::get<0>(tup));
48+
}
49+
50+
OpBuilder B(op->getContext());
51+
B.setInsertionPointToEnd(&regionEntry);
52+
SmallVector<Operation *> toClone;
53+
for (Operation &inst : regionEntry)
54+
toClone.push_back(&inst);
55+
for (Operation *inst : toClone)
56+
B.clone(*inst, map);
57+
terminator->erase();
58+
}
59+
};
60+
} // namespace
61+
62+
namespace mlir {
63+
void registerCloneTestPasses() { PassRegistration<ClonePass>(); }
64+
} // namespace mlir

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ using namespace mlir;
3030
// Defined in the test directory, no public header.
3131
namespace mlir {
3232
void registerConvertToTargetEnvPass();
33+
void registerCloneTestPasses();
3334
void registerPassManagerTestPass();
3435
void registerPrintSpirvAvailabilityPass();
3536
void registerShapeFunctionTestPasses();
@@ -119,6 +120,7 @@ void registerTestTransformDialectExtension(DialectRegistry &);
119120

120121
#ifdef MLIR_INCLUDE_TESTS
121122
void registerTestPasses() {
123+
registerCloneTestPasses();
122124
registerConvertToTargetEnvPass();
123125
registerPassManagerTestPass();
124126
registerPrintSpirvAvailabilityPass();

0 commit comments

Comments
 (0)