Skip to content

Commit fa88c18

Browse files
authored
[MLIR][Affine] Add default null init for mlir::affine::MemRefAccess (#147922)
Add default null init for `mlir::affine::MemRefAccess`. This is consistent with various other MLIR structures and had been missing for `mlir::affine::MemRefAccess`.
1 parent 968d38d commit fa88c18

File tree

2 files changed

+26
-23
lines changed

2 files changed

+26
-23
lines changed

mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,13 @@ LogicalResult getIndexSet(MutableArrayRef<Operation *> ops,
8181
/// Encapsulates a memref load or store access information.
8282
struct MemRefAccess {
8383
Value memref;
84-
Operation *opInst;
84+
Operation *opInst = nullptr;
8585
SmallVector<Value, 4> indices;
8686

87-
/// Constructs a MemRefAccess from a load or store operation.
88-
// TODO: add accessors to standard op's load, store, DMA op's to return
89-
// MemRefAccess, i.e., loadOp->getAccess(), dmaOp->getRead/WriteAccess.
90-
explicit MemRefAccess(Operation *opInst);
87+
/// Constructs a MemRefAccess from an affine read/write operation.
88+
explicit MemRefAccess(Operation *memOp);
89+
90+
MemRefAccess() = default;
9191

9292
// Returns the rank of the memref associated with this access.
9393
unsigned getRank() const;
@@ -126,10 +126,12 @@ struct MemRefAccess {
126126
/// time (considering the memrefs, their respective affine access maps and
127127
/// operands). The equality of access functions + operands is checked by
128128
/// subtracting fully composed value maps, and then simplifying the difference
129-
/// using the expression flattener.
130-
/// TODO: this does not account for aliasing of memrefs.
129+
/// using the expression flattener. This does not account for aliasing of
130+
/// memrefs.
131131
bool operator==(const MemRefAccess &rhs) const;
132132
bool operator!=(const MemRefAccess &rhs) const { return !(*this == rhs); }
133+
134+
explicit operator bool() const { return !!memref; }
133135
};
134136

135137
// DependenceComponent contains state about the direction of a dependence as an

mlir/lib/Dialect/Affine/Analysis/Utils.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,15 +1550,17 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
15501550
FlatAffineValueConstraints sliceUnionCst;
15511551
assert(sliceUnionCst.getNumDimAndSymbolVars() == 0);
15521552
std::vector<std::pair<Operation *, Operation *>> dependentOpPairs;
1553-
for (Operation *i : opsA) {
1554-
MemRefAccess srcAccess(i);
1555-
for (Operation *j : opsB) {
1556-
MemRefAccess dstAccess(j);
1553+
MemRefAccess srcAccess;
1554+
MemRefAccess dstAccess;
1555+
for (Operation *a : opsA) {
1556+
srcAccess = MemRefAccess(a);
1557+
for (Operation *b : opsB) {
1558+
dstAccess = MemRefAccess(b);
15571559
if (srcAccess.memref != dstAccess.memref)
15581560
continue;
15591561
// Check if 'loopDepth' exceeds nesting depth of src/dst ops.
1560-
if ((!isBackwardSlice && loopDepth > getNestingDepth(i)) ||
1561-
(isBackwardSlice && loopDepth > getNestingDepth(j))) {
1562+
if ((!isBackwardSlice && loopDepth > getNestingDepth(a)) ||
1563+
(isBackwardSlice && loopDepth > getNestingDepth(b))) {
15621564
LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n");
15631565
return SliceComputationResult::GenericFailure;
15641566
}
@@ -1577,13 +1579,12 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
15771579
}
15781580
if (result.value == DependenceResult::NoDependence)
15791581
continue;
1580-
dependentOpPairs.emplace_back(i, j);
1582+
dependentOpPairs.emplace_back(a, b);
15811583

15821584
// Compute slice bounds for 'srcAccess' and 'dstAccess'.
15831585
ComputationSliceState tmpSliceState;
1584-
mlir::affine::getComputationSliceState(i, j, dependenceConstraints,
1585-
loopDepth, isBackwardSlice,
1586-
&tmpSliceState);
1586+
getComputationSliceState(a, b, dependenceConstraints, loopDepth,
1587+
isBackwardSlice, &tmpSliceState);
15871588

15881589
if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
15891590
// Initialize 'sliceUnionCst' with the bounds computed in previous step.
@@ -1948,16 +1949,16 @@ AffineForOp mlir::affine::insertBackwardComputationSlice(
19481949

19491950
// Constructs MemRefAccess populating it with the memref, its indices and
19501951
// opinst from 'loadOrStoreOpInst'.
1951-
MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
1952-
if (auto loadOp = dyn_cast<AffineReadOpInterface>(loadOrStoreOpInst)) {
1952+
MemRefAccess::MemRefAccess(Operation *memOp) {
1953+
if (auto loadOp = dyn_cast<AffineReadOpInterface>(memOp)) {
19531954
memref = loadOp.getMemRef();
1954-
opInst = loadOrStoreOpInst;
1955+
opInst = memOp;
19551956
llvm::append_range(indices, loadOp.getMapOperands());
19561957
} else {
1957-
assert(isa<AffineWriteOpInterface>(loadOrStoreOpInst) &&
1958+
assert(isa<AffineWriteOpInterface>(memOp) &&
19581959
"Affine read/write op expected");
1959-
auto storeOp = cast<AffineWriteOpInterface>(loadOrStoreOpInst);
1960-
opInst = loadOrStoreOpInst;
1960+
auto storeOp = cast<AffineWriteOpInterface>(memOp);
1961+
opInst = memOp;
19611962
memref = storeOp.getMemRef();
19621963
llvm::append_range(indices, storeOp.getMapOperands());
19631964
}

0 commit comments

Comments
 (0)