diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h index f99eae379596b..464a47355b420 100644 --- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h @@ -31,6 +31,16 @@ class IntegerValueRangeLattice : public Lattice { public: using Lattice::Lattice; + /// Override the join logic so that arguments to non-entry blocks + /// whose arguments come from later in the program get set to + /// a maximal value so that we don't prematurely declare code to be + /// deade. + ChangeResult join(const AbstractSparseLattice &rhs) override; + + ChangeResult join(const IntegerValueRange &range) { + return Lattice::join(range); + } + /// If the range can be narrowed to an integer constant, update the constant /// value of the SSA value. void onUpdate(DataFlowSolver *solver) const override; diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp index a97e43708d9a3..a45fcee345e91 100644 --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -37,6 +37,50 @@ using namespace mlir; using namespace mlir::dataflow; +/// Return true if `block` is a non-entry block with a predecessor that's +/// defined after the block. This allows us to detect loop-varying values +/// in unstructured control flow. +static bool isLoopLikeBlock(Block *block) { + if (!block || block->isEntryBlock()) + return false; + Region *parent = block->getParent(); + if (!parent) + return false; + + SmallPtrSet preds; + for (Block *pred : block->getPredecessors()) + preds.insert(pred); + if (preds.size() <= 1) + return false; + + for (Block ®ionBlock : parent->getBlocks()) { + if (®ionBlock == block) + break; + preds.erase(®ionBlock); + } + + // The block loops back on itself or has an edge from further in the program. + return !preds.empty(); +} + +ChangeResult IntegerValueRangeLattice::join(const AbstractSparseLattice &rhs) { + Value lhsAnchor = getAnchor(); + Block *lhsBlock = lhsAnchor.getParentBlock(); + unsigned width = ConstantIntRanges::getStorageBitwidth(lhsAnchor.getType()); + /// Special-case: we're in unstructured control flow and one of the + /// predecessors of this block argument is defined in a block that comes after + /// the argument. So we conservatively conclude that the value could be + /// anything. + if (width > 0 && isa(lhsAnchor) && isLoopLikeBlock(lhsBlock)) { + LLVM_DEBUG(llvm::dbgs() << "Found loop-varying block argument " << lhsAnchor + << " from " << rhs.getAnchor() << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Inferring maximum range\n"); + IntegerValueRange maxRange = IntegerValueRange::getMaxRange(lhsAnchor); + return join(maxRange); + } + return Lattice::join(rhs); +} + void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const { Lattice::onUpdate(solver); @@ -206,6 +250,8 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( if (max.sge(min)) { IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv); auto ivRange = ConstantIntRanges::fromSigned(min, max); + LLVM_DEBUG(llvm::dbgs() + << "Inferred loop bound range: " << ivRange << "\n"); propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange})); } return; diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir index ea5969a100258..e312cf175f5b5 100644 --- a/mlir/test/Dialect/Arith/int-range-opts.mlir +++ b/mlir/test/Dialect/Arith/int-range-opts.mlir @@ -132,3 +132,63 @@ func.func @wraps() -> i8 { %mod = arith.remsi %val, %c64 : i8 return %mod : i8 } + +// ----- + +// Note: I wish I had a simpler example than this, but getting rid of a +// bunch of the arithmetic made the issue go away. +// CHECK-LABEL: @blocks_prematurely_declared_dead_bug +// CHECK-NOT: arith.constant true +func.func @blocks_prematurely_declared_dead_bug(%mem: memref) { + %cst = arith.constant dense : vector<1xi1> + %c1 = arith.constant 1 : index + %cst_0 = arith.constant dense<0.000000e+00> : vector<1xf16> + %cst_1 = arith.constant 0.000000e+00 : f16 + %c16 = arith.constant 16 : index + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %thread_id_x = gpu.thread_id x upper_bound 64 + %6 = test.with_bounds { smin = 16 : index, smax = 112 : index, umin = 16 : index, umax = 112 : index } : index + %8 = arith.divui %6, %c16 : index + %9 = arith.muli %8, %c16 : index + cf.br ^bb1(%c0 : index) +^bb1(%12: index): // 2 preds: ^bb0, ^bb7 + %13 = arith.cmpi slt, %12, %9 : index + cf.cond_br %13, ^bb2, ^bb8 +^bb2: // pred: ^bb1 + %14 = arith.subi %9, %12 : index + %15 = arith.minsi %14, %c64 : index + %16 = arith.subi %15, %thread_id_x : index + %17 = vector.constant_mask [1] : vector<1xi1> + %18 = arith.cmpi sgt, %16, %c0 : index + %19 = arith.select %18, %17, %cst : vector<1xi1> + %20 = vector.extract %19[0] : i1 from vector<1xi1> + %21 = vector.insert %20, %cst [0] : i1 into vector<1xi1> + %22 = arith.addi %12, %thread_id_x : index + cf.br ^bb3(%c0, %cst_0 : index, vector<1xf16>) +^bb3(%23: index, %24: vector<1xf16>): // 2 preds: ^bb2, ^bb6 + %25 = arith.cmpi slt, %23, %c1 : index + cf.cond_br %25, ^bb4, ^bb7 +^bb4: // pred: ^bb3 + %26 = vector.extractelement %21[%23 : index] : vector<1xi1> + cf.cond_br %26, ^bb5, ^bb6(%24 : vector<1xf16>) +^bb5: // pred: ^bb4 + %27 = arith.addi %22, %23 : index + %28 = memref.load %mem[%27] : memref + %29 = vector.insertelement %28, %24[%23 : index] : vector<1xf16> + cf.br ^bb6(%29 : vector<1xf16>) +^bb6(%30: vector<1xf16>): // 2 preds: ^bb4, ^bb5 + %31 = arith.addi %23, %c1 : index + cf.br ^bb3(%31, %30 : index, vector<1xf16>) +^bb7: // pred: ^bb3 + %37 = arith.addi %12, %c64 : index + cf.br ^bb1(%37 : index) +^bb8: // pred: ^bb1 + %70 = arith.cmpi eq, %thread_id_x, %c0 : index + cf.cond_br %70, ^bb9, ^bb10 +^bb9: // pred: ^bb8 + memref.store %cst_1, %mem[%c0] : memref + cf.br ^bb10 +^bb10: // 2 preds: ^bb8, ^bb9 + return +}