-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[flang] Introduce omp.target_allocmem and omp.target_freemem omp dialect ops. #145464
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4915fb8
e41a2c7
a864094
2f5b289
1e7a216
3879eb7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -125,10 +125,177 @@ struct PrivateClauseOpConversion | |||||
return mlir::success(); | ||||||
} | ||||||
}; | ||||||
|
||||||
static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) { | ||||||
auto module = op->getParentOfType<mlir::ModuleOp>(); | ||||||
if (mlir::LLVM::LLVMFuncOp mallocFunc = | ||||||
module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("omp_target_alloc")) | ||||||
return mallocFunc; | ||||||
mlir::OpBuilder moduleBuilder(module.getBodyRegion()); | ||||||
auto i64Ty = mlir::IntegerType::get(module->getContext(), 64); | ||||||
auto i32Ty = mlir::IntegerType::get(module->getContext(), 32); | ||||||
return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>( | ||||||
moduleBuilder.getUnknownLoc(), "omp_target_alloc", | ||||||
mlir::LLVM::LLVMFunctionType::get( | ||||||
mlir::LLVM::LLVMPointerType::get(module->getContext()), | ||||||
{i64Ty, i32Ty}, | ||||||
/*isVarArg=*/false)); | ||||||
} | ||||||
|
||||||
static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter, | ||||||
mlir::Type firType) { | ||||||
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType)) | ||||||
return converter.convertBoxTypeAsStruct(boxTy); | ||||||
return converter.convertType(firType); | ||||||
} | ||||||
|
||||||
static llvm::SmallVector<mlir::NamedAttribute> | ||||||
addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems this function does more than what the name suggests (i.e. copies old attributes, drops the old Should we do the attribute changes in-place (i.e. at the call site) instead? Since this function does not seem reusable in other scenarios. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see this is adapted from |
||||||
llvm::ArrayRef<mlir::NamedAttribute> attrs, | ||||||
int32_t numCallOperands) { | ||||||
llvm::SmallVector<mlir::NamedAttribute> newAttrs; | ||||||
newAttrs.reserve(attrs.size() + 2); | ||||||
|
||||||
for (mlir::NamedAttribute attr : attrs) { | ||||||
if (attr.getName() != "operandSegmentSizes") | ||||||
newAttrs.push_back(attr); | ||||||
} | ||||||
|
||||||
newAttrs.push_back(rewriter.getNamedAttr( | ||||||
"operandSegmentSizes", | ||||||
rewriter.getDenseI32ArrayAttr({numCallOperands, 0}))); | ||||||
newAttrs.push_back(rewriter.getNamedAttr("op_bundle_sizes", | ||||||
rewriter.getDenseI32ArrayAttr({}))); | ||||||
return newAttrs; | ||||||
} | ||||||
|
||||||
static mlir::LLVM::ConstantOp | ||||||
genConstantIndex(mlir::Location loc, mlir::Type ity, | ||||||
mlir::ConversionPatternRewriter &rewriter, | ||||||
std::int64_t offset) { | ||||||
auto cattr = rewriter.getI64IntegerAttr(offset); | ||||||
return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr); | ||||||
} | ||||||
|
||||||
static mlir::Value | ||||||
computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType, | ||||||
mlir::Type idxTy, | ||||||
mlir::ConversionPatternRewriter &rewriter, | ||||||
const mlir::DataLayout &dataLayout) { | ||||||
llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType); | ||||||
unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType); | ||||||
std::int64_t distance = llvm::alignTo(size, alignment); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we are using a deprected conversion here (see: https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/Support/TypeSize.h#L360). |
||||||
return genConstantIndex(loc, idxTy, rewriter, distance); | ||||||
} | ||||||
|
||||||
static mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy, | ||||||
mlir::ConversionPatternRewriter &rewriter, | ||||||
mlir::Type llTy, | ||||||
const mlir::DataLayout &dataLayout) { | ||||||
return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout); | ||||||
} | ||||||
|
||||||
template <typename OP> | ||||||
static mlir::Value | ||||||
genAllocationScaleSize(OP op, mlir::Type ity, | ||||||
mlir::ConversionPatternRewriter &rewriter) { | ||||||
mlir::Location loc = op.getLoc(); | ||||||
mlir::Type dataTy = op.getInType(); | ||||||
auto seqTy = mlir::dyn_cast<fir::SequenceType>(dataTy); | ||||||
fir::SequenceType::Extent constSize = 1; | ||||||
if (seqTy) { | ||||||
int constRows = seqTy.getConstantRows(); | ||||||
const fir::SequenceType::ShapeRef &shape = seqTy.getShape(); | ||||||
if (constRows != static_cast<int>(shape.size())) { | ||||||
for (auto extent : shape) { | ||||||
if (constRows-- > 0) | ||||||
continue; | ||||||
if (extent != fir::SequenceType::getUnknownExtent()) | ||||||
constSize *= extent; | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
if (constSize != 1) { | ||||||
mlir::Value constVal{ | ||||||
genConstantIndex(loc, ity, rewriter, constSize).getResult()}; | ||||||
return constVal; | ||||||
} | ||||||
return nullptr; | ||||||
} | ||||||
|
||||||
static mlir::Value integerCast(const fir::LLVMTypeConverter &converter, | ||||||
mlir::Location loc, | ||||||
mlir::ConversionPatternRewriter &rewriter, | ||||||
mlir::Type ty, mlir::Value val, | ||||||
bool fold = false) { | ||||||
auto valTy = val.getType(); | ||||||
// If the value was not yet lowered, lower its type so that it can | ||||||
// be used in getPrimitiveTypeSizeInBits. | ||||||
if (!mlir::isa<mlir::IntegerType>(valTy)) | ||||||
valTy = converter.convertType(valTy); | ||||||
auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty); | ||||||
auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy); | ||||||
if (fold) { | ||||||
if (toSize < fromSize) | ||||||
return rewriter.createOrFold<mlir::LLVM::TruncOp>(loc, ty, val); | ||||||
if (toSize > fromSize) | ||||||
return rewriter.createOrFold<mlir::LLVM::SExtOp>(loc, ty, val); | ||||||
} else { | ||||||
if (toSize < fromSize) | ||||||
return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val); | ||||||
if (toSize > fromSize) | ||||||
return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val); | ||||||
} | ||||||
return val; | ||||||
} | ||||||
|
||||||
// FIR Op specific conversion for TargetAllocMemOp | ||||||
struct TargetAllocMemOpConversion | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this conversion pattern should not replace |
||||||
: public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> { | ||||||
using OpenMPFIROpConversion::OpenMPFIROpConversion; | ||||||
|
||||||
llvm::LogicalResult | ||||||
matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor, | ||||||
mlir::ConversionPatternRewriter &rewriter) const override { | ||||||
mlir::Type heapTy = allocmemOp.getAllocatedType(); | ||||||
mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc(allocmemOp); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: move before use location below. |
||||||
mlir::Location loc = allocmemOp.getLoc(); | ||||||
auto ity = lowerTy().indexType(); | ||||||
mlir::Type dataTy = fir::unwrapRefType(heapTy); | ||||||
mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy); | ||||||
mlir::Type llvmPtrTy = | ||||||
mlir::LLVM::LLVMPointerType::get(allocmemOp.getContext(), 0); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy))) | ||||||
TODO(loc, "omp.target_allocmem codegen of derived type with length " | ||||||
"parameters"); | ||||||
mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy, | ||||||
lowerTy().getDataLayout()); | ||||||
if (auto scaleSize = genAllocationScaleSize(allocmemOp, ity, rewriter)) | ||||||
size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize); | ||||||
for (mlir::Value opnd : adaptor.getOperands().drop_front()) | ||||||
size = rewriter.create<mlir::LLVM::MulOp>( | ||||||
loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd)); | ||||||
auto mallocTyWidth = lowerTy().getIndexTypeBitwidth(); | ||||||
auto mallocTy = | ||||||
mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth); | ||||||
if (mallocTyWidth != ity.getIntOrFloatBitWidth()) | ||||||
size = integerCast(lowerTy(), loc, rewriter, mallocTy, size); | ||||||
Comment on lines
+282
to
+283
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this |
||||||
allocmemOp->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc)); | ||||||
auto callOp = rewriter.create<mlir::LLVM::CallOp>( | ||||||
loc, llvmPtrTy, | ||||||
mlir::SmallVector<mlir::Value, 2>({size, allocmemOp.getDevice()}), | ||||||
addLLVMOpBundleAttrs(rewriter, allocmemOp->getAttrs(), 2)); | ||||||
rewriter.replaceOpWithNewOp<mlir::LLVM::PtrToIntOp>( | ||||||
allocmemOp, rewriter.getIntegerType(64), callOp.getResult()); | ||||||
return mlir::success(); | ||||||
} | ||||||
}; | ||||||
} // namespace | ||||||
|
||||||
void fir::populateOpenMPFIRToLLVMConversionPatterns( | ||||||
const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) { | ||||||
patterns.add<MapInfoOpConversion>(converter); | ||||||
patterns.add<PrivateClauseOpConversion>(converter); | ||||||
patterns.add<TargetAllocMemOpConversion>(converter); | ||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// RUN: %flang_fc1 -emit-llvm %s -o - | FileCheck %s | ||
|
||
// CHECK-LABEL: define void @omp_target_allocmem_array_of_nonchar( | ||
// CHECK: call ptr @omp_target_alloc(i64 36, i32 0) | ||
// CHECK: call void @omp_target_free(ptr {{.*}}, i32 0) | ||
func.func @omp_target_allocmem_array_of_nonchar() -> () { | ||
%device = arith.constant 0 : i32 | ||
%1 = omp.target_allocmem %device : i32, !fir.array<3x3xi32> | ||
omp.target_freemem %device, %1 : i32, i64 | ||
return | ||
} | ||
|
||
// CHECK-LABEL: define void @omp_target_allocmem_array_of_char( | ||
// CHECK: call ptr @omp_target_alloc(i64 90, i32 0) | ||
// CHECK: call void @omp_target_free(ptr {{.*}}, i32 0) | ||
func.func @omp_target_allocmem_array_of_char() -> () { | ||
%device = arith.constant 0 : i32 | ||
%1 = omp.target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,10>> | ||
omp.target_freemem %device, %1 : i32, i64 | ||
return | ||
} | ||
|
||
// CHECK-LABEL: define void @omp_target_allocmem_array_of_dynchar( | ||
// CHECK-SAME: i32 %[[len:.*]]) | ||
// CHECK: %[[mul1:.*]] = sext i32 %[[len]] to i64 | ||
// CHECK: %[[mul2:.*]] = mul i64 9, %[[mul1]] | ||
// CHECK: call ptr @omp_target_alloc(i64 %[[mul2]], i32 0) | ||
func.func @omp_target_allocmem_array_of_dynchar(%l: i32) -> () { | ||
%device = arith.constant 0 : i32 | ||
%1 = omp.target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,?>>(%l : i32) | ||
omp.target_freemem %device, %1 : i32, i64 | ||
return | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems this function and some of the ones below are shared between
CodeGen.cpp
andCodeGenOpenMP.cpp
. Can we move them to a shared location, e.g.flang/Optimizer/Support/Utils.h/.cpp
?