From a83380ab62bb6f78e5baf32ac3db3e97d6c97b69 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Tue, 10 Dec 2024 21:17:13 +0000 Subject: [PATCH] [mlir][IntRangeAnalysis] Handle unstructured loop arguments correctly The integer range analysis currently has a bug where, because of how it interacts with dead code analysis, it will sometimes declare code dead that isn't dead, becaues it hasn't seen the edge that loops an incremented value back to itself yet. This commit fixes the issue by overriding the join method on lattice values in order to detect these back-edges on non-entry blocks and then snapping the passed-around value to its maximum possible range, just like we do for loop-varying values in region control flow. Fixes #119045 --- .../Analysis/DataFlow/IntegerRangeAnalysis.h | 10 ++++ .../DataFlow/IntegerRangeAnalysis.cpp | 46 ++++++++++++++ mlir/test/Dialect/Arith/int-range-opts.mlir | 60 +++++++++++++++++++ 3 files changed, 116 insertions(+) diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h index f99eae379596b..464a47355b420 100644 --- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h @@ -31,6 +31,16 @@ class IntegerValueRangeLattice : public Lattice { public: using Lattice::Lattice; + /// Override the join logic so that arguments to non-entry blocks + /// whose arguments come from later in the program get set to + /// a maximal value so that we don't prematurely declare code to be + /// deade. + ChangeResult join(const AbstractSparseLattice &rhs) override; + + ChangeResult join(const IntegerValueRange &range) { + return Lattice::join(range); + } + /// If the range can be narrowed to an integer constant, update the constant /// value of the SSA value. void onUpdate(DataFlowSolver *solver) const override; diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp index a97e43708d9a3..a45fcee345e91 100644 --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -37,6 +37,50 @@ using namespace mlir; using namespace mlir::dataflow; +/// Return true if `block` is a non-entry block with a predecessor that's +/// defined after the block. This allows us to detect loop-varying values +/// in unstructured control flow. +static bool isLoopLikeBlock(Block *block) { + if (!block || block->isEntryBlock()) + return false; + Region *parent = block->getParent(); + if (!parent) + return false; + + SmallPtrSet preds; + for (Block *pred : block->getPredecessors()) + preds.insert(pred); + if (preds.size() <= 1) + return false; + + for (Block ®ionBlock : parent->getBlocks()) { + if (®ionBlock == block) + break; + preds.erase(®ionBlock); + } + + // The block loops back on itself or has an edge from further in the program. + return !preds.empty(); +} + +ChangeResult IntegerValueRangeLattice::join(const AbstractSparseLattice &rhs) { + Value lhsAnchor = getAnchor(); + Block *lhsBlock = lhsAnchor.getParentBlock(); + unsigned width = ConstantIntRanges::getStorageBitwidth(lhsAnchor.getType()); + /// Special-case: we're in unstructured control flow and one of the + /// predecessors of this block argument is defined in a block that comes after + /// the argument. So we conservatively conclude that the value could be + /// anything. + if (width > 0 && isa(lhsAnchor) && isLoopLikeBlock(lhsBlock)) { + LLVM_DEBUG(llvm::dbgs() << "Found loop-varying block argument " << lhsAnchor + << " from " << rhs.getAnchor() << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Inferring maximum range\n"); + IntegerValueRange maxRange = IntegerValueRange::getMaxRange(lhsAnchor); + return join(maxRange); + } + return Lattice::join(rhs); +} + void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const { Lattice::onUpdate(solver); @@ -206,6 +250,8 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( if (max.sge(min)) { IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv); auto ivRange = ConstantIntRanges::fromSigned(min, max); + LLVM_DEBUG(llvm::dbgs() + << "Inferred loop bound range: " << ivRange << "\n"); propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange})); } return; diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir index ea5969a100258..e312cf175f5b5 100644 --- a/mlir/test/Dialect/Arith/int-range-opts.mlir +++ b/mlir/test/Dialect/Arith/int-range-opts.mlir @@ -132,3 +132,63 @@ func.func @wraps() -> i8 { %mod = arith.remsi %val, %c64 : i8 return %mod : i8 } + +// ----- + +// Note: I wish I had a simpler example than this, but getting rid of a +// bunch of the arithmetic made the issue go away. +// CHECK-LABEL: @blocks_prematurely_declared_dead_bug +// CHECK-NOT: arith.constant true +func.func @blocks_prematurely_declared_dead_bug(%mem: memref) { + %cst = arith.constant dense : vector<1xi1> + %c1 = arith.constant 1 : index + %cst_0 = arith.constant dense<0.000000e+00> : vector<1xf16> + %cst_1 = arith.constant 0.000000e+00 : f16 + %c16 = arith.constant 16 : index + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %thread_id_x = gpu.thread_id x upper_bound 64 + %6 = test.with_bounds { smin = 16 : index, smax = 112 : index, umin = 16 : index, umax = 112 : index } : index + %8 = arith.divui %6, %c16 : index + %9 = arith.muli %8, %c16 : index + cf.br ^bb1(%c0 : index) +^bb1(%12: index): // 2 preds: ^bb0, ^bb7 + %13 = arith.cmpi slt, %12, %9 : index + cf.cond_br %13, ^bb2, ^bb8 +^bb2: // pred: ^bb1 + %14 = arith.subi %9, %12 : index + %15 = arith.minsi %14, %c64 : index + %16 = arith.subi %15, %thread_id_x : index + %17 = vector.constant_mask [1] : vector<1xi1> + %18 = arith.cmpi sgt, %16, %c0 : index + %19 = arith.select %18, %17, %cst : vector<1xi1> + %20 = vector.extract %19[0] : i1 from vector<1xi1> + %21 = vector.insert %20, %cst [0] : i1 into vector<1xi1> + %22 = arith.addi %12, %thread_id_x : index + cf.br ^bb3(%c0, %cst_0 : index, vector<1xf16>) +^bb3(%23: index, %24: vector<1xf16>): // 2 preds: ^bb2, ^bb6 + %25 = arith.cmpi slt, %23, %c1 : index + cf.cond_br %25, ^bb4, ^bb7 +^bb4: // pred: ^bb3 + %26 = vector.extractelement %21[%23 : index] : vector<1xi1> + cf.cond_br %26, ^bb5, ^bb6(%24 : vector<1xf16>) +^bb5: // pred: ^bb4 + %27 = arith.addi %22, %23 : index + %28 = memref.load %mem[%27] : memref + %29 = vector.insertelement %28, %24[%23 : index] : vector<1xf16> + cf.br ^bb6(%29 : vector<1xf16>) +^bb6(%30: vector<1xf16>): // 2 preds: ^bb4, ^bb5 + %31 = arith.addi %23, %c1 : index + cf.br ^bb3(%31, %30 : index, vector<1xf16>) +^bb7: // pred: ^bb3 + %37 = arith.addi %12, %c64 : index + cf.br ^bb1(%37 : index) +^bb8: // pred: ^bb1 + %70 = arith.cmpi eq, %thread_id_x, %c0 : index + cf.cond_br %70, ^bb9, ^bb10 +^bb9: // pred: ^bb8 + memref.store %cst_1, %mem[%c0] : memref + cf.br ^bb10 +^bb10: // 2 preds: ^bb8, ^bb9 + return +}