diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h b/mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h index 4134aef8174bc..3e4b8648061ff 100644 --- a/mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h @@ -81,13 +81,13 @@ LogicalResult getIndexSet(MutableArrayRef ops, /// Encapsulates a memref load or store access information. struct MemRefAccess { Value memref; - Operation *opInst; + Operation *opInst = nullptr; SmallVector indices; - /// Constructs a MemRefAccess from a load or store operation. - // TODO: add accessors to standard op's load, store, DMA op's to return - // MemRefAccess, i.e., loadOp->getAccess(), dmaOp->getRead/WriteAccess. - explicit MemRefAccess(Operation *opInst); + /// Constructs a MemRefAccess from an affine read/write operation. + explicit MemRefAccess(Operation *memOp); + + MemRefAccess() = default; // Returns the rank of the memref associated with this access. unsigned getRank() const; @@ -126,10 +126,12 @@ struct MemRefAccess { /// time (considering the memrefs, their respective affine access maps and /// operands). The equality of access functions + operands is checked by /// subtracting fully composed value maps, and then simplifying the difference - /// using the expression flattener. - /// TODO: this does not account for aliasing of memrefs. + /// using the expression flattener. This does not account for aliasing of + /// memrefs. bool operator==(const MemRefAccess &rhs) const; bool operator!=(const MemRefAccess &rhs) const { return !(*this == rhs); } + + explicit operator bool() const { return !!memref; } }; // DependenceComponent contains state about the direction of a dependence as an diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp index 8bdb4c3593335..4739290bf6e4b 100644 --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -1550,15 +1550,17 @@ mlir::affine::computeSliceUnion(ArrayRef opsA, FlatAffineValueConstraints sliceUnionCst; assert(sliceUnionCst.getNumDimAndSymbolVars() == 0); std::vector> dependentOpPairs; - for (Operation *i : opsA) { - MemRefAccess srcAccess(i); - for (Operation *j : opsB) { - MemRefAccess dstAccess(j); + MemRefAccess srcAccess; + MemRefAccess dstAccess; + for (Operation *a : opsA) { + srcAccess = MemRefAccess(a); + for (Operation *b : opsB) { + dstAccess = MemRefAccess(b); if (srcAccess.memref != dstAccess.memref) continue; // Check if 'loopDepth' exceeds nesting depth of src/dst ops. - if ((!isBackwardSlice && loopDepth > getNestingDepth(i)) || - (isBackwardSlice && loopDepth > getNestingDepth(j))) { + if ((!isBackwardSlice && loopDepth > getNestingDepth(a)) || + (isBackwardSlice && loopDepth > getNestingDepth(b))) { LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n"); return SliceComputationResult::GenericFailure; } @@ -1577,13 +1579,12 @@ mlir::affine::computeSliceUnion(ArrayRef opsA, } if (result.value == DependenceResult::NoDependence) continue; - dependentOpPairs.emplace_back(i, j); + dependentOpPairs.emplace_back(a, b); // Compute slice bounds for 'srcAccess' and 'dstAccess'. ComputationSliceState tmpSliceState; - mlir::affine::getComputationSliceState(i, j, dependenceConstraints, - loopDepth, isBackwardSlice, - &tmpSliceState); + getComputationSliceState(a, b, dependenceConstraints, loopDepth, + isBackwardSlice, &tmpSliceState); if (sliceUnionCst.getNumDimAndSymbolVars() == 0) { // Initialize 'sliceUnionCst' with the bounds computed in previous step. @@ -1948,16 +1949,16 @@ AffineForOp mlir::affine::insertBackwardComputationSlice( // Constructs MemRefAccess populating it with the memref, its indices and // opinst from 'loadOrStoreOpInst'. -MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { - if (auto loadOp = dyn_cast(loadOrStoreOpInst)) { +MemRefAccess::MemRefAccess(Operation *memOp) { + if (auto loadOp = dyn_cast(memOp)) { memref = loadOp.getMemRef(); - opInst = loadOrStoreOpInst; + opInst = memOp; llvm::append_range(indices, loadOp.getMapOperands()); } else { - assert(isa(loadOrStoreOpInst) && + assert(isa(memOp) && "Affine read/write op expected"); - auto storeOp = cast(loadOrStoreOpInst); - opInst = loadOrStoreOpInst; + auto storeOp = cast(memOp); + opInst = memOp; memref = storeOp.getMemRef(); llvm::append_range(indices, storeOp.getMapOperands()); }