Skip to content

[mlir] Add requiresReplacedValues and visitReplacedValues to PromotableOpInterface #86792

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,10 @@ class LLVM_DbgIntrOp<string name, string argName, list<Trait> traits = []>
}];
}

def LLVM_DbgDeclareOp : LLVM_DbgIntrOp<"dbg.declare", "addr",
[DeclareOpInterfaceMethods<PromotableOpInterface>]> {
def LLVM_DbgDeclareOp : LLVM_DbgIntrOp<"dbg.declare", "addr", [
DeclareOpInterfaceMethods<PromotableOpInterface, [
"requiresVisitingMutatedDefs", "visitMutatedDefs"
]>]> {
let summary = "Describes how the address relates to a source language variable.";
let arguments = (ins
LLVM_AnyPointer:$addr,
Expand Down
21 changes: 21 additions & 0 deletions mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,27 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
(ins "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
"::mlir::RewriterBase &":$rewriter)
>,
InterfaceMethod<[{
Checks whether the operation requires visiting the mutated
definitions by a store operation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain what this means more in the documentation?

}], "bool", "requiresVisitingMutatedDefs", (ins), [{}],
[{ return false; }]
>,
InterfaceMethod<[{
Visits all the mutated definitions by a store operation.

This method will only be called after all blocking uses have been
scheduled for removal and if `requiresVisitingMutatedDefs` returned
true.

The rewriter is located after the promotable operation on call. All IR
mutations must happen through the rewriter. During the transformation,
*no operation should be deleted*.
}],
"void", "visitMutatedDefs",
(ins "::llvm::ArrayRef<std::pair<::mlir::Operation*, ::mlir::Value>>":$mutatedDefs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name suggestion: replacingValues

"::mlir::RewriterBase &":$rewriter), [{}], [{ return; }]
Comment on lines +259 to +260
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like to be absolutely complete this method should also provide the memory slot, as one may want to know which memory slot specifically those definitions correspond to. But at the same time I am not really sure that this would be useful in practice, so change it only if you feel like it.

>,
];
}

Expand Down
19 changes: 12 additions & 7 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,6 @@ bool LLVM::StoreOp::canUsesBeRemoved(
DeletionKind LLVM::StoreOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
RewriterBase &rewriter, Value reachingDefinition) {
// `canUsesBeRemoved` checked this blocking use must be the stored slot
// pointer.
for (Operation *user : slot.ptr.getUsers())
if (auto declareOp = dyn_cast<LLVM::DbgDeclareOp>(user))
rewriter.create<LLVM::DbgValueOp>(declareOp->getLoc(), getValue(),
declareOp.getVarInfo(),
declareOp.getLocationExpr());
return DeletionKind::Delete;
}

Expand Down Expand Up @@ -405,6 +398,18 @@ DeletionKind LLVM::DbgValueOp::removeBlockingUses(
return DeletionKind::Keep;
}

bool LLVM::DbgDeclareOp::requiresVisitingMutatedDefs() { return true; }

void LLVM::DbgDeclareOp::visitMutatedDefs(
ArrayRef<std::pair<Operation *, Value>> definitions,
RewriterBase &rewriter) {
for (auto [op, value] : definitions) {
rewriter.setInsertionPointAfter(op);
rewriter.create<LLVM::DbgValueOp>(getLoc(), value, getVarInfo(),
getLocationExpr());
}
}

//===----------------------------------------------------------------------===//
// Interfaces for GEPOp
//===----------------------------------------------------------------------===//
Expand Down
14 changes: 13 additions & 1 deletion mlir/lib/Transforms/Mem2Reg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ class MemorySlotPromoter {
/// Contains the reaching definition at this operation. Reaching definitions
/// are only computed for promotable memory operations with blocking uses.
DenseMap<PromotableMemOpInterface, Value> reachingDefs;
DenseMap<PromotableMemOpInterface, Value> mutatedValues;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really dislike the "mutatedValues" name because it is unclear to me in what way those values are mutated. This name suggests you put something in those values when what happens conceptually is the opposite.

DominanceInfo &dominance;
MemorySlotPromotionInfo info;
const Mem2RegStatistics &statistics;
Expand Down Expand Up @@ -438,6 +439,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
assert(stored && "a memory operation storing to a slot must provide a "
"new definition of the slot");
reachingDef = stored;
mutatedValues[memOp] = stored;
}
}
}
Expand Down Expand Up @@ -552,6 +554,8 @@ void MemorySlotPromoter::removeBlockingUses() {
dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent());

llvm::SmallVector<Operation *> toErase;
llvm::SmallVector<std::pair<Operation *, Value>> mutatedDefinitions;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above for this name.

llvm::SmallVector<PromotableOpInterface> visitMutatedDefinitions;
for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
Expand All @@ -565,7 +569,9 @@ void MemorySlotPromoter::removeBlockingUses() {
slot, info.userToBlockingUses[toPromote], rewriter,
reachingDef) == DeletionKind::Delete)
toErase.push_back(toPromote);

if (toPromoteMemOp.storesTo(slot))
if (Value mutatedValue = mutatedValues[toPromoteMemOp])
mutatedDefinitions.push_back({toPromoteMemOp, mutatedValue});
continue;
}

Expand All @@ -574,6 +580,12 @@ void MemorySlotPromoter::removeBlockingUses() {
if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
rewriter) == DeletionKind::Delete)
toErase.push_back(toPromote);
if (toPromoteBasic.requiresVisitingMutatedDefs())
visitMutatedDefinitions.push_back(toPromoteBasic);
}
for (PromotableOpInterface op : visitMutatedDefinitions) {
rewriter.setInsertionPointAfter(op);
op.visitMutatedDefs(mutatedDefinitions, rewriter);
}

for (Operation *toEraseOp : toErase)
Expand Down