diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp index 26e5c0572f12e..434d53e582d59 100644 --- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp +++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp @@ -16,9 +16,11 @@ #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "clang/AST/DeclBase.h" #include "clang/CIR/Dialect/IR/CIRDialect.h" #include "clang/CIR/Dialect/Passes.h" #include "clang/CIR/MissingFeatures.h" @@ -492,14 +494,6 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern { Block *condBlock = rewriter.getInsertionBlock(); Block::iterator opPosition = rewriter.getInsertionPoint(); Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); - llvm::SmallVector locs; - // Ternary result is optional, make sure to populate the location only - // when relevant. - if (op->getResultTypes().size()) - locs.push_back(loc); - Block *continueBlock = - rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs); - rewriter.create(loc, remainingOpsBlock); Region &trueRegion = op.getTrueRegion(); Block *trueBlock = &trueRegion.front(); @@ -508,24 +502,29 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern { auto trueYieldOp = dyn_cast(trueTerminator); rewriter.replaceOpWithNewOp(trueYieldOp, trueYieldOp.getArgs(), - continueBlock); - rewriter.inlineRegionBefore(trueRegion, continueBlock); + remainingOpsBlock); + rewriter.inlineRegionBefore(trueRegion, remainingOpsBlock); - Block *falseBlock = continueBlock; Region &falseRegion = op.getFalseRegion(); + Block *falseBlock = &falseRegion.front(); - falseBlock = &falseRegion.front(); mlir::Operation *falseTerminator = falseRegion.back().getTerminator(); rewriter.setInsertionPointToEnd(&falseRegion.back()); auto falseYieldOp = dyn_cast(falseTerminator); rewriter.replaceOpWithNewOp(falseYieldOp, falseYieldOp.getArgs(), - continueBlock); - rewriter.inlineRegionBefore(falseRegion, continueBlock); + remainingOpsBlock); + rewriter.inlineRegionBefore(falseRegion, remainingOpsBlock); rewriter.setInsertionPointToEnd(condBlock); rewriter.create(loc, op.getCond(), trueBlock, falseBlock); - rewriter.replaceOp(op, continueBlock->getArguments()); + if (ValueTypeRange rt = op.getResultTypes(); rt.size()) { + iterator_range args = remainingOpsBlock->addArguments(rt, op.getLoc()); + SmallVector values; + llvm::copy(args, std::back_inserter(values)); + rewriter.replaceOpUsesWithinBlock(op, values, remainingOpsBlock); + } + rewriter.eraseOp(op); // Ok, we're done! return mlir::success(); diff --git a/clang/test/CIR/Lowering/ternary.cir b/clang/test/CIR/Lowering/ternary.cir index 247c6ae3a1e17..c2226cd92ece7 100644 --- a/clang/test/CIR/Lowering/ternary.cir +++ b/clang/test/CIR/Lowering/ternary.cir @@ -25,6 +25,4 @@ module { // LLVM: br label %[[M]] // LLVM: [[M]]: // LLVM: [[R:%[[:alnum:]]+]] = phi i32 [ 1, %[[B2]] ], [ 0, %[[B1]] ] -// LLVM: br label %[[B3:[[:alnum:]]+]] -// LLVM: [[B3]]: // LLVM: ret i32 [[R]] diff --git a/clang/test/CIR/Transforms/ternary.cir b/clang/test/CIR/Transforms/ternary.cir index 67ef7f95a6b52..0c22268495697 100644 --- a/clang/test/CIR/Transforms/ternary.cir +++ b/clang/test/CIR/Transforms/ternary.cir @@ -37,8 +37,6 @@ module { // CHECK: %6 = cir.const #cir.int<5> : !s32i // CHECK: cir.br ^bb3(%6 : !s32i) // CHECK: ^bb3(%7: !s32i): // 2 preds: ^bb1, ^bb2 -// CHECK: cir.br ^bb4 -// CHECK: ^bb4: // pred: ^bb3 // CHECK: cir.store %7, %1 : !s32i, !cir.ptr // CHECK: %8 = cir.load %1 : !cir.ptr, !s32i // CHECK: cir.return %8 : !s32i @@ -60,8 +58,6 @@ module { // CHECK: ^bb2: // pred: ^bb0 // CHECK: cir.br ^bb3 // CHECK: ^bb3: // 2 preds: ^bb1, ^bb2 -// CHECK: cir.br ^bb4 -// CHECK: ^bb4: // pred: ^bb3 // CHECK: cir.return // CHECK: }