From 4915fb8f1d27847dc3d36899e233d5ac988f96c5 Mon Sep 17 00:00:00 2001 From: skc7 Date: Mon, 23 Jun 2025 16:39:55 +0530 Subject: [PATCH 1/6] [flang] Introduce omp_target_allocmem and omp_target_freemem fir ops. --- .../include/flang/Optimizer/Dialect/FIROps.td | 58 ++++++++++ flang/lib/Optimizer/CodeGen/CodeGen.cpp | 102 +++++++++++++++++- 2 files changed, 159 insertions(+), 1 deletion(-) diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 8ac847dd7dd0a..2dff0f05fade7 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -517,6 +517,64 @@ def fir_ZeroOp : fir_OneResultOp<"zero_bits", [NoMemoryEffect]> { let assemblyFormat = "type($intype) attr-dict"; } +def fir_OmpTargetAllocMemOp : fir_Op<"omp_target_allocmem", + [MemoryEffects<[MemAlloc]>, AttrSizedOperandSegments]> { + let summary = "allocate storage on an openmp device for an object of a given type"; + + let description = [{ + Creates a heap memory reference suitable for storing a value of the + given type, T. The heap refernce returned has type `!fir.heap`. + The memory object is in an undefined state. `omp_target_allocmem` operations must + be paired with `omp_target_freemem` operations to avoid memory leaks. + + ``` + %0 = "fir.omp_target_allocmem"(%device, %type) : (i32, index) -> !fir.heap> + ``` + }]; + + let arguments = (ins + Arg:$device, + TypeAttr:$in_type, + OptionalAttr:$uniq_name, + OptionalAttr:$bindc_name, + Variadic:$typeparams, + Variadic:$shape + ); + let results = (outs fir_HeapType); + + let extraClassDeclaration = [{ + mlir::Type getAllocatedType(); + bool hasLenParams() { return !getTypeparams().empty(); } + bool hasShapeOperands() { return !getShape().empty(); } + unsigned numLenParams() { return getTypeparams().size(); } + operand_range getLenParams() { return getTypeparams(); } + unsigned numShapeOperands() { return getShape().size(); } + operand_range getShapeOperands() { return getShape(); } + static mlir::Type getRefTy(mlir::Type ty); + }]; +} + +def fir_OmpTargetFreeMemOp : fir_Op<"omp_target_freemem", + [MemoryEffects<[MemFree]>]> { + let summary = "free a heap object"; + + let description = [{ + Deallocates a heap memory reference that was allocated by an `omp_target_allocmem`. + The memory object that is deallocated is placed in an undefined state + after `fir.omp_target_freemem`. + ``` + %0 = "fir.omp_target_allocmem"(%device, %type) : (i32, index) -> !fir.heap> + ... + "fir.omp_target_freemem"(%device, %0) : (i32, !fir.heap>) -> () + ``` + }]; + + let arguments = (ins + Arg:$device, + Arg:$heapref + ); +} + //===----------------------------------------------------------------------===// // Terminator operations //===----------------------------------------------------------------------===// diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index a3de3ae9d116a..042ade6b1e0a1 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -1168,6 +1168,105 @@ struct FreeMemOpConversion : public fir::FIROpConversion { }; } // namespace +static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) { + auto module = op->getParentOfType(); + if (mlir::LLVM::LLVMFuncOp mallocFunc = + module.lookupSymbol("omp_target_alloc")) + return mallocFunc; + mlir::OpBuilder moduleBuilder(module.getBodyRegion()); + auto i64Ty = mlir::IntegerType::get(module->getContext(), 64); + auto i32Ty = mlir::IntegerType::get(module->getContext(), 32); + return moduleBuilder.create( + moduleBuilder.getUnknownLoc(), "omp_target_alloc", + mlir::LLVM::LLVMFunctionType::get( + mlir::LLVM::LLVMPointerType::get(module->getContext()), + {i64Ty, i32Ty}, + /*isVarArg=*/false)); +} + +namespace { +struct OmpTargetAllocMemOpConversion + : public fir::FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::OmpTargetAllocMemOp heap, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Type heapTy = heap.getType(); + mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc(heap); + mlir::Location loc = heap.getLoc(); + auto ity = lowerTy().indexType(); + mlir::Type dataTy = fir::unwrapRefType(heapTy); + mlir::Type llvmObjectTy = convertObjectType(dataTy); + if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy))) + TODO(loc, "fir.omp_target_allocmem codegen of derived type with length " + "parameters"); + mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy); + if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter)) + size = rewriter.create(loc, ity, size, scaleSize); + for (mlir::Value opnd : adaptor.getOperands().drop_front()) + size = rewriter.create( + loc, ity, size, integerCast(loc, rewriter, ity, opnd)); + auto mallocTyWidth = lowerTy().getIndexTypeBitwidth(); + auto mallocTy = + mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth); + if (mallocTyWidth != ity.getIntOrFloatBitWidth()) + size = integerCast(loc, rewriter, mallocTy, size); + heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc)); + rewriter.replaceOpWithNewOp( + heap, ::getLlvmPtrType(heap.getContext()), + mlir::SmallVector({size, heap.getDevice()}), + addLLVMOpBundleAttrs(rewriter, heap->getAttrs(), 2)); + return mlir::success(); + } + + /// Compute the allocation size in bytes of the element type of + /// \p llTy pointer type. The result is returned as a value of \p idxTy + /// integer type. + mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy, + mlir::ConversionPatternRewriter &rewriter, + mlir::Type llTy) const { + return computeElementDistance(loc, llTy, idxTy, rewriter, getDataLayout()); + } +}; +} // namespace + +static mlir::LLVM::LLVMFuncOp getOmpTargetFree(mlir::Operation *op) { + auto module = op->getParentOfType(); + if (mlir::LLVM::LLVMFuncOp freeFunc = + module.lookupSymbol("omp_target_free")) + return freeFunc; + mlir::OpBuilder moduleBuilder(module.getBodyRegion()); + auto i32Ty = mlir::IntegerType::get(module->getContext(), 32); + return moduleBuilder.create( + moduleBuilder.getUnknownLoc(), "omp_target_free", + mlir::LLVM::LLVMFunctionType::get( + mlir::LLVM::LLVMVoidType::get(module->getContext()), + {getLlvmPtrType(module->getContext()), i32Ty}, + /*isVarArg=*/false)); +} + +namespace { +struct OmpTargetFreeMemOpConversion + : public fir::FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::OmpTargetFreeMemOp freemem, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::LLVM::LLVMFuncOp freeFunc = getOmpTargetFree(freemem); + mlir::Location loc = freemem.getLoc(); + freemem->setAttr("callee", mlir::SymbolRefAttr::get(freeFunc)); + rewriter.create( + loc, mlir::TypeRange{}, + mlir::ValueRange{adaptor.getHeapref(), freemem.getDevice()}, + addLLVMOpBundleAttrs(rewriter, freemem->getAttrs(), 2)); + rewriter.eraseOp(freemem); + return mlir::success(); + } +}; +} // namespace + // Convert subcomponent array indices from column-major to row-major ordering. static llvm::SmallVector convertSubcomponentIndices(mlir::Location loc, mlir::Type eleTy, @@ -4274,7 +4373,8 @@ void fir::populateFIRToLLVMConversionPatterns( GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion, IsPresentOpConversion, LenParamIndexOpConversion, LoadOpConversion, LocalitySpecifierOpConversion, MulcOpConversion, NegcOpConversion, - NoReassocOpConversion, SelectCaseOpConversion, SelectOpConversion, + NoReassocOpConversion, OmpTargetAllocMemOpConversion, + OmpTargetFreeMemOpConversion, SelectCaseOpConversion, SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion, ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion, SliceOpConversion, StoreOpConversion, StringLitOpConversion, SubcOpConversion, From e41a2c76786538fc411e104542a0282cfebef4f7 Mon Sep 17 00:00:00 2001 From: skc7 Date: Thu, 26 Jun 2025 10:20:36 +0530 Subject: [PATCH 2/6] [flang] Fix parsing and printing. --- .../include/flang/Optimizer/Dialect/FIROps.td | 13 ++- flang/lib/Optimizer/Dialect/FIROps.cpp | 90 ++++++++++++++++--- flang/test/Fir/omp_target_allocmem.fir | 28 ++++++ flang/test/Fir/omp_target_freemem.fir | 28 ++++++ 4 files changed, 145 insertions(+), 14 deletions(-) create mode 100644 flang/test/Fir/omp_target_allocmem.fir create mode 100644 flang/test/Fir/omp_target_freemem.fir diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 2dff0f05fade7..666b66a8670d6 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -528,7 +528,8 @@ def fir_OmpTargetAllocMemOp : fir_Op<"omp_target_allocmem", be paired with `omp_target_freemem` operations to avoid memory leaks. ``` - %0 = "fir.omp_target_allocmem"(%device, %type) : (i32, index) -> !fir.heap> + %device = arith.constant 0 : i32 + %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3xi32> ``` }]; @@ -542,6 +543,9 @@ def fir_OmpTargetAllocMemOp : fir_Op<"omp_target_allocmem", ); let results = (outs fir_HeapType); + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; + let extraClassDeclaration = [{ mlir::Type getAllocatedType(); bool hasLenParams() { return !getTypeparams().empty(); } @@ -563,9 +567,9 @@ def fir_OmpTargetFreeMemOp : fir_Op<"omp_target_freemem", The memory object that is deallocated is placed in an undefined state after `fir.omp_target_freemem`. ``` - %0 = "fir.omp_target_allocmem"(%device, %type) : (i32, index) -> !fir.heap> - ... - "fir.omp_target_freemem"(%device, %0) : (i32, !fir.heap>) -> () + %device = arith.constant 0 : i32 + %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3xi32> + fir.omp_target_freemem %device, %1 : i32, !fir.heap> ``` }]; @@ -573,6 +577,7 @@ def fir_OmpTargetFreeMemOp : fir_Op<"omp_target_freemem", Arg:$device, Arg:$heapref ); + let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))"; } //===----------------------------------------------------------------------===// diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index ecfa2939e96a6..9335a4b041ac8 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -106,24 +106,38 @@ static bool verifyTypeParamCount(mlir::Type inType, unsigned numParams) { return false; } -/// Parser shared by Alloca and Allocmem -/// +/// Parser shared by Alloca, Allocmem and OmpTargetAllocmem +/// boolean flag isTargetOp is used to identify omp_target_allocmem /// operation ::= %res = (`fir.alloca` | `fir.allocmem`) $in_type /// ( `(` $typeparams `)` )? ( `,` $shape )? /// attr-dict-without-keyword +/// operation ::= %res = (`fir.omp_target_alloca`) $device : devicetype, +/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )? +/// attr-dict-without-keyword template -static mlir::ParseResult parseAllocatableOp(FN wrapResultType, - mlir::OpAsmParser &parser, - mlir::OperationState &result) { +static mlir::ParseResult +parseAllocatableOp(FN wrapResultType, mlir::OpAsmParser &parser, + mlir::OperationState &result, bool isTargetOp = false) { + auto &builder = parser.getBuilder(); + bool hasOperands = false; + std::int32_t typeparamsSize = 0; + // Parse device number as a new operand + if (isTargetOp) { + mlir::OpAsmParser::UnresolvedOperand deviceOperand; + mlir::Type deviceType; + if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType)) + return mlir::failure(); + if (parser.resolveOperand(deviceOperand, deviceType, result.operands)) + return mlir::failure(); + if (parser.parseComma()) + return mlir::failure(); + } mlir::Type intype; if (parser.parseType(intype)) return mlir::failure(); - auto &builder = parser.getBuilder(); result.addAttribute("in_type", mlir::TypeAttr::get(intype)); llvm::SmallVector operands; llvm::SmallVector typeVec; - bool hasOperands = false; - std::int32_t typeparamsSize = 0; if (!parser.parseOptionalLParen()) { // parse the LEN params of the derived type. ( : ) if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || @@ -147,13 +161,19 @@ static mlir::ParseResult parseAllocatableOp(FN wrapResultType, parser.resolveOperands(operands, typeVec, parser.getNameLoc(), result.operands)) return mlir::failure(); + mlir::Type restype = wrapResultType(intype); if (!restype) { parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype; return mlir::failure(); } - result.addAttribute("operandSegmentSizes", builder.getDenseI32ArrayAttr( - {typeparamsSize, shapeSize})); + llvm::SmallVector segmentSizes; + if (isTargetOp) + segmentSizes.push_back(1); + segmentSizes.push_back(typeparamsSize); + segmentSizes.push_back(shapeSize); + result.addAttribute("operandSegmentSizes", + builder.getDenseI32ArrayAttr(segmentSizes)); if (parser.parseOptionalAttrDict(result.attributes) || parser.addTypeToList(restype, result.types)) return mlir::failure(); @@ -385,6 +405,56 @@ llvm::LogicalResult fir::AllocMemOp::verify() { return mlir::success(); } +//===----------------------------------------------------------------------===// +// OmpTargetAllocMemOp +//===----------------------------------------------------------------------===// + +mlir::Type fir::OmpTargetAllocMemOp::getAllocatedType() { + return mlir::cast(getType()).getEleTy(); +} + +mlir::Type fir::OmpTargetAllocMemOp::getRefTy(mlir::Type ty) { + return fir::HeapType::get(ty); +} + +mlir::ParseResult +fir::OmpTargetAllocMemOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseAllocatableOp(wrapAllocMemResultType, parser, result, true); +} + +void fir::OmpTargetAllocMemOp::print(mlir::OpAsmPrinter &p) { + p << " "; + p.printOperand(getDevice()); + p << " : "; + p << getDevice().getType(); + p << ", "; + p << getInType(); + if (!getTypeparams().empty()) { + p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')'; + } + for (auto sh : getShape()) { + p << ", "; + p.printOperand(sh); + } + p.printOptionalAttrDict((*this)->getAttrs(), + {"in_type", "operandSegmentSizes"}); +} + +llvm::LogicalResult fir::OmpTargetAllocMemOp::verify() { + llvm::SmallVector visited; + if (verifyInType(getInType(), visited, numShapeOperands())) + return emitOpError("invalid type for allocation"); + if (verifyTypeParamCount(getInType(), numLenParams())) + return emitOpError("LEN params do not correspond to type"); + mlir::Type outType = getType(); + if (!mlir::dyn_cast(outType)) + return emitOpError("must be a !fir.heap type"); + if (fir::isa_unknown_size_box(fir::dyn_cast_ptrEleTy(outType))) + return emitOpError("cannot allocate !fir.box of unknown rank or type"); + return mlir::success(); +} + //===----------------------------------------------------------------------===// // ArrayCoorOp //===----------------------------------------------------------------------===// diff --git a/flang/test/Fir/omp_target_allocmem.fir b/flang/test/Fir/omp_target_allocmem.fir new file mode 100644 index 0000000000000..5140c91c9510c --- /dev/null +++ b/flang/test/Fir/omp_target_allocmem.fir @@ -0,0 +1,28 @@ +// RUN: %flang_fc1 -emit-llvm %s -o - | FileCheck %s + +// CHECK-LABEL: define ptr @omp_target_allocmem_array_of_nonchar( +// CHECK: call ptr @omp_target_alloc(i64 36, i32 0) +func.func @omp_target_allocmem_array_of_nonchar() -> !fir.heap> { + %device = arith.constant 0 : i32 + %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3xi32> + return %1 : !fir.heap> +} + +// CHECK-LABEL: define ptr @omp_target_allocmem_array_of_char( +// CHECK: call ptr @omp_target_alloc(i64 90, i32 0) +func.func @omp_target_allocmem_array_of_char() -> !fir.heap>> { + %device = arith.constant 0 : i32 + %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,10>> + return %1 : !fir.heap>> +} + +// CHECK-LABEL: define ptr @omp_target_allocmem_array_of_dynchar( +// CHECK-SAME: i32 %[[len:.*]]) +// CHECK: %[[mul1:.*]] = sext i32 %[[len]] to i64 +// CHECK: %[[mul2:.*]] = mul i64 9, %[[mul1]] +// CHECK: call ptr @omp_target_alloc(i64 %[[mul2]], i32 0) +func.func @omp_target_allocmem_array_of_dynchar(%l: i32) -> !fir.heap>> { + %device = arith.constant 0 : i32 + %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,?>>(%l : i32) + return %1 : !fir.heap>> +} diff --git a/flang/test/Fir/omp_target_freemem.fir b/flang/test/Fir/omp_target_freemem.fir new file mode 100644 index 0000000000000..02e136076a9cf --- /dev/null +++ b/flang/test/Fir/omp_target_freemem.fir @@ -0,0 +1,28 @@ +// RUN: %flang_fc1 -emit-llvm %s -o - | FileCheck %s + +// CHECK-LABEL: define void @omp_target_allocmem_array_of_nonchar( +// CHECK: call void @omp_target_free(ptr {{.*}}, i32 0) +func.func @omp_target_allocmem_array_of_nonchar() -> () { + %device = arith.constant 0 : i32 + %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3xi32> + fir.omp_target_freemem %device, %1 : i32, !fir.heap> + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_array_of_char( +// CHECK: call void @omp_target_free(ptr {{.*}}, i32 0) +func.func @omp_target_allocmem_array_of_char() -> () { + %device = arith.constant 0 : i32 + %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,10>> + fir.omp_target_freemem %device, %1 : i32, !fir.heap>> + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_array_of_dynchar( +// CHECK: call void @omp_target_free(ptr {{.*}}, i32 0) +func.func @omp_target_allocmem_array_of_dynchar(%l: i32) -> () { + %device = arith.constant 0 : i32 + %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,?>>(%l : i32) + fir.omp_target_freemem %device, %1 : i32, !fir.heap>> + return +} From a864094e5bd482499f82ef913c92c1f5b6bd070f Mon Sep 17 00:00:00 2001 From: skc7 Date: Thu, 26 Jun 2025 10:32:21 +0530 Subject: [PATCH 3/6] [flang] Fix doc in td --- flang/include/flang/Optimizer/Dialect/FIROps.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 666b66a8670d6..93d617027e30b 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -560,7 +560,7 @@ def fir_OmpTargetAllocMemOp : fir_Op<"omp_target_allocmem", def fir_OmpTargetFreeMemOp : fir_Op<"omp_target_freemem", [MemoryEffects<[MemFree]>]> { - let summary = "free a heap object"; + let summary = "free a heap object on an openmp device"; let description = [{ Deallocates a heap memory reference that was allocated by an `omp_target_allocmem`. @@ -569,7 +569,7 @@ def fir_OmpTargetFreeMemOp : fir_Op<"omp_target_freemem", ``` %device = arith.constant 0 : i32 %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3xi32> - fir.omp_target_freemem %device, %1 : i32, !fir.heap> + fir.omp_target_freemem %device, %1 : i32, !fir.heap> ``` }]; From 2f5b289464787d89ea00ab295ccb2ebee6b57329 Mon Sep 17 00:00:00 2001 From: skc7 Date: Thu, 3 Jul 2025 12:22:55 +0530 Subject: [PATCH 4/6] [omp][mlir] Introduce TargetAllocMem and TargetFreeMem ops in openMP mlir dialect --- .../include/flang/Optimizer/Dialect/FIROps.td | 63 ------- flang/lib/Optimizer/CodeGen/CodeGen.cpp | 102 +---------- flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp | 161 ++++++++++++++++++ flang/lib/Optimizer/Dialect/FIROps.cpp | 88 +--------- flang/test/Fir/omp_target_allocmem.fir | 28 --- ...em.fir => omp_target_allocmem_freemem.fir} | 19 ++- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 64 +++++++ mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 102 +++++++++++ .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 82 +++++++++ .../ompenmp-target-allocmem-freemem.mlir | 42 +++++ 10 files changed, 473 insertions(+), 278 deletions(-) delete mode 100644 flang/test/Fir/omp_target_allocmem.fir rename flang/test/Fir/{omp_target_freemem.fir => omp_target_allocmem_freemem.fir} (51%) create mode 100644 mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 93d617027e30b..8ac847dd7dd0a 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -517,69 +517,6 @@ def fir_ZeroOp : fir_OneResultOp<"zero_bits", [NoMemoryEffect]> { let assemblyFormat = "type($intype) attr-dict"; } -def fir_OmpTargetAllocMemOp : fir_Op<"omp_target_allocmem", - [MemoryEffects<[MemAlloc]>, AttrSizedOperandSegments]> { - let summary = "allocate storage on an openmp device for an object of a given type"; - - let description = [{ - Creates a heap memory reference suitable for storing a value of the - given type, T. The heap refernce returned has type `!fir.heap`. - The memory object is in an undefined state. `omp_target_allocmem` operations must - be paired with `omp_target_freemem` operations to avoid memory leaks. - - ``` - %device = arith.constant 0 : i32 - %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3xi32> - ``` - }]; - - let arguments = (ins - Arg:$device, - TypeAttr:$in_type, - OptionalAttr:$uniq_name, - OptionalAttr:$bindc_name, - Variadic:$typeparams, - Variadic:$shape - ); - let results = (outs fir_HeapType); - - let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; - - let extraClassDeclaration = [{ - mlir::Type getAllocatedType(); - bool hasLenParams() { return !getTypeparams().empty(); } - bool hasShapeOperands() { return !getShape().empty(); } - unsigned numLenParams() { return getTypeparams().size(); } - operand_range getLenParams() { return getTypeparams(); } - unsigned numShapeOperands() { return getShape().size(); } - operand_range getShapeOperands() { return getShape(); } - static mlir::Type getRefTy(mlir::Type ty); - }]; -} - -def fir_OmpTargetFreeMemOp : fir_Op<"omp_target_freemem", - [MemoryEffects<[MemFree]>]> { - let summary = "free a heap object on an openmp device"; - - let description = [{ - Deallocates a heap memory reference that was allocated by an `omp_target_allocmem`. - The memory object that is deallocated is placed in an undefined state - after `fir.omp_target_freemem`. - ``` - %device = arith.constant 0 : i32 - %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3xi32> - fir.omp_target_freemem %device, %1 : i32, !fir.heap> - ``` - }]; - - let arguments = (ins - Arg:$device, - Arg:$heapref - ); - let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))"; -} - //===----------------------------------------------------------------------===// // Terminator operations //===----------------------------------------------------------------------===// diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 042ade6b1e0a1..a3de3ae9d116a 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -1168,105 +1168,6 @@ struct FreeMemOpConversion : public fir::FIROpConversion { }; } // namespace -static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) { - auto module = op->getParentOfType(); - if (mlir::LLVM::LLVMFuncOp mallocFunc = - module.lookupSymbol("omp_target_alloc")) - return mallocFunc; - mlir::OpBuilder moduleBuilder(module.getBodyRegion()); - auto i64Ty = mlir::IntegerType::get(module->getContext(), 64); - auto i32Ty = mlir::IntegerType::get(module->getContext(), 32); - return moduleBuilder.create( - moduleBuilder.getUnknownLoc(), "omp_target_alloc", - mlir::LLVM::LLVMFunctionType::get( - mlir::LLVM::LLVMPointerType::get(module->getContext()), - {i64Ty, i32Ty}, - /*isVarArg=*/false)); -} - -namespace { -struct OmpTargetAllocMemOpConversion - : public fir::FIROpConversion { - using FIROpConversion::FIROpConversion; - - mlir::LogicalResult - matchAndRewrite(fir::OmpTargetAllocMemOp heap, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - mlir::Type heapTy = heap.getType(); - mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc(heap); - mlir::Location loc = heap.getLoc(); - auto ity = lowerTy().indexType(); - mlir::Type dataTy = fir::unwrapRefType(heapTy); - mlir::Type llvmObjectTy = convertObjectType(dataTy); - if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy))) - TODO(loc, "fir.omp_target_allocmem codegen of derived type with length " - "parameters"); - mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy); - if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter)) - size = rewriter.create(loc, ity, size, scaleSize); - for (mlir::Value opnd : adaptor.getOperands().drop_front()) - size = rewriter.create( - loc, ity, size, integerCast(loc, rewriter, ity, opnd)); - auto mallocTyWidth = lowerTy().getIndexTypeBitwidth(); - auto mallocTy = - mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth); - if (mallocTyWidth != ity.getIntOrFloatBitWidth()) - size = integerCast(loc, rewriter, mallocTy, size); - heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc)); - rewriter.replaceOpWithNewOp( - heap, ::getLlvmPtrType(heap.getContext()), - mlir::SmallVector({size, heap.getDevice()}), - addLLVMOpBundleAttrs(rewriter, heap->getAttrs(), 2)); - return mlir::success(); - } - - /// Compute the allocation size in bytes of the element type of - /// \p llTy pointer type. The result is returned as a value of \p idxTy - /// integer type. - mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy, - mlir::ConversionPatternRewriter &rewriter, - mlir::Type llTy) const { - return computeElementDistance(loc, llTy, idxTy, rewriter, getDataLayout()); - } -}; -} // namespace - -static mlir::LLVM::LLVMFuncOp getOmpTargetFree(mlir::Operation *op) { - auto module = op->getParentOfType(); - if (mlir::LLVM::LLVMFuncOp freeFunc = - module.lookupSymbol("omp_target_free")) - return freeFunc; - mlir::OpBuilder moduleBuilder(module.getBodyRegion()); - auto i32Ty = mlir::IntegerType::get(module->getContext(), 32); - return moduleBuilder.create( - moduleBuilder.getUnknownLoc(), "omp_target_free", - mlir::LLVM::LLVMFunctionType::get( - mlir::LLVM::LLVMVoidType::get(module->getContext()), - {getLlvmPtrType(module->getContext()), i32Ty}, - /*isVarArg=*/false)); -} - -namespace { -struct OmpTargetFreeMemOpConversion - : public fir::FIROpConversion { - using FIROpConversion::FIROpConversion; - - mlir::LogicalResult - matchAndRewrite(fir::OmpTargetFreeMemOp freemem, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - mlir::LLVM::LLVMFuncOp freeFunc = getOmpTargetFree(freemem); - mlir::Location loc = freemem.getLoc(); - freemem->setAttr("callee", mlir::SymbolRefAttr::get(freeFunc)); - rewriter.create( - loc, mlir::TypeRange{}, - mlir::ValueRange{adaptor.getHeapref(), freemem.getDevice()}, - addLLVMOpBundleAttrs(rewriter, freemem->getAttrs(), 2)); - rewriter.eraseOp(freemem); - return mlir::success(); - } -}; -} // namespace - // Convert subcomponent array indices from column-major to row-major ordering. static llvm::SmallVector convertSubcomponentIndices(mlir::Location loc, mlir::Type eleTy, @@ -4373,8 +4274,7 @@ void fir::populateFIRToLLVMConversionPatterns( GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion, IsPresentOpConversion, LenParamIndexOpConversion, LoadOpConversion, LocalitySpecifierOpConversion, MulcOpConversion, NegcOpConversion, - NoReassocOpConversion, OmpTargetAllocMemOpConversion, - OmpTargetFreeMemOpConversion, SelectCaseOpConversion, SelectOpConversion, + NoReassocOpConversion, SelectCaseOpConversion, SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion, ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion, SliceOpConversion, StoreOpConversion, StringLitOpConversion, SubcOpConversion, diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp index 37f1c9f97e1ce..a04c5d7eb7ee7 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp @@ -125,10 +125,171 @@ struct PrivateClauseOpConversion return mlir::success(); } }; + +static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) { + auto module = op->getParentOfType(); + if (mlir::LLVM::LLVMFuncOp mallocFunc = + module.lookupSymbol("omp_target_alloc")) + return mallocFunc; + mlir::OpBuilder moduleBuilder(module.getBodyRegion()); + auto i64Ty = mlir::IntegerType::get(module->getContext(), 64); + auto i32Ty = mlir::IntegerType::get(module->getContext(), 32); + return moduleBuilder.create( + moduleBuilder.getUnknownLoc(), "omp_target_alloc", + mlir::LLVM::LLVMFunctionType::get( + mlir::LLVM::LLVMPointerType::get(module->getContext()), + {i64Ty, i32Ty}, + /*isVarArg=*/false)); +} + +static mlir::Type +convertObjectType(const fir::LLVMTypeConverter &converter, mlir::Type firType) { + if (auto boxTy = mlir::dyn_cast(firType)) + return converter.convertBoxTypeAsStruct(boxTy); + return converter.convertType(firType); +} + +static llvm::SmallVector +addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter, + llvm::ArrayRef attrs, + int32_t numCallOperands) { + llvm::SmallVector newAttrs; + newAttrs.reserve(attrs.size() + 2); + + for (mlir::NamedAttribute attr : attrs) { + if (attr.getName() != "operandSegmentSizes") + newAttrs.push_back(attr); + } + + newAttrs.push_back(rewriter.getNamedAttr( + "operandSegmentSizes", + rewriter.getDenseI32ArrayAttr({numCallOperands, 0}))); + newAttrs.push_back(rewriter.getNamedAttr("op_bundle_sizes", + rewriter.getDenseI32ArrayAttr({}))); + return newAttrs; +} + +static mlir::LLVM::ConstantOp +genConstantIndex(mlir::Location loc, mlir::Type ity, + mlir::ConversionPatternRewriter &rewriter, + std::int64_t offset) { + auto cattr = rewriter.getI64IntegerAttr(offset); + return rewriter.create(loc, ity, cattr); +} + +static mlir::Value +computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType, + mlir::Type idxTy, + mlir::ConversionPatternRewriter &rewriter, + const mlir::DataLayout &dataLayout) { + llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType); + unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType); + std::int64_t distance = llvm::alignTo(size, alignment); + return genConstantIndex(loc, idxTy, rewriter, distance); +} + +static mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy, + mlir::ConversionPatternRewriter &rewriter, + mlir::Type llTy, const mlir::DataLayout &dataLayout) { + return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout); +} + +template +static mlir::Value +genAllocationScaleSize(OP op, mlir::Type ity, + mlir::ConversionPatternRewriter &rewriter) { + mlir::Location loc = op.getLoc(); + mlir::Type dataTy = op.getInType(); + auto seqTy = mlir::dyn_cast(dataTy); + fir::SequenceType::Extent constSize = 1; + if (seqTy) { + int constRows = seqTy.getConstantRows(); + const fir::SequenceType::ShapeRef &shape = seqTy.getShape(); + if (constRows != static_cast(shape.size())) { + for (auto extent : shape) { + if (constRows-- > 0) + continue; + if (extent != fir::SequenceType::getUnknownExtent()) + constSize *= extent; + } + } + } + + if (constSize != 1) { + mlir::Value constVal{ + genConstantIndex(loc, ity, rewriter, constSize).getResult()}; + return constVal; + } + return nullptr; +} + +static mlir::Value integerCast(const fir::LLVMTypeConverter &converter, + mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, + mlir::Type ty, mlir::Value val, bool fold = false) { + auto valTy = val.getType(); + // If the value was not yet lowered, lower its type so that it can + // be used in getPrimitiveTypeSizeInBits. + if (!mlir::isa(valTy)) + valTy = converter.convertType(valTy); + auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty); + auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy); + if (fold) { + if (toSize < fromSize) + return rewriter.createOrFold(loc, ty, val); + if (toSize > fromSize) + return rewriter.createOrFold(loc, ty, val); + } else { + if (toSize < fromSize) + return rewriter.create(loc, ty, val); + if (toSize > fromSize) + return rewriter.create(loc, ty, val); + } + return val; +} + +// FIR Op specific conversion for TargetAllocMemOp +struct TargetAllocMemOpConversion + : public OpenMPFIROpConversion { + using OpenMPFIROpConversion::OpenMPFIROpConversion; + + llvm::LogicalResult + matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Type heapTy = allocmemOp.getAllocatedType(); + mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc(allocmemOp); + mlir::Location loc = allocmemOp.getLoc(); + auto ity = lowerTy().indexType(); + mlir::Type dataTy = fir::unwrapRefType(heapTy); + mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy); + mlir::Type llvmPtrTy = mlir::LLVM::LLVMPointerType::get(allocmemOp.getContext(), 0); + if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy))) + TODO(loc, "omp.target_allocmem codegen of derived type with length " + "parameters"); + mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy, lowerTy().getDataLayout()); + if (auto scaleSize = genAllocationScaleSize(allocmemOp, ity, rewriter)) + size = rewriter.create(loc, ity, size, scaleSize); + for (mlir::Value opnd : adaptor.getOperands().drop_front()) + size = rewriter.create( + loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd)); + auto mallocTyWidth = lowerTy().getIndexTypeBitwidth(); + auto mallocTy = + mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth); + if (mallocTyWidth != ity.getIntOrFloatBitWidth()) + size = integerCast(lowerTy(), loc, rewriter, mallocTy, size); + allocmemOp->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc)); + auto callOp = rewriter.create( + loc, llvmPtrTy, + mlir::SmallVector({size, allocmemOp.getDevice()}), + addLLVMOpBundleAttrs(rewriter, allocmemOp->getAttrs(), 2)); + rewriter.replaceOpWithNewOp(allocmemOp, rewriter.getIntegerType(64), callOp.getResult()); + return mlir::success(); + } +}; } // namespace void fir::populateOpenMPFIRToLLVMConversionPatterns( const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) { patterns.add(converter); patterns.add(converter); + patterns.add(converter); } diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 9335a4b041ac8..f6c794b8fe9ae 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -106,38 +106,24 @@ static bool verifyTypeParamCount(mlir::Type inType, unsigned numParams) { return false; } -/// Parser shared by Alloca, Allocmem and OmpTargetAllocmem +/// Parser shared by Alloca and Allocmem /// boolean flag isTargetOp is used to identify omp_target_allocmem /// operation ::= %res = (`fir.alloca` | `fir.allocmem`) $in_type /// ( `(` $typeparams `)` )? ( `,` $shape )? /// attr-dict-without-keyword -/// operation ::= %res = (`fir.omp_target_alloca`) $device : devicetype, -/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )? -/// attr-dict-without-keyword template -static mlir::ParseResult -parseAllocatableOp(FN wrapResultType, mlir::OpAsmParser &parser, - mlir::OperationState &result, bool isTargetOp = false) { - auto &builder = parser.getBuilder(); - bool hasOperands = false; - std::int32_t typeparamsSize = 0; - // Parse device number as a new operand - if (isTargetOp) { - mlir::OpAsmParser::UnresolvedOperand deviceOperand; - mlir::Type deviceType; - if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType)) - return mlir::failure(); - if (parser.resolveOperand(deviceOperand, deviceType, result.operands)) - return mlir::failure(); - if (parser.parseComma()) - return mlir::failure(); - } +static mlir::ParseResult parseAllocatableOp(FN wrapResultType, + mlir::OpAsmParser &parser, + mlir::OperationState &result) { mlir::Type intype; if (parser.parseType(intype)) return mlir::failure(); + auto &builder = parser.getBuilder(); result.addAttribute("in_type", mlir::TypeAttr::get(intype)); llvm::SmallVector operands; llvm::SmallVector typeVec; + bool hasOperands = false; + std::int32_t typeparamsSize = 0; if (!parser.parseOptionalLParen()) { // parse the LEN params of the derived type. ( : ) if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || @@ -161,19 +147,13 @@ parseAllocatableOp(FN wrapResultType, mlir::OpAsmParser &parser, parser.resolveOperands(operands, typeVec, parser.getNameLoc(), result.operands)) return mlir::failure(); - mlir::Type restype = wrapResultType(intype); if (!restype) { parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype; return mlir::failure(); } - llvm::SmallVector segmentSizes; - if (isTargetOp) - segmentSizes.push_back(1); - segmentSizes.push_back(typeparamsSize); - segmentSizes.push_back(shapeSize); - result.addAttribute("operandSegmentSizes", - builder.getDenseI32ArrayAttr(segmentSizes)); + result.addAttribute("operandSegmentSizes", builder.getDenseI32ArrayAttr( + {typeparamsSize, shapeSize})); if (parser.parseOptionalAttrDict(result.attributes) || parser.addTypeToList(restype, result.types)) return mlir::failure(); @@ -405,56 +385,6 @@ llvm::LogicalResult fir::AllocMemOp::verify() { return mlir::success(); } -//===----------------------------------------------------------------------===// -// OmpTargetAllocMemOp -//===----------------------------------------------------------------------===// - -mlir::Type fir::OmpTargetAllocMemOp::getAllocatedType() { - return mlir::cast(getType()).getEleTy(); -} - -mlir::Type fir::OmpTargetAllocMemOp::getRefTy(mlir::Type ty) { - return fir::HeapType::get(ty); -} - -mlir::ParseResult -fir::OmpTargetAllocMemOp::parse(mlir::OpAsmParser &parser, - mlir::OperationState &result) { - return parseAllocatableOp(wrapAllocMemResultType, parser, result, true); -} - -void fir::OmpTargetAllocMemOp::print(mlir::OpAsmPrinter &p) { - p << " "; - p.printOperand(getDevice()); - p << " : "; - p << getDevice().getType(); - p << ", "; - p << getInType(); - if (!getTypeparams().empty()) { - p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')'; - } - for (auto sh : getShape()) { - p << ", "; - p.printOperand(sh); - } - p.printOptionalAttrDict((*this)->getAttrs(), - {"in_type", "operandSegmentSizes"}); -} - -llvm::LogicalResult fir::OmpTargetAllocMemOp::verify() { - llvm::SmallVector visited; - if (verifyInType(getInType(), visited, numShapeOperands())) - return emitOpError("invalid type for allocation"); - if (verifyTypeParamCount(getInType(), numLenParams())) - return emitOpError("LEN params do not correspond to type"); - mlir::Type outType = getType(); - if (!mlir::dyn_cast(outType)) - return emitOpError("must be a !fir.heap type"); - if (fir::isa_unknown_size_box(fir::dyn_cast_ptrEleTy(outType))) - return emitOpError("cannot allocate !fir.box of unknown rank or type"); - return mlir::success(); -} - //===----------------------------------------------------------------------===// // ArrayCoorOp //===----------------------------------------------------------------------===// diff --git a/flang/test/Fir/omp_target_allocmem.fir b/flang/test/Fir/omp_target_allocmem.fir deleted file mode 100644 index 5140c91c9510c..0000000000000 --- a/flang/test/Fir/omp_target_allocmem.fir +++ /dev/null @@ -1,28 +0,0 @@ -// RUN: %flang_fc1 -emit-llvm %s -o - | FileCheck %s - -// CHECK-LABEL: define ptr @omp_target_allocmem_array_of_nonchar( -// CHECK: call ptr @omp_target_alloc(i64 36, i32 0) -func.func @omp_target_allocmem_array_of_nonchar() -> !fir.heap> { - %device = arith.constant 0 : i32 - %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3xi32> - return %1 : !fir.heap> -} - -// CHECK-LABEL: define ptr @omp_target_allocmem_array_of_char( -// CHECK: call ptr @omp_target_alloc(i64 90, i32 0) -func.func @omp_target_allocmem_array_of_char() -> !fir.heap>> { - %device = arith.constant 0 : i32 - %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,10>> - return %1 : !fir.heap>> -} - -// CHECK-LABEL: define ptr @omp_target_allocmem_array_of_dynchar( -// CHECK-SAME: i32 %[[len:.*]]) -// CHECK: %[[mul1:.*]] = sext i32 %[[len]] to i64 -// CHECK: %[[mul2:.*]] = mul i64 9, %[[mul1]] -// CHECK: call ptr @omp_target_alloc(i64 %[[mul2]], i32 0) -func.func @omp_target_allocmem_array_of_dynchar(%l: i32) -> !fir.heap>> { - %device = arith.constant 0 : i32 - %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,?>>(%l : i32) - return %1 : !fir.heap>> -} diff --git a/flang/test/Fir/omp_target_freemem.fir b/flang/test/Fir/omp_target_allocmem_freemem.fir similarity index 51% rename from flang/test/Fir/omp_target_freemem.fir rename to flang/test/Fir/omp_target_allocmem_freemem.fir index 02e136076a9cf..9202202728454 100644 --- a/flang/test/Fir/omp_target_freemem.fir +++ b/flang/test/Fir/omp_target_allocmem_freemem.fir @@ -1,28 +1,33 @@ // RUN: %flang_fc1 -emit-llvm %s -o - | FileCheck %s // CHECK-LABEL: define void @omp_target_allocmem_array_of_nonchar( +// CHECK: call ptr @omp_target_alloc(i64 36, i32 0) // CHECK: call void @omp_target_free(ptr {{.*}}, i32 0) func.func @omp_target_allocmem_array_of_nonchar() -> () { %device = arith.constant 0 : i32 - %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3xi32> - fir.omp_target_freemem %device, %1 : i32, !fir.heap> + %1 = omp.target_allocmem %device : i32, !fir.array<3x3xi32> + omp.target_freemem %device, %1 : i32, i64 return } // CHECK-LABEL: define void @omp_target_allocmem_array_of_char( +// CHECK: call ptr @omp_target_alloc(i64 90, i32 0) // CHECK: call void @omp_target_free(ptr {{.*}}, i32 0) func.func @omp_target_allocmem_array_of_char() -> () { %device = arith.constant 0 : i32 - %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,10>> - fir.omp_target_freemem %device, %1 : i32, !fir.heap>> + %1 = omp.target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,10>> + omp.target_freemem %device, %1 : i32, i64 return } // CHECK-LABEL: define void @omp_target_allocmem_array_of_dynchar( -// CHECK: call void @omp_target_free(ptr {{.*}}, i32 0) +// CHECK-SAME: i32 %[[len:.*]]) +// CHECK: %[[mul1:.*]] = sext i32 %[[len]] to i64 +// CHECK: %[[mul2:.*]] = mul i64 9, %[[mul1]] +// CHECK: call ptr @omp_target_alloc(i64 %[[mul2]], i32 0) func.func @omp_target_allocmem_array_of_dynchar(%l: i32) -> () { %device = arith.constant 0 : i32 - %1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,?>>(%l : i32) - fir.omp_target_freemem %device, %1 : i32, !fir.heap>> + %1 = omp.target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,?>>(%l : i32) + omp.target_freemem %device, %1 : i32, i64 return } diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index ac80926053a2d..bcc7c906041ac 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1887,4 +1887,68 @@ def MaskedOp : OpenMP_Op<"masked", clauses = [ ]; } +//===----------------------------------------------------------------------===// +// TargetAllocMemOp +//===----------------------------------------------------------------------===// + +def TargetAllocMemOp : OpenMP_Op<"target_allocmem", + [MemoryEffects<[MemAlloc]>, AttrSizedOperandSegments]> { + let summary = "allocate storage on an openmp device for an object of a given type"; + + let description = [{ + Allocates memory on the specified OpenMP device for an object of the given type. + Returns an integer value representing the device pointer to the allocated memory. + The memory is uninitialized after allocation. Operations must be paired with + `omp.target_freemem` to avoid memory leaks. + + ```mlir + %device = arith.constant 0 : i32 + %ptr = omp.target_allocmem %device : i32, vector<3x3xi32> + ``` + }]; + + let arguments = (ins + Arg:$device, + TypeAttr:$in_type, + OptionalAttr:$uniq_name, + OptionalAttr:$bindc_name, + Variadic:$typeparams, + Variadic:$shape + ); + let results = (outs I64); + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + mlir::Type getAllocatedType(); + }]; +} + +//===----------------------------------------------------------------------===// +// TargetFreeMemOp +//===----------------------------------------------------------------------===// + +def TargetFreeMemOp : OpenMP_Op<"target_freemem", + [MemoryEffects<[MemFree]>]> { + let summary = "free memory on an openmp device"; + + let description = [{ + Deallocates memory on the specified OpenMP device that was previously + allocated by an `omp.target_allocmem` operation. The memory is placed + in an undefined state after deallocation. + ``` + %device = arith.constant 0 : i32 + %ptr = omp.target_allocmem %device : i32, vector<3x3xi32> + omp.target_freemem %device, %ptr : i32, i64 + ``` + }]; + + let arguments = (ins + Arg:$device, + Arg:$heapref + ); + let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))"; +} + #endif // OPENMP_OPS diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index e94d570b57122..5e3ad6ccf3ffe 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -3493,6 +3493,108 @@ LogicalResult ScanOp::verify() { "reduction modifier"); } +//===----------------------------------------------------------------------===// +// TargetAllocMemOp +//===----------------------------------------------------------------------===// + +mlir::Type omp::TargetAllocMemOp::getAllocatedType() { + return getInTypeAttr().getValue(); +} + +/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype, +/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )? +/// attr-dict-without-keyword +static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + auto &builder = parser.getBuilder(); + bool hasOperands = false; + std::int32_t typeparamsSize = 0; + + // Parse device number as a new operand + mlir::OpAsmParser::UnresolvedOperand deviceOperand; + mlir::Type deviceType; + if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType)) + return mlir::failure(); + if (parser.resolveOperand(deviceOperand, deviceType, result.operands)) + return mlir::failure(); + if (parser.parseComma()) + return mlir::failure(); + + mlir::Type intype; + if (parser.parseType(intype)) + return mlir::failure(); + result.addAttribute("in_type", mlir::TypeAttr::get(intype)); + llvm::SmallVector operands; + llvm::SmallVector typeVec; + if (!parser.parseOptionalLParen()) { + // parse the LEN params of the derived type. ( : ) + if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || + parser.parseColonTypeList(typeVec) || parser.parseRParen()) + return mlir::failure(); + typeparamsSize = operands.size(); + hasOperands = true; + } + std::int32_t shapeSize = 0; + if (!parser.parseOptionalComma()) { + // parse size to scale by, vector of n dimensions of type index + if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None)) + return mlir::failure(); + shapeSize = operands.size() - typeparamsSize; + auto idxTy = builder.getIndexType(); + for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i) + typeVec.push_back(idxTy); + hasOperands = true; + } + if (hasOperands && + parser.resolveOperands(operands, typeVec, parser.getNameLoc(), + result.operands)) + return mlir::failure(); + + mlir::Type restype = builder.getIntegerType(64); + ; + if (!restype) { + parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype; + return mlir::failure(); + } + llvm::SmallVector segmentSizes{1, typeparamsSize, shapeSize}; + result.addAttribute("operandSegmentSizes", + builder.getDenseI32ArrayAttr(segmentSizes)); + if (parser.parseOptionalAttrDict(result.attributes) || + parser.addTypeToList(restype, result.types)) + return mlir::failure(); + return mlir::success(); +} + +mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseTargetAllocMemOp(parser, result); +} + +void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) { + p << " "; + p.printOperand(getDevice()); + p << " : "; + p << getDevice().getType(); + p << ", "; + p << getInType(); + if (!getTypeparams().empty()) { + p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')'; + } + for (auto sh : getShape()) { + p << ", "; + p.printOperand(sh); + } + p.printOptionalAttrDict((*this)->getAttrs(), + {"in_type", "operandSegmentSizes"}); +} + +llvm::LogicalResult omp::TargetAllocMemOp::verify() { + mlir::Type outType = getType(); + if (!mlir::dyn_cast(outType)) + return emitOpError("must be a integer type"); + return mlir::success(); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index eece8573f00ec..63b41f3b13363 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -5714,6 +5714,82 @@ static bool isTargetDeviceOp(Operation *op) { return false; } +static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder, + llvm::Module *llvmModule) { + llvm::Type *i64Ty = builder.getInt64Ty(); + llvm::Type *i32Ty = builder.getInt32Ty(); + llvm::Type *returnType = builder.getPtrTy(0); + llvm::FunctionType *fnType = + llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false); + llvm::Function *func = cast( + llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee()); + return func; +} + +static LogicalResult +convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto allocMemOp = cast(opInst); + if (!allocMemOp) + return failure(); + + // Get "omp_target_alloc" function + llvm::Module *llvmModule = moduleTranslation.getLLVMModule(); + llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule); + // Get the corresponding device value in llvm + mlir::Value deviceNum = allocMemOp.getDevice(); + llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum); + // Get the allocation size. + llvm::DataLayout dataLayout = llvmModule->getDataLayout(); + mlir::Type heapTy = allocMemOp.getAllocatedType(); + llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy); + llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy); + llvm::ConstantInt *allocSize = builder.getInt64(typeSize.getFixedValue()); + // Create call to "omp_target_alloc" with the args as translated llvm values. + llvm::CallInst *call = + builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum}); + llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty()); + + // Map the result + moduleTranslation.mapValue(allocMemOp.getResult(), resultI64); + return success(); +} + +static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder, + llvm::Module *llvmModule) { + llvm::Type *ptrTy = builder.getPtrTy(0); + llvm::Type *i32Ty = builder.getInt32Ty(); + llvm::Type *voidTy = builder.getVoidTy(); + llvm::FunctionType *fnType = + llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false); + llvm::Function *func = dyn_cast( + llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee()); + return func; +} + +static LogicalResult +convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto freeMemOp = cast(opInst); + if (!freeMemOp) + return failure(); + + // Get "omp_target_free" function + llvm::Module *llvmModule = moduleTranslation.getLLVMModule(); + llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule); + // Get the corresponding device value in llvm + mlir::Value deviceNum = freeMemOp.getDevice(); + llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum); + // Get the corresponding heapref value in llvm + mlir::Value heapref = freeMemOp.getHeapref(); + llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref); + // Convert heapref int to ptr and call "omp_target_free" + llvm::Value *intToPtr = + builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0)); + builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum}); + return success(); +} + /// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including /// OpenMP runtime calls). static LogicalResult @@ -5871,6 +5947,12 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, // etc. and then discarded return success(); }) + .Case([&](omp::TargetAllocMemOp) { + return convertTargetAllocMemOp(*op, builder, moduleTranslation); + }) + .Case([&](omp::TargetFreeMemOp) { + return convertTargetFreeMemOp(*op, builder, moduleTranslation); + }) .Default([&](Operation *inst) { return inst->emitError() << "not yet implemented: " << inst->getName(); diff --git a/mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir b/mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir new file mode 100644 index 0000000000000..1bc97609ccff4 --- /dev/null +++ b/mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt %s -convert-openmp-to-llvm | mlir-translate -mlir-to-llvmir | FileCheck %s + +// This file contains MLIR test cases for omp.target_allocmem and omp.target_freemem + +// CHECK-LABEL: test_alloc_free_i64 +// CHECK: %[[ALLOC:.*]] = call ptr @omp_target_alloc(i64 8, i32 0) +// CHECK: %[[PTRTOINT:.*]] = ptrtoint ptr %[[ALLOC]] to i64 +// CHECK: %[[INTTOPTR:.*]] = inttoptr i64 %[[PTRTOINT]] to ptr +// CHECK: call void @omp_target_free(ptr %[[INTTOPTR]], i32 0) +// CHECK: ret void +llvm.func @test_alloc_free_i64() -> () { + %device = llvm.mlir.constant(0 : i32) : i32 + %1 = omp.target_allocmem %device : i32, i64 + omp.target_freemem %device, %1 : i32, i64 + llvm.return +} + +// CHECK-LABEL: test_alloc_free_vector_1d_f32 +// CHECK: %[[ALLOC:.*]] = call ptr @omp_target_alloc(i64 64, i32 0) +// CHECK: %[[PTRTOINT:.*]] = ptrtoint ptr %[[ALLOC]] to i64 +// CHECK: %[[INTTOPTR:.*]] = inttoptr i64 %[[PTRTOINT]] to ptr +// CHECK: call void @omp_target_free(ptr %[[INTTOPTR]], i32 0) +// CHECK: ret void +llvm.func @test_alloc_free_vector_1d_f32() -> () { + %device = llvm.mlir.constant(0 : i32) : i32 + %1 = omp.target_allocmem %device : i32, vector<16xf32> + omp.target_freemem %device, %1 : i32, i64 + llvm.return +} + +// CHECK-LABEL: test_alloc_free_vector_2d_f32 +// CHECK: %[[ALLOC:.*]] = call ptr @omp_target_alloc(i64 1024, i32 0) +// CHECK: %[[PTRTOINT:.*]] = ptrtoint ptr %[[ALLOC]] to i64 +// CHECK: %[[INTTOPTR:.*]] = inttoptr i64 %[[PTRTOINT]] to ptr +// CHECK: call void @omp_target_free(ptr %[[INTTOPTR]], i32 0) +// CHECK: ret void +llvm.func @test_alloc_free_vector_2d_f32() -> () { + %device = llvm.mlir.constant(0 : i32) : i32 + %1 = omp.target_allocmem %device : i32, vector<16x16xf32> + omp.target_freemem %device, %1 : i32, i64 + llvm.return +} From 1e7a216d3d9e07146ebb364025983e0f5000eefa Mon Sep 17 00:00:00 2001 From: skc7 Date: Thu, 3 Jul 2025 21:05:35 +0530 Subject: [PATCH 5/6] Fix comments --- flang/lib/Optimizer/Dialect/FIROps.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index f6c794b8fe9ae..ba5c81d826e5e 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -107,7 +107,6 @@ static bool verifyTypeParamCount(mlir::Type inType, unsigned numParams) { } /// Parser shared by Alloca and Allocmem -/// boolean flag isTargetOp is used to identify omp_target_allocmem /// operation ::= %res = (`fir.alloca` | `fir.allocmem`) $in_type /// ( `(` $typeparams `)` )? ( `,` $shape )? /// attr-dict-without-keyword From 3879eb7271098d52281aca44e77fc82af9987f02 Mon Sep 17 00:00:00 2001 From: skc7 Date: Wed, 9 Jul 2025 15:49:06 +0530 Subject: [PATCH 6/6] clang format --- flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp index a04c5d7eb7ee7..14cc7bb511f0f 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp @@ -142,8 +142,8 @@ static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) { /*isVarArg=*/false)); } -static mlir::Type -convertObjectType(const fir::LLVMTypeConverter &converter, mlir::Type firType) { +static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter, + mlir::Type firType) { if (auto boxTy = mlir::dyn_cast(firType)) return converter.convertBoxTypeAsStruct(boxTy); return converter.convertType(firType); @@ -189,8 +189,9 @@ computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType, } static mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy, - mlir::ConversionPatternRewriter &rewriter, - mlir::Type llTy, const mlir::DataLayout &dataLayout) { + mlir::ConversionPatternRewriter &rewriter, + mlir::Type llTy, + const mlir::DataLayout &dataLayout) { return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout); } @@ -224,8 +225,10 @@ genAllocationScaleSize(OP op, mlir::Type ity, } static mlir::Value integerCast(const fir::LLVMTypeConverter &converter, - mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, - mlir::Type ty, mlir::Value val, bool fold = false) { + mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter, + mlir::Type ty, mlir::Value val, + bool fold = false) { auto valTy = val.getType(); // If the value was not yet lowered, lower its type so that it can // be used in getPrimitiveTypeSizeInBits. @@ -261,11 +264,13 @@ struct TargetAllocMemOpConversion auto ity = lowerTy().indexType(); mlir::Type dataTy = fir::unwrapRefType(heapTy); mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy); - mlir::Type llvmPtrTy = mlir::LLVM::LLVMPointerType::get(allocmemOp.getContext(), 0); + mlir::Type llvmPtrTy = + mlir::LLVM::LLVMPointerType::get(allocmemOp.getContext(), 0); if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy))) TODO(loc, "omp.target_allocmem codegen of derived type with length " "parameters"); - mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy, lowerTy().getDataLayout()); + mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy, + lowerTy().getDataLayout()); if (auto scaleSize = genAllocationScaleSize(allocmemOp, ity, rewriter)) size = rewriter.create(loc, ity, size, scaleSize); for (mlir::Value opnd : adaptor.getOperands().drop_front()) @@ -281,7 +286,8 @@ struct TargetAllocMemOpConversion loc, llvmPtrTy, mlir::SmallVector({size, allocmemOp.getDevice()}), addLLVMOpBundleAttrs(rewriter, allocmemOp->getAttrs(), 2)); - rewriter.replaceOpWithNewOp(allocmemOp, rewriter.getIntegerType(64), callOp.getResult()); + rewriter.replaceOpWithNewOp( + allocmemOp, rewriter.getIntegerType(64), callOp.getResult()); return mlir::success(); } };