Skip to content

Commit 4cca22f

Browse files
[mlir][memref] Do not access erased op in memref.global lowering (#148355)
Do not access the erased `memref.global` operation in the lowering pattern. That won't work anymore in a One-Shot Dialect Conversion and triggers a use-after-free sanitizer error. After the One-Shot Dialect Conversion refactoring, a `ConversionPatternRewriter` will behave more like a normal `PatternRewriter`.
1 parent c4cc357 commit 4cca22f

File tree

1 file changed

+14
-18
lines changed

1 file changed

+14
-18
lines changed

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -754,9 +754,11 @@ class GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<memref::GlobalOp> {
754754

755755
LLVM::Linkage linkage =
756756
global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
757+
bool isExternal = global.isExternal();
758+
bool isUninitialized = global.isUninitialized();
757759

758760
Attribute initialValue = nullptr;
759-
if (!global.isExternal() && !global.isUninitialized()) {
761+
if (!isExternal && !isUninitialized) {
760762
auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
761763
initialValue = elementsAttr;
762764

@@ -773,35 +775,29 @@ class GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<memref::GlobalOp> {
773775
return global.emitOpError(
774776
"memory space cannot be converted to an integer address space");
775777

778+
// Remove old operation from symbol table.
779+
SymbolTable *symbolTable = nullptr;
776780
if (symbolTables) {
777781
Operation *symbolTableOp =
778782
global->getParentWithTrait<OpTrait::SymbolTable>();
779-
780-
if (symbolTableOp) {
781-
SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
782-
symbolTable.remove(global);
783-
}
783+
symbolTable = &symbolTables->getSymbolTable(symbolTableOp);
784+
symbolTable->remove(global);
784785
}
785786

787+
// Create new operation.
786788
auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
787789
global, arrayTy, global.getConstant(), linkage, global.getSymName(),
788790
initialValue, alignment, *addressSpace);
789791

790-
if (symbolTables) {
791-
Operation *symbolTableOp =
792-
global->getParentWithTrait<OpTrait::SymbolTable>();
793-
794-
if (symbolTableOp) {
795-
SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
796-
symbolTable.insert(newGlobal, rewriter.getInsertionPoint());
797-
}
798-
}
792+
// Insert new operation into symbol table.
793+
if (symbolTable)
794+
symbolTable->insert(newGlobal, rewriter.getInsertionPoint());
799795

800-
if (!global.isExternal() && global.isUninitialized()) {
796+
if (!isExternal && isUninitialized) {
801797
rewriter.createBlock(&newGlobal.getInitializerRegion());
802798
Value undef[] = {
803-
rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
804-
rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
799+
rewriter.create<LLVM::UndefOp>(newGlobal.getLoc(), arrayTy)};
800+
rewriter.create<LLVM::ReturnOp>(newGlobal.getLoc(), undef);
805801
}
806802
return success();
807803
}

0 commit comments

Comments
 (0)