Skip to content

Commit cd84582

Browse files
committed
[Flang] Add fir omp target alloc and free ops
This commit is C-P from ivaradanov commit be860ac
1 parent 79dd250 commit cd84582

File tree

2 files changed

+171
-10
lines changed

2 files changed

+171
-10
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,105 @@ struct FreeMemOpConversion : public fir::FIROpConversion<fir::FreeMemOp> {
11671167
};
11681168
} // namespace
11691169

1170+
static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc(mlir::Operation *op) {
1171+
auto module = op->getParentOfType<mlir::ModuleOp>();
1172+
if (mlir::LLVM::LLVMFuncOp mallocFunc =
1173+
module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("omp_target_alloc"))
1174+
return mallocFunc;
1175+
mlir::OpBuilder moduleBuilder(module.getBodyRegion());
1176+
auto i64Ty = mlir::IntegerType::get(module->getContext(), 64);
1177+
auto i32Ty = mlir::IntegerType::get(module->getContext(), 32);
1178+
return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
1179+
moduleBuilder.getUnknownLoc(), "omp_target_alloc",
1180+
mlir::LLVM::LLVMFunctionType::get(
1181+
mlir::LLVM::LLVMPointerType::get(module->getContext()),
1182+
{i64Ty, i32Ty},
1183+
/*isVarArg=*/false));
1184+
}
1185+
1186+
namespace {
1187+
struct OmpTargetAllocMemOpConversion
1188+
: public fir::FIROpConversion<fir::OmpTargetAllocMemOp> {
1189+
using FIROpConversion::FIROpConversion;
1190+
1191+
mlir::LogicalResult
1192+
matchAndRewrite(fir::OmpTargetAllocMemOp heap, OpAdaptor adaptor,
1193+
mlir::ConversionPatternRewriter &rewriter) const override {
1194+
mlir::Type heapTy = heap.getType();
1195+
mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc(heap);
1196+
mlir::Location loc = heap.getLoc();
1197+
auto ity = lowerTy().indexType();
1198+
mlir::Type dataTy = fir::unwrapRefType(heapTy);
1199+
mlir::Type llvmObjectTy = convertObjectType(dataTy);
1200+
if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
1201+
TODO(loc, "fir.omp_target_allocmem codegen of derived type with length "
1202+
"parameters");
1203+
mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy);
1204+
if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter))
1205+
size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
1206+
for (mlir::Value opnd : adaptor.getOperands())
1207+
size = rewriter.create<mlir::LLVM::MulOp>(
1208+
loc, ity, size, integerCast(loc, rewriter, ity, opnd));
1209+
auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
1210+
auto mallocTy =
1211+
mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
1212+
if (mallocTyWidth != ity.getIntOrFloatBitWidth())
1213+
size = integerCast(loc, rewriter, mallocTy, size);
1214+
heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc));
1215+
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
1216+
heap, ::getLlvmPtrType(heap.getContext()),
1217+
mlir::SmallVector<mlir::Value, 2>({size, heap.getDevice()}),
1218+
addLLVMOpBundleAttrs(rewriter, heap->getAttrs(), 2));
1219+
return mlir::success();
1220+
}
1221+
1222+
/// Compute the allocation size in bytes of the element type of
1223+
/// \p llTy pointer type. The result is returned as a value of \p idxTy
1224+
/// integer type.
1225+
mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
1226+
mlir::ConversionPatternRewriter &rewriter,
1227+
mlir::Type llTy) const {
1228+
return computeElementDistance(loc, llTy, idxTy, rewriter, getDataLayout());
1229+
}
1230+
};
1231+
} // namespace
1232+
1233+
static mlir::LLVM::LLVMFuncOp getOmpTargetFree(mlir::Operation *op) {
1234+
auto module = op->getParentOfType<mlir::ModuleOp>();
1235+
if (mlir::LLVM::LLVMFuncOp freeFunc =
1236+
module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("omp_target_free"))
1237+
return freeFunc;
1238+
mlir::OpBuilder moduleBuilder(module.getBodyRegion());
1239+
auto i32Ty = mlir::IntegerType::get(module->getContext(), 32);
1240+
return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
1241+
moduleBuilder.getUnknownLoc(), "omp_target_free",
1242+
mlir::LLVM::LLVMFunctionType::get(
1243+
mlir::LLVM::LLVMVoidType::get(module->getContext()),
1244+
{getLlvmPtrType(module->getContext()), i32Ty},
1245+
/*isVarArg=*/false));
1246+
}
1247+
1248+
namespace {
1249+
struct OmpTargetFreeMemOpConversion
1250+
: public fir::FIROpConversion<fir::OmpTargetFreeMemOp> {
1251+
using FIROpConversion::FIROpConversion;
1252+
1253+
mlir::LogicalResult
1254+
matchAndRewrite(fir::OmpTargetFreeMemOp freemem, OpAdaptor adaptor,
1255+
mlir::ConversionPatternRewriter &rewriter) const override {
1256+
mlir::LLVM::LLVMFuncOp freeFunc = getOmpTargetFree(freemem);
1257+
mlir::Location loc = freemem.getLoc();
1258+
freemem->setAttr("callee", mlir::SymbolRefAttr::get(freeFunc));
1259+
rewriter.create<mlir::LLVM::CallOp>(
1260+
loc, mlir::TypeRange{},
1261+
mlir::ValueRange{adaptor.getHeapref(), freemem.getDevice()},
1262+
addLLVMOpBundleAttrs(rewriter, freemem->getAttrs(), 2));
1263+
rewriter.eraseOp(freemem);
1264+
return mlir::success();
1265+
}
1266+
};
1267+
} // namespace
1268+
11701269
// Convert subcomponent array indices from column-major to row-major ordering.
11711270
static llvm::SmallVector<mlir::Value>
11721271
convertSubcomponentIndices(mlir::Location loc, mlir::Type eleTy,
@@ -4224,6 +4323,7 @@ void fir::populateFIRToLLVMConversionPatterns(
42244323
GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion,
42254324
IsPresentOpConversion, LenParamIndexOpConversion, LoadOpConversion,
42264325
MulcOpConversion, NegcOpConversion, NoReassocOpConversion,
4326+
OmpTargetAllocMemOpConversion, OmpTargetFreeMemOpConversion,
42274327
SelectCaseOpConversion, SelectOpConversion, SelectRankOpConversion,
42284328
SelectTypeOpConversion, ShapeOpConversion, ShapeShiftOpConversion,
42294329
ShiftOpConversion, SliceOpConversion, StoreOpConversion,

flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp

Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,15 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
751751
return SplitResult{preTargetOp, isolatedTargetOp, postTargetOp};
752752
}
753753

754+
static mlir::LLVM::ConstantOp
755+
genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
756+
mlir::Type i32Ty = rewriter.getI32Type();
757+
mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value);
758+
return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr);
759+
}
760+
761+
static Type getOmpDeviceType(MLIRContext *c) { return IntegerType::get(c, 32); }
762+
754763
static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
755764
OpBuilder::InsertionGuard guard(rewriter);
756765
Block *targetBlock = &targetOp.getRegion().front();
@@ -776,22 +785,74 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) {
776785
if (isRuntimeCall(op))
777786
runtimeCall = cast<fir::CallOp>(op);
778787

779-
if (allocOp || freeOp || runtimeCall)
780-
continue;
781-
opsToMove.push_back(op);
788+
if (allocOp || freeOp || runtimeCall) {
789+
Value device = targetOp.getDevice();
790+
if (!device) {
791+
device = genI32Constant(it->getLoc(), rewriter, 0);
792+
}
793+
if (allocOp) {
794+
auto tmpAllocOp = rewriter.create<fir::OmpTargetAllocMemOp>(
795+
allocOp.getLoc(), allocOp.getType(), device,
796+
allocOp.getInTypeAttr(), allocOp.getUniqNameAttr(),
797+
allocOp.getBindcNameAttr(), allocOp.getTypeparams(),
798+
allocOp.getShape());
799+
auto newAllocOp = cast<fir::OmpTargetAllocMemOp>(
800+
rewriter.clone(*tmpAllocOp.getOperation(), mapping));
801+
mapping.map(allocOp.getResult(), newAllocOp.getResult());
802+
rewriter.eraseOp(tmpAllocOp);
803+
} else if (freeOp) {
804+
auto tmpFreeOp = rewriter.create<fir::OmpTargetFreeMemOp>(
805+
freeOp.getLoc(), device, freeOp.getHeapref());
806+
rewriter.clone(*tmpFreeOp.getOperation(), mapping);
807+
rewriter.eraseOp(tmpFreeOp);
808+
} else if (runtimeCall) {
809+
auto module = runtimeCall->getParentOfType<ModuleOp>();
810+
auto callee = cast<func::FuncOp>(
811+
module.lookupSymbol(runtimeCall.getCalleeAttr()));
812+
std::string newCalleeName = (callee.getName()).str();
813+
mlir::OpBuilder moduleBuilder(module.getBodyRegion());
814+
func::FuncOp newCallee =
815+
cast_or_null<func::FuncOp>(module.lookupSymbol(newCalleeName));
816+
if (!newCallee) {
817+
SmallVector<Type> argTypes(callee.getFunctionType().getInputs());
818+
argTypes.push_back(getOmpDeviceType(rewriter.getContext()));
819+
newCallee = moduleBuilder.create<func::FuncOp>(
820+
callee->getLoc(), newCalleeName,
821+
FunctionType::get(rewriter.getContext(), argTypes,
822+
callee.getFunctionType().getResults()));
823+
if (callee.getArgAttrs())
824+
newCallee.setArgAttrsAttr(*callee.getArgAttrs());
825+
if (callee.getResAttrs())
826+
newCallee.setResAttrsAttr(*callee.getResAttrs());
827+
newCallee.setSymVisibility(callee.getSymVisibility());
828+
newCallee->setDiscardableAttrs(
829+
callee->getDiscardableAttrDictionary());
830+
}
831+
SmallVector<Value> operands = runtimeCall.getOperands();
832+
operands.push_back(device);
833+
auto tmpCall = rewriter.create<fir::CallOp>(
834+
runtimeCall.getLoc(), runtimeCall.getResultTypes(),
835+
SymbolRefAttr::get(newCallee), operands, nullptr, nullptr, nullptr,
836+
runtimeCall.getFastmathAttr());
837+
Operation *newCall = rewriter.clone(*tmpCall, mapping);
838+
mapping.map(&*it, newCall);
839+
rewriter.eraseOp(tmpCall);
840+
}
841+
} else {
842+
Operation *clonedOp = rewriter.clone(*op, mapping);
843+
for (unsigned i = 0; i < op->getNumResults(); ++i) {
844+
mapping.map(op->getResult(i), clonedOp->getResult(i));
845+
}
846+
}
782847
}
783-
// Move ops before targetOp and erase from region
784-
for (Operation *op : opsToMove)
785-
rewriter.clone(*op, mapping);
786-
787848
rewriter.eraseOp(targetOp);
788849
}
789850

790851
void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) {
791852
auto tuple = getNestedOpToIsolate(targetOp);
792853
if (!tuple) {
793854
LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n");
794-
//moveToHost(targetOp, rewriter);
855+
moveToHost(targetOp, rewriter);
795856
return;
796857
}
797858

@@ -801,13 +862,13 @@ void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) {
801862

802863
if (splitBefore && splitAfter) {
803864
auto res = isolateOp(toIsolate, splitAfter, rewriter);
804-
//moveToHost(res.preTargetOp, rewriter);
865+
moveToHost(res.preTargetOp, rewriter);
805866
fissionTarget(res.postTargetOp, rewriter);
806867
return;
807868
}
808869
if (splitBefore) {
809870
auto res = isolateOp(toIsolate, splitAfter, rewriter);
810-
//moveToHost(res.preTargetOp, rewriter);
871+
moveToHost(res.preTargetOp, rewriter);
811872
return;
812873
}
813874
if (splitAfter) {

0 commit comments

Comments
 (0)