@@ -1167,6 +1167,105 @@ struct FreeMemOpConversion : public fir::FIROpConversion<fir::FreeMemOp> {
1167
1167
};
1168
1168
} // namespace
1169
1169
1170
+ static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc (mlir::Operation *op) {
1171
+ auto module = op->getParentOfType <mlir::ModuleOp>();
1172
+ if (mlir::LLVM::LLVMFuncOp mallocFunc =
1173
+ module .lookupSymbol <mlir::LLVM::LLVMFuncOp>(" omp_target_alloc" ))
1174
+ return mallocFunc;
1175
+ mlir::OpBuilder moduleBuilder (module .getBodyRegion ());
1176
+ auto i64Ty = mlir::IntegerType::get (module ->getContext (), 64 );
1177
+ auto i32Ty = mlir::IntegerType::get (module ->getContext (), 32 );
1178
+ return moduleBuilder.create <mlir::LLVM::LLVMFuncOp>(
1179
+ moduleBuilder.getUnknownLoc (), " omp_target_alloc" ,
1180
+ mlir::LLVM::LLVMFunctionType::get (
1181
+ mlir::LLVM::LLVMPointerType::get (module ->getContext ()),
1182
+ {i64Ty, i32Ty},
1183
+ /* isVarArg=*/ false ));
1184
+ }
1185
+
1186
+ namespace {
1187
+ struct OmpTargetAllocMemOpConversion
1188
+ : public fir::FIROpConversion<fir::OmpTargetAllocMemOp> {
1189
+ using FIROpConversion::FIROpConversion;
1190
+
1191
+ mlir::LogicalResult
1192
+ matchAndRewrite (fir::OmpTargetAllocMemOp heap, OpAdaptor adaptor,
1193
+ mlir::ConversionPatternRewriter &rewriter) const override {
1194
+ mlir::Type heapTy = heap.getType ();
1195
+ mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc (heap);
1196
+ mlir::Location loc = heap.getLoc ();
1197
+ auto ity = lowerTy ().indexType ();
1198
+ mlir::Type dataTy = fir::unwrapRefType (heapTy);
1199
+ mlir::Type llvmObjectTy = convertObjectType (dataTy);
1200
+ if (fir::isRecordWithTypeParameters (fir::unwrapSequenceType (dataTy)))
1201
+ TODO (loc, " fir.omp_target_allocmem codegen of derived type with length "
1202
+ " parameters" );
1203
+ mlir::Value size = genTypeSizeInBytes (loc, ity, rewriter, llvmObjectTy);
1204
+ if (auto scaleSize = genAllocationScaleSize (heap, ity, rewriter))
1205
+ size = rewriter.create <mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
1206
+ for (mlir::Value opnd : adaptor.getOperands ())
1207
+ size = rewriter.create <mlir::LLVM::MulOp>(
1208
+ loc, ity, size, integerCast (loc, rewriter, ity, opnd));
1209
+ auto mallocTyWidth = lowerTy ().getIndexTypeBitwidth ();
1210
+ auto mallocTy =
1211
+ mlir::IntegerType::get (rewriter.getContext (), mallocTyWidth);
1212
+ if (mallocTyWidth != ity.getIntOrFloatBitWidth ())
1213
+ size = integerCast (loc, rewriter, mallocTy, size);
1214
+ heap->setAttr (" callee" , mlir::SymbolRefAttr::get (mallocFunc));
1215
+ rewriter.replaceOpWithNewOp <mlir::LLVM::CallOp>(
1216
+ heap, ::getLlvmPtrType (heap.getContext ()),
1217
+ mlir::SmallVector<mlir::Value, 2 >({size, heap.getDevice ()}),
1218
+ addLLVMOpBundleAttrs (rewriter, heap->getAttrs (), 2 ));
1219
+ return mlir::success ();
1220
+ }
1221
+
1222
+ // / Compute the allocation size in bytes of the element type of
1223
+ // / \p llTy pointer type. The result is returned as a value of \p idxTy
1224
+ // / integer type.
1225
+ mlir::Value genTypeSizeInBytes (mlir::Location loc, mlir::Type idxTy,
1226
+ mlir::ConversionPatternRewriter &rewriter,
1227
+ mlir::Type llTy) const {
1228
+ return computeElementDistance (loc, llTy, idxTy, rewriter, getDataLayout ());
1229
+ }
1230
+ };
1231
+ } // namespace
1232
+
1233
+ static mlir::LLVM::LLVMFuncOp getOmpTargetFree (mlir::Operation *op) {
1234
+ auto module = op->getParentOfType <mlir::ModuleOp>();
1235
+ if (mlir::LLVM::LLVMFuncOp freeFunc =
1236
+ module .lookupSymbol <mlir::LLVM::LLVMFuncOp>(" omp_target_free" ))
1237
+ return freeFunc;
1238
+ mlir::OpBuilder moduleBuilder (module .getBodyRegion ());
1239
+ auto i32Ty = mlir::IntegerType::get (module ->getContext (), 32 );
1240
+ return moduleBuilder.create <mlir::LLVM::LLVMFuncOp>(
1241
+ moduleBuilder.getUnknownLoc (), " omp_target_free" ,
1242
+ mlir::LLVM::LLVMFunctionType::get (
1243
+ mlir::LLVM::LLVMVoidType::get (module ->getContext ()),
1244
+ {getLlvmPtrType (module ->getContext ()), i32Ty},
1245
+ /* isVarArg=*/ false ));
1246
+ }
1247
+
1248
+ namespace {
1249
+ struct OmpTargetFreeMemOpConversion
1250
+ : public fir::FIROpConversion<fir::OmpTargetFreeMemOp> {
1251
+ using FIROpConversion::FIROpConversion;
1252
+
1253
+ mlir::LogicalResult
1254
+ matchAndRewrite (fir::OmpTargetFreeMemOp freemem, OpAdaptor adaptor,
1255
+ mlir::ConversionPatternRewriter &rewriter) const override {
1256
+ mlir::LLVM::LLVMFuncOp freeFunc = getOmpTargetFree (freemem);
1257
+ mlir::Location loc = freemem.getLoc ();
1258
+ freemem->setAttr (" callee" , mlir::SymbolRefAttr::get (freeFunc));
1259
+ rewriter.create <mlir::LLVM::CallOp>(
1260
+ loc, mlir::TypeRange{},
1261
+ mlir::ValueRange{adaptor.getHeapref (), freemem.getDevice ()},
1262
+ addLLVMOpBundleAttrs (rewriter, freemem->getAttrs (), 2 ));
1263
+ rewriter.eraseOp (freemem);
1264
+ return mlir::success ();
1265
+ }
1266
+ };
1267
+ } // namespace
1268
+
1170
1269
// Convert subcomponent array indices from column-major to row-major ordering.
1171
1270
static llvm::SmallVector<mlir::Value>
1172
1271
convertSubcomponentIndices (mlir::Location loc, mlir::Type eleTy,
@@ -4224,6 +4323,7 @@ void fir::populateFIRToLLVMConversionPatterns(
4224
4323
GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion,
4225
4324
IsPresentOpConversion, LenParamIndexOpConversion, LoadOpConversion,
4226
4325
MulcOpConversion, NegcOpConversion, NoReassocOpConversion,
4326
+ OmpTargetAllocMemOpConversion, OmpTargetFreeMemOpConversion,
4227
4327
SelectCaseOpConversion, SelectOpConversion, SelectRankOpConversion,
4228
4328
SelectTypeOpConversion, ShapeOpConversion, ShapeShiftOpConversion,
4229
4329
ShiftOpConversion, SliceOpConversion, StoreOpConversion,
0 commit comments