diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp index 37f1c9f97e1ce..14cc7bb511f0f 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp @@ -125,10 +125,177 @@ struct PrivateClauseOpConversion return mlir::success(); } }; + +static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) { + auto module = op->getParentOfType(); + if (mlir::LLVM::LLVMFuncOp mallocFunc = + module.lookupSymbol("omp_target_alloc")) + return mallocFunc; + mlir::OpBuilder moduleBuilder(module.getBodyRegion()); + auto i64Ty = mlir::IntegerType::get(module->getContext(), 64); + auto i32Ty = mlir::IntegerType::get(module->getContext(), 32); + return moduleBuilder.create( + moduleBuilder.getUnknownLoc(), "omp_target_alloc", + mlir::LLVM::LLVMFunctionType::get( + mlir::LLVM::LLVMPointerType::get(module->getContext()), + {i64Ty, i32Ty}, + /*isVarArg=*/false)); +} + +static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter, + mlir::Type firType) { + if (auto boxTy = mlir::dyn_cast(firType)) + return converter.convertBoxTypeAsStruct(boxTy); + return converter.convertType(firType); +} + +static llvm::SmallVector +addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter, + llvm::ArrayRef attrs, + int32_t numCallOperands) { + llvm::SmallVector newAttrs; + newAttrs.reserve(attrs.size() + 2); + + for (mlir::NamedAttribute attr : attrs) { + if (attr.getName() != "operandSegmentSizes") + newAttrs.push_back(attr); + } + + newAttrs.push_back(rewriter.getNamedAttr( + "operandSegmentSizes", + rewriter.getDenseI32ArrayAttr({numCallOperands, 0}))); + newAttrs.push_back(rewriter.getNamedAttr("op_bundle_sizes", + rewriter.getDenseI32ArrayAttr({}))); + return newAttrs; +} + +static mlir::LLVM::ConstantOp +genConstantIndex(mlir::Location loc, mlir::Type ity, + mlir::ConversionPatternRewriter &rewriter, + std::int64_t offset) { + auto cattr = rewriter.getI64IntegerAttr(offset); + return rewriter.create(loc, ity, cattr); +} + +static mlir::Value +computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType, + mlir::Type idxTy, + mlir::ConversionPatternRewriter &rewriter, + const mlir::DataLayout &dataLayout) { + llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType); + unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType); + std::int64_t distance = llvm::alignTo(size, alignment); + return genConstantIndex(loc, idxTy, rewriter, distance); +} + +static mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy, + mlir::ConversionPatternRewriter &rewriter, + mlir::Type llTy, + const mlir::DataLayout &dataLayout) { + return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout); +} + +template +static mlir::Value +genAllocationScaleSize(OP op, mlir::Type ity, + mlir::ConversionPatternRewriter &rewriter) { + mlir::Location loc = op.getLoc(); + mlir::Type dataTy = op.getInType(); + auto seqTy = mlir::dyn_cast(dataTy); + fir::SequenceType::Extent constSize = 1; + if (seqTy) { + int constRows = seqTy.getConstantRows(); + const fir::SequenceType::ShapeRef &shape = seqTy.getShape(); + if (constRows != static_cast(shape.size())) { + for (auto extent : shape) { + if (constRows-- > 0) + continue; + if (extent != fir::SequenceType::getUnknownExtent()) + constSize *= extent; + } + } + } + + if (constSize != 1) { + mlir::Value constVal{ + genConstantIndex(loc, ity, rewriter, constSize).getResult()}; + return constVal; + } + return nullptr; +} + +static mlir::Value integerCast(const fir::LLVMTypeConverter &converter, + mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter, + mlir::Type ty, mlir::Value val, + bool fold = false) { + auto valTy = val.getType(); + // If the value was not yet lowered, lower its type so that it can + // be used in getPrimitiveTypeSizeInBits. + if (!mlir::isa(valTy)) + valTy = converter.convertType(valTy); + auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty); + auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy); + if (fold) { + if (toSize < fromSize) + return rewriter.createOrFold(loc, ty, val); + if (toSize > fromSize) + return rewriter.createOrFold(loc, ty, val); + } else { + if (toSize < fromSize) + return rewriter.create(loc, ty, val); + if (toSize > fromSize) + return rewriter.create(loc, ty, val); + } + return val; +} + +// FIR Op specific conversion for TargetAllocMemOp +struct TargetAllocMemOpConversion + : public OpenMPFIROpConversion { + using OpenMPFIROpConversion::OpenMPFIROpConversion; + + llvm::LogicalResult + matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Type heapTy = allocmemOp.getAllocatedType(); + mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc(allocmemOp); + mlir::Location loc = allocmemOp.getLoc(); + auto ity = lowerTy().indexType(); + mlir::Type dataTy = fir::unwrapRefType(heapTy); + mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy); + mlir::Type llvmPtrTy = + mlir::LLVM::LLVMPointerType::get(allocmemOp.getContext(), 0); + if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy))) + TODO(loc, "omp.target_allocmem codegen of derived type with length " + "parameters"); + mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy, + lowerTy().getDataLayout()); + if (auto scaleSize = genAllocationScaleSize(allocmemOp, ity, rewriter)) + size = rewriter.create(loc, ity, size, scaleSize); + for (mlir::Value opnd : adaptor.getOperands().drop_front()) + size = rewriter.create( + loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd)); + auto mallocTyWidth = lowerTy().getIndexTypeBitwidth(); + auto mallocTy = + mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth); + if (mallocTyWidth != ity.getIntOrFloatBitWidth()) + size = integerCast(lowerTy(), loc, rewriter, mallocTy, size); + allocmemOp->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc)); + auto callOp = rewriter.create( + loc, llvmPtrTy, + mlir::SmallVector({size, allocmemOp.getDevice()}), + addLLVMOpBundleAttrs(rewriter, allocmemOp->getAttrs(), 2)); + rewriter.replaceOpWithNewOp( + allocmemOp, rewriter.getIntegerType(64), callOp.getResult()); + return mlir::success(); + } +}; } // namespace void fir::populateOpenMPFIRToLLVMConversionPatterns( const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) { patterns.add(converter); patterns.add(converter); + patterns.add(converter); } diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index ecfa2939e96a6..ba5c81d826e5e 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -107,7 +107,6 @@ static bool verifyTypeParamCount(mlir::Type inType, unsigned numParams) { } /// Parser shared by Alloca and Allocmem -/// /// operation ::= %res = (`fir.alloca` | `fir.allocmem`) $in_type /// ( `(` $typeparams `)` )? ( `,` $shape )? /// attr-dict-without-keyword diff --git a/flang/test/Fir/omp_target_allocmem_freemem.fir b/flang/test/Fir/omp_target_allocmem_freemem.fir new file mode 100644 index 0000000000000..9202202728454 --- /dev/null +++ b/flang/test/Fir/omp_target_allocmem_freemem.fir @@ -0,0 +1,33 @@ +// RUN: %flang_fc1 -emit-llvm %s -o - | FileCheck %s + +// CHECK-LABEL: define void @omp_target_allocmem_array_of_nonchar( +// CHECK: call ptr @omp_target_alloc(i64 36, i32 0) +// CHECK: call void @omp_target_free(ptr {{.*}}, i32 0) +func.func @omp_target_allocmem_array_of_nonchar() -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array<3x3xi32> + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_array_of_char( +// CHECK: call ptr @omp_target_alloc(i64 90, i32 0) +// CHECK: call void @omp_target_free(ptr {{.*}}, i32 0) +func.func @omp_target_allocmem_array_of_char() -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,10>> + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_array_of_dynchar( +// CHECK-SAME: i32 %[[len:.*]]) +// CHECK: %[[mul1:.*]] = sext i32 %[[len]] to i64 +// CHECK: %[[mul2:.*]] = mul i64 9, %[[mul1]] +// CHECK: call ptr @omp_target_alloc(i64 %[[mul2]], i32 0) +func.func @omp_target_allocmem_array_of_dynchar(%l: i32) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,?>>(%l : i32) + omp.target_freemem %device, %1 : i32, i64 + return +} diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index ac80926053a2d..bcc7c906041ac 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1887,4 +1887,68 @@ def MaskedOp : OpenMP_Op<"masked", clauses = [ ]; } +//===----------------------------------------------------------------------===// +// TargetAllocMemOp +//===----------------------------------------------------------------------===// + +def TargetAllocMemOp : OpenMP_Op<"target_allocmem", + [MemoryEffects<[MemAlloc]>, AttrSizedOperandSegments]> { + let summary = "allocate storage on an openmp device for an object of a given type"; + + let description = [{ + Allocates memory on the specified OpenMP device for an object of the given type. + Returns an integer value representing the device pointer to the allocated memory. + The memory is uninitialized after allocation. Operations must be paired with + `omp.target_freemem` to avoid memory leaks. + + ```mlir + %device = arith.constant 0 : i32 + %ptr = omp.target_allocmem %device : i32, vector<3x3xi32> + ``` + }]; + + let arguments = (ins + Arg:$device, + TypeAttr:$in_type, + OptionalAttr:$uniq_name, + OptionalAttr:$bindc_name, + Variadic:$typeparams, + Variadic:$shape + ); + let results = (outs I64); + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + mlir::Type getAllocatedType(); + }]; +} + +//===----------------------------------------------------------------------===// +// TargetFreeMemOp +//===----------------------------------------------------------------------===// + +def TargetFreeMemOp : OpenMP_Op<"target_freemem", + [MemoryEffects<[MemFree]>]> { + let summary = "free memory on an openmp device"; + + let description = [{ + Deallocates memory on the specified OpenMP device that was previously + allocated by an `omp.target_allocmem` operation. The memory is placed + in an undefined state after deallocation. + ``` + %device = arith.constant 0 : i32 + %ptr = omp.target_allocmem %device : i32, vector<3x3xi32> + omp.target_freemem %device, %ptr : i32, i64 + ``` + }]; + + let arguments = (ins + Arg:$device, + Arg:$heapref + ); + let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))"; +} + #endif // OPENMP_OPS diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index e94d570b57122..5e3ad6ccf3ffe 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -3493,6 +3493,108 @@ LogicalResult ScanOp::verify() { "reduction modifier"); } +//===----------------------------------------------------------------------===// +// TargetAllocMemOp +//===----------------------------------------------------------------------===// + +mlir::Type omp::TargetAllocMemOp::getAllocatedType() { + return getInTypeAttr().getValue(); +} + +/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype, +/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )? +/// attr-dict-without-keyword +static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + auto &builder = parser.getBuilder(); + bool hasOperands = false; + std::int32_t typeparamsSize = 0; + + // Parse device number as a new operand + mlir::OpAsmParser::UnresolvedOperand deviceOperand; + mlir::Type deviceType; + if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType)) + return mlir::failure(); + if (parser.resolveOperand(deviceOperand, deviceType, result.operands)) + return mlir::failure(); + if (parser.parseComma()) + return mlir::failure(); + + mlir::Type intype; + if (parser.parseType(intype)) + return mlir::failure(); + result.addAttribute("in_type", mlir::TypeAttr::get(intype)); + llvm::SmallVector operands; + llvm::SmallVector typeVec; + if (!parser.parseOptionalLParen()) { + // parse the LEN params of the derived type. ( : ) + if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || + parser.parseColonTypeList(typeVec) || parser.parseRParen()) + return mlir::failure(); + typeparamsSize = operands.size(); + hasOperands = true; + } + std::int32_t shapeSize = 0; + if (!parser.parseOptionalComma()) { + // parse size to scale by, vector of n dimensions of type index + if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None)) + return mlir::failure(); + shapeSize = operands.size() - typeparamsSize; + auto idxTy = builder.getIndexType(); + for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i) + typeVec.push_back(idxTy); + hasOperands = true; + } + if (hasOperands && + parser.resolveOperands(operands, typeVec, parser.getNameLoc(), + result.operands)) + return mlir::failure(); + + mlir::Type restype = builder.getIntegerType(64); + ; + if (!restype) { + parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype; + return mlir::failure(); + } + llvm::SmallVector segmentSizes{1, typeparamsSize, shapeSize}; + result.addAttribute("operandSegmentSizes", + builder.getDenseI32ArrayAttr(segmentSizes)); + if (parser.parseOptionalAttrDict(result.attributes) || + parser.addTypeToList(restype, result.types)) + return mlir::failure(); + return mlir::success(); +} + +mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseTargetAllocMemOp(parser, result); +} + +void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) { + p << " "; + p.printOperand(getDevice()); + p << " : "; + p << getDevice().getType(); + p << ", "; + p << getInType(); + if (!getTypeparams().empty()) { + p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')'; + } + for (auto sh : getShape()) { + p << ", "; + p.printOperand(sh); + } + p.printOptionalAttrDict((*this)->getAttrs(), + {"in_type", "operandSegmentSizes"}); +} + +llvm::LogicalResult omp::TargetAllocMemOp::verify() { + mlir::Type outType = getType(); + if (!mlir::dyn_cast(outType)) + return emitOpError("must be a integer type"); + return mlir::success(); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index eece8573f00ec..63b41f3b13363 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -5714,6 +5714,82 @@ static bool isTargetDeviceOp(Operation *op) { return false; } +static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder, + llvm::Module *llvmModule) { + llvm::Type *i64Ty = builder.getInt64Ty(); + llvm::Type *i32Ty = builder.getInt32Ty(); + llvm::Type *returnType = builder.getPtrTy(0); + llvm::FunctionType *fnType = + llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false); + llvm::Function *func = cast( + llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee()); + return func; +} + +static LogicalResult +convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto allocMemOp = cast(opInst); + if (!allocMemOp) + return failure(); + + // Get "omp_target_alloc" function + llvm::Module *llvmModule = moduleTranslation.getLLVMModule(); + llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule); + // Get the corresponding device value in llvm + mlir::Value deviceNum = allocMemOp.getDevice(); + llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum); + // Get the allocation size. + llvm::DataLayout dataLayout = llvmModule->getDataLayout(); + mlir::Type heapTy = allocMemOp.getAllocatedType(); + llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy); + llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy); + llvm::ConstantInt *allocSize = builder.getInt64(typeSize.getFixedValue()); + // Create call to "omp_target_alloc" with the args as translated llvm values. + llvm::CallInst *call = + builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum}); + llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty()); + + // Map the result + moduleTranslation.mapValue(allocMemOp.getResult(), resultI64); + return success(); +} + +static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder, + llvm::Module *llvmModule) { + llvm::Type *ptrTy = builder.getPtrTy(0); + llvm::Type *i32Ty = builder.getInt32Ty(); + llvm::Type *voidTy = builder.getVoidTy(); + llvm::FunctionType *fnType = + llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false); + llvm::Function *func = dyn_cast( + llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee()); + return func; +} + +static LogicalResult +convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto freeMemOp = cast(opInst); + if (!freeMemOp) + return failure(); + + // Get "omp_target_free" function + llvm::Module *llvmModule = moduleTranslation.getLLVMModule(); + llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule); + // Get the corresponding device value in llvm + mlir::Value deviceNum = freeMemOp.getDevice(); + llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum); + // Get the corresponding heapref value in llvm + mlir::Value heapref = freeMemOp.getHeapref(); + llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref); + // Convert heapref int to ptr and call "omp_target_free" + llvm::Value *intToPtr = + builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0)); + builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum}); + return success(); +} + /// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including /// OpenMP runtime calls). static LogicalResult @@ -5871,6 +5947,12 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, // etc. and then discarded return success(); }) + .Case([&](omp::TargetAllocMemOp) { + return convertTargetAllocMemOp(*op, builder, moduleTranslation); + }) + .Case([&](omp::TargetFreeMemOp) { + return convertTargetFreeMemOp(*op, builder, moduleTranslation); + }) .Default([&](Operation *inst) { return inst->emitError() << "not yet implemented: " << inst->getName(); diff --git a/mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir b/mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir new file mode 100644 index 0000000000000..1bc97609ccff4 --- /dev/null +++ b/mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt %s -convert-openmp-to-llvm | mlir-translate -mlir-to-llvmir | FileCheck %s + +// This file contains MLIR test cases for omp.target_allocmem and omp.target_freemem + +// CHECK-LABEL: test_alloc_free_i64 +// CHECK: %[[ALLOC:.*]] = call ptr @omp_target_alloc(i64 8, i32 0) +// CHECK: %[[PTRTOINT:.*]] = ptrtoint ptr %[[ALLOC]] to i64 +// CHECK: %[[INTTOPTR:.*]] = inttoptr i64 %[[PTRTOINT]] to ptr +// CHECK: call void @omp_target_free(ptr %[[INTTOPTR]], i32 0) +// CHECK: ret void +llvm.func @test_alloc_free_i64() -> () { + %device = llvm.mlir.constant(0 : i32) : i32 + %1 = omp.target_allocmem %device : i32, i64 + omp.target_freemem %device, %1 : i32, i64 + llvm.return +} + +// CHECK-LABEL: test_alloc_free_vector_1d_f32 +// CHECK: %[[ALLOC:.*]] = call ptr @omp_target_alloc(i64 64, i32 0) +// CHECK: %[[PTRTOINT:.*]] = ptrtoint ptr %[[ALLOC]] to i64 +// CHECK: %[[INTTOPTR:.*]] = inttoptr i64 %[[PTRTOINT]] to ptr +// CHECK: call void @omp_target_free(ptr %[[INTTOPTR]], i32 0) +// CHECK: ret void +llvm.func @test_alloc_free_vector_1d_f32() -> () { + %device = llvm.mlir.constant(0 : i32) : i32 + %1 = omp.target_allocmem %device : i32, vector<16xf32> + omp.target_freemem %device, %1 : i32, i64 + llvm.return +} + +// CHECK-LABEL: test_alloc_free_vector_2d_f32 +// CHECK: %[[ALLOC:.*]] = call ptr @omp_target_alloc(i64 1024, i32 0) +// CHECK: %[[PTRTOINT:.*]] = ptrtoint ptr %[[ALLOC]] to i64 +// CHECK: %[[INTTOPTR:.*]] = inttoptr i64 %[[PTRTOINT]] to ptr +// CHECK: call void @omp_target_free(ptr %[[INTTOPTR]], i32 0) +// CHECK: ret void +llvm.func @test_alloc_free_vector_2d_f32() -> () { + %device = llvm.mlir.constant(0 : i32) : i32 + %1 = omp.target_allocmem %device : i32, vector<16x16xf32> + omp.target_freemem %device, %1 : i32, i64 + llvm.return +}