diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td index 704faf0ccd856..743b6d381ed42 100644 --- a/flang/include/flang/Optimizer/OpenMP/Passes.td +++ b/flang/include/flang/Optimizer/OpenMP/Passes.td @@ -93,6 +93,10 @@ def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> { let summary = "Lower workshare construct"; } +def LowerWorkdistribute : Pass<"lower-workdistribute", "::mlir::ModuleOp"> { + let summary = "Lower workdistribute construct"; +} + def GenericLoopConversionPass : Pass<"omp-generic-loop-conversion", "mlir::func::FuncOp"> { let summary = "Converts OpenMP generic `omp.loop` to semantically " diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt index e31543328a9f9..cd746834741f9 100644 --- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt +++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt @@ -7,6 +7,7 @@ add_flang_library(FlangOpenMPTransforms MapsForPrivatizedSymbols.cpp MapInfoFinalization.cpp MarkDeclareTarget.cpp + LowerWorkdistribute.cpp LowerWorkshare.cpp LowerNontemporal.cpp diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp new file mode 100644 index 0000000000000..3f4116d524452 --- /dev/null +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -0,0 +1,898 @@ +//===- LowerWorkshare.cpp - special cases for bufferization -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the lowering and optimisations of omp.workdistribute. +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/HLFIR/Passes.h" +#include "flang/Optimizer/OpenMP/Utils.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "llvm/Frontend/OpenMP/OMPConstants.h" +#include +#include + +namespace flangomp { +#define GEN_PASS_DEF_LOWERWORKDISTRIBUTE +#include "flang/Optimizer/OpenMP/Passes.h.inc" +} // namespace flangomp + +#define DEBUG_TYPE "lower-workdistribute" + +using namespace mlir; + +namespace { + +static bool isRuntimeCall(Operation *op) { + if (auto callOp = dyn_cast(op)) { + auto callee = callOp.getCallee(); + if (!callee) + return false; + auto *func = op->getParentOfType().lookupSymbol(*callee); + if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName())) + return true; + } + return false; +} + +/// This is the single source of truth about whether we should parallelize an +/// operation nested in an omp.execute region. +static bool shouldParallelize(Operation *op) { + if (llvm::any_of(op->getResults(), + [](OpResult v) -> bool { return !v.use_empty(); })) + return false; + // We will parallelize unordered loops - these come from array syntax + if (auto loop = dyn_cast(op)) { + auto unordered = loop.getUnordered(); + if (!unordered) + return false; + return *unordered; + } + if (isRuntimeCall(op)) { + return true; + } + // We cannot parallise anything else + return false; +} + +template +static T getPerfectlyNested(Operation *op) { + if (op->getNumRegions() != 1) + return nullptr; + auto ®ion = op->getRegion(0); + if (region.getBlocks().size() != 1) + return nullptr; + auto *block = ®ion.front(); + auto *firstOp = &block->front(); + if (auto nested = dyn_cast(firstOp)) + if (firstOp->getNextNode() == block->getTerminator()) + return nested; + return nullptr; +} + +/// If B() and D() are parallelizable, +/// +/// omp.teams { +/// omp.workdistribute { +/// A() +/// B() +/// C() +/// D() +/// E() +/// } +/// } +/// +/// becomes +/// +/// A() +/// omp.teams { +/// omp.workdistribute { +/// B() +/// } +/// } +/// C() +/// omp.teams { +/// omp.workdistribute { +/// D() +/// } +/// } +/// E() + +static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) { + OpBuilder rewriter(workdistribute); + auto loc = workdistribute->getLoc(); + auto teams = dyn_cast(workdistribute->getParentOp()); + if (!teams) { + emitError(loc, "workdistribute not nested in teams\n"); + return false; + } + if (workdistribute.getRegion().getBlocks().size() != 1) { + emitError(loc, "workdistribute with multiple blocks\n"); + return false; + } + if (teams.getRegion().getBlocks().size() != 1) { + emitError(loc, "teams with multiple blocks\n"); + return false; + } + + auto *teamsBlock = &teams.getRegion().front(); + bool changed = false; + // Move the ops inside teams and before workdistribute outside. + IRMapping irMapping; + llvm::SmallVector teamsHoisted; + for (auto &op : teams.getOps()) { + if (&op == workdistribute) { + break; + } + if (shouldParallelize(&op)) { + emitError(loc, "teams has parallelize ops before first workdistribute\n"); + return false; + } else { + rewriter.setInsertionPoint(teams); + rewriter.clone(op, irMapping); + teamsHoisted.push_back(&op); + changed = true; + } + } + for (auto *op : llvm::reverse(teamsHoisted)) { + op->replaceAllUsesWith(irMapping.lookup(op)); + op->erase(); + } + + // While we have unhandled operations in the original workdistribute + auto *workdistributeBlock = &workdistribute.getRegion().front(); + auto *terminator = workdistributeBlock->getTerminator(); + while (&workdistributeBlock->front() != terminator) { + rewriter.setInsertionPoint(teams); + IRMapping mapping; + llvm::SmallVector hoisted; + Operation *parallelize = nullptr; + for (auto &op : workdistribute.getOps()) { + if (&op == terminator) { + break; + } + if (shouldParallelize(&op)) { + parallelize = &op; + break; + } else { + rewriter.clone(op, mapping); + hoisted.push_back(&op); + changed = true; + } + } + + for (auto *op : llvm::reverse(hoisted)) { + op->replaceAllUsesWith(mapping.lookup(op)); + op->erase(); + } + + if (parallelize && hoisted.empty() && + parallelize->getNextNode() == terminator) + break; + if (parallelize) { + auto newTeams = rewriter.cloneWithoutRegions(teams); + auto *newTeamsBlock = rewriter.createBlock( + &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {}); + for (auto arg : teamsBlock->getArguments()) + newTeamsBlock->addArgument(arg.getType(), arg.getLoc()); + auto newWorkdistribute = rewriter.create(loc); + rewriter.create(loc); + rewriter.createBlock(&newWorkdistribute.getRegion(), + newWorkdistribute.getRegion().begin(), {}, {}); + auto *cloned = rewriter.clone(*parallelize); + parallelize->replaceAllUsesWith(cloned); + parallelize->erase(); + rewriter.create(loc); + changed = true; + } + } + return changed; +} + +/// If fir.do_loop is present inside teams workdistribute +/// +/// omp.teams { +/// omp.workdistribute { +/// fir.do_loop unoredered { +/// ... +/// } +/// } +/// } +/// +/// Then, its lowered to +/// +/// omp.teams { +/// omp.parallel { +/// omp.distribute { +/// omp.wsloop { +/// omp.loop_nest +/// ... +/// } +/// } +/// } +/// } + +static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) { + auto parallelOp = rewriter.create(loc); + parallelOp.setComposite(composite); + rewriter.createBlock(¶llelOp.getRegion()); + rewriter.setInsertionPoint(rewriter.create(loc)); + return; +} + +static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) { + mlir::omp::DistributeOperands distributeClauseOps; + auto distributeOp = + rewriter.create(loc, distributeClauseOps); + distributeOp.setComposite(composite); + auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion()); + rewriter.setInsertionPointToStart(distributeBlock); + return; +} + +static void +genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop, + mlir::omp::LoopNestOperands &loopNestClauseOps) { + assert(loopNestClauseOps.loopLowerBounds.empty() && + "Loop nest bounds were already emitted!"); + loopNestClauseOps.loopLowerBounds.push_back(loop.getLowerBound()); + loopNestClauseOps.loopUpperBounds.push_back(loop.getUpperBound()); + loopNestClauseOps.loopSteps.push_back(loop.getStep()); + loopNestClauseOps.loopInclusive = rewriter.getUnitAttr(); +} + +static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, + const mlir::omp::LoopNestOperands &clauseOps, + bool composite) { + + auto wsloopOp = rewriter.create(doLoop.getLoc()); + wsloopOp.setComposite(composite); + rewriter.createBlock(&wsloopOp.getRegion()); + + auto loopNestOp = + rewriter.create(doLoop.getLoc(), clauseOps); + + // Clone the loop's body inside the loop nest construct using the + // mapped values. + rewriter.cloneRegionBefore(doLoop.getRegion(), loopNestOp.getRegion(), + loopNestOp.getRegion().begin()); + Block *clonedBlock = &loopNestOp.getRegion().back(); + mlir::Operation *terminatorOp = clonedBlock->getTerminator(); + + // Erase fir.result op of do loop and create yield op. + if (auto resultOp = dyn_cast(terminatorOp)) { + rewriter.setInsertionPoint(terminatorOp); + rewriter.create(doLoop->getLoc()); + // rewriter.erase(terminatorOp); + terminatorOp->erase(); + } + return; +} + +static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) { + OpBuilder rewriter(workdistribute); + auto doLoop = getPerfectlyNested(workdistribute); + auto wdLoc = workdistribute->getLoc(); + if (doLoop && shouldParallelize(doLoop)) { + assert(doLoop.getReduceOperands().empty()); + genParallelOp(wdLoc, rewriter, true); + genDistributeOp(wdLoc, rewriter, true); + mlir::omp::LoopNestOperands loopNestClauseOps; + genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps); + genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true); + workdistribute.erase(); + return true; + } + return false; +} + +/// If A() and B () are present inside teams workdistribute +/// +/// omp.teams { +/// omp.workdistribute { +/// A() +/// B() +/// } +/// } +/// +/// Then, its lowered to +/// +/// A() +/// B() +/// + +static bool TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp) { + auto workdistributeOp = getPerfectlyNested(teamsOp); + if (!workdistributeOp) + return false; + // Get the block containing teamsOp (the parent block). + Block *parentBlock = teamsOp->getBlock(); + Block &workdistributeBlock = *workdistributeOp.getRegion().begin(); + auto insertPoint = Block::iterator(teamsOp); + // Get the range of operations to move (excluding the terminator). + auto workdistributeBegin = workdistributeBlock.begin(); + auto workdistributeEnd = workdistributeBlock.getTerminator()->getIterator(); + // Move the operations from workdistribute block to before teamsOp. + parentBlock->getOperations().splice(insertPoint, + workdistributeBlock.getOperations(), + workdistributeBegin, workdistributeEnd); + // Erase the now-empty workdistributeOp. + workdistributeOp.erase(); + Block &teamsBlock = *teamsOp.getRegion().begin(); + // Check if only the terminator remains and erase teams op. + if (teamsBlock.getOperations().size() == 1 && + teamsBlock.getTerminator() != nullptr) { + teamsOp.erase(); + } + return true; +} + +struct SplitTargetResult { + omp::TargetOp targetOp; + omp::TargetDataOp dataOp; +}; + +/// If multiple workdistribute are nested in a target regions, we will need to +/// split the target region, but we want to preserve the data semantics of the +/// original data region and avoid unnecessary data movement at each of the +/// subkernels - we split the target region into a target_data{target} +/// nest where only the outer one moves the data +std::optional splitTargetData(omp::TargetOp targetOp, + RewriterBase &rewriter) { + auto loc = targetOp->getLoc(); + if (targetOp.getMapVars().empty()) { + LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " target region has no data maps\n"); + return std::nullopt; + } + + SmallVector mapInfos; + for (auto opr : targetOp.getMapVars()) { + auto mapInfo = cast(opr.getDefiningOp()); + mapInfos.push_back(mapInfo); + } + + rewriter.setInsertionPoint(targetOp); + SmallVector innerMapInfos; + SmallVector outerMapInfos; + + for (auto mapInfo : mapInfos) { + auto originalMapType = + (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType()); + auto originalCaptureType = mapInfo.getMapCaptureType(); + llvm::omp::OpenMPOffloadMappingFlags newMapType; + mlir::omp::VariableCaptureKind newCaptureType; + + if (originalCaptureType == mlir::omp::VariableCaptureKind::ByCopy) { + newMapType = originalMapType; + newCaptureType = originalCaptureType; + } else if (originalCaptureType == mlir::omp::VariableCaptureKind::ByRef) { + newMapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + newCaptureType = originalCaptureType; + outerMapInfos.push_back(mapInfo); + } else { + llvm_unreachable("Unhandled case"); + } + auto innerMapInfo = cast(rewriter.clone(*mapInfo)); + innerMapInfo.setMapTypeAttr(rewriter.getIntegerAttr( + rewriter.getIntegerType(64, false), + static_cast< + std::underlying_type_t>( + newMapType))); + innerMapInfo.setMapCaptureType(newCaptureType); + innerMapInfos.push_back(innerMapInfo.getResult()); + } + + rewriter.setInsertionPoint(targetOp); + auto device = targetOp.getDevice(); + auto ifExpr = targetOp.getIfExpr(); + auto deviceAddrVars = targetOp.getHasDeviceAddrVars(); + auto devicePtrVars = targetOp.getIsDevicePtrVars(); + auto targetDataOp = rewriter.create( + loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars); + auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion()); + rewriter.create(loc); + rewriter.setInsertionPointToStart(taregtDataBlock); + + auto newTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), + targetOp.getHostEvalVars(), targetOp.getIfExpr(), + targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(), + targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), + innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), + targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); + rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(), + newTargetOp.getRegion().begin()); + + rewriter.replaceOp(targetOp, newTargetOp); + return SplitTargetResult{cast(newTargetOp), targetDataOp}; +} + +static std::optional> +getNestedOpToIsolate(omp::TargetOp targetOp) { + if (targetOp.getRegion().empty()) + return std::nullopt; + auto *targetBlock = &targetOp.getRegion().front(); + for (auto &op : *targetBlock) { + bool first = &op == &*targetBlock->begin(); + bool last = op.getNextNode() == targetBlock->getTerminator(); + if (first && last) + return std::nullopt; + + if (isa(&op)) + return {{&op, first, last}}; + } + return std::nullopt; +} + +struct TempOmpVar { + omp::MapInfoOp from, to; +}; + +static bool isPtr(Type ty) { + return isa(ty) || isa(ty); +} + +static Type getPtrTypeForOmp(Type ty) { + if (isPtr(ty)) + return LLVM::LLVMPointerType::get(ty.getContext()); + else + return fir::LLVMPointerType::get(ty); +} + +static TempOmpVar +allocateTempOmpVar(Location loc, Type ty, RewriterBase &rewriter) { + MLIRContext& ctx = *ty.getContext(); + Value alloc; + Type allocType; + auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx); + if (isPtr(ty)) { + Type intTy = rewriter.getI32Type(); + auto one = rewriter.create(loc, intTy, 1); + allocType = llvmPtrTy; + alloc = rewriter.create(loc, llvmPtrTy, allocType, one); + allocType = intTy; + } + else { + allocType = ty; + alloc = rewriter.create(loc, allocType); + } + auto getMapInfo = [&](uint64_t mappingFlags, const char *name) { + return rewriter.create( + loc, alloc.getType(), alloc, + TypeAttr::get(allocType), + rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false), mappingFlags), + rewriter.getAttr( + omp::VariableCaptureKind::ByRef), + /*varPtrPtr=*/Value{}, + /*members=*/SmallVector{}, + /*member_index=*/mlir::ArrayAttr{}, + /*bounds=*/ValueRange(), + /*mapperId=*/mlir::FlatSymbolRefAttr(), + /*name=*/rewriter.getStringAttr(name), + rewriter.getBoolAttr(false)); + }; + uint64_t mapFrom = static_cast>(llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); + uint64_t mapTo = static_cast>(llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); + auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from"); + auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to"); + return TempOmpVar{mapInfoFrom, mapInfoTo}; +}; + +static bool usedOutsideSplit(Value v, Operation *split) { + if (!split) + return false; + auto targetOp = cast(split->getParentOp()); + auto *targetBlock = &targetOp.getRegion().front(); + for (auto *user : v.getUsers()) { + while (user->getBlock() != targetBlock) { + user = user->getParentOp(); + } + if (!user->isBeforeInBlock(split)) + return true; + } + return false; +}; + +static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) { + if (isa(op)) + return true; + + llvm::SmallVector effects; + MemoryEffectOpInterface interface = dyn_cast(op); + if (!interface) { + return false; + } + interface.getEffects(effects); + if (effects.empty()) + return true; + return false; +} + +struct SplitResult { + omp::TargetOp preTargetOp; + omp::TargetOp isolatedTargetOp; + omp::TargetOp postTargetOp; +}; + +static void collectNonRecomputableDeps(Value& v, + omp::TargetOp targetOp, + SetVector& nonRecomputable, + SetVector& toCache, + SetVector& toRecompute) { + Operation *op = v.getDefiningOp(); + if (!op) { + assert(cast(v).getOwner()->getParentOp() == targetOp); + return; + } + if (nonRecomputable.contains(op)) { + toCache.insert(op); + return; + } + toRecompute.insert(op); + for (auto opr : op->getOperands()) + collectNonRecomputableDeps(opr, targetOp, nonRecomputable, toCache, toRecompute); +} + + +static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter, + MLIRContext& ctx, + IRMapping &mapping, Operation *splitBefore, + Block *targetBlock, Block *newTargetBlock, + SmallVector& allocs, + SetVector& toRecompute) { + for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) { + auto originalArg = targetBlock->getArgument(i); + auto newArg = newTargetBlock->addArgument(originalArg.getType(), + originalArg.getLoc()); + mapping.map(originalArg, newArg); + } + auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx); + for (auto original : allocs) { + Value newArg = newTargetBlock->addArgument( + getPtrTypeForOmp(original.getType()), original.getLoc()); + Value restored; + if (isPtr(original.getType())) { + restored = rewriter.create(loc, llvmPtrTy, newArg); + if (!isa(original.getType())) + restored = rewriter.create(loc, original.getType(), restored); + } + else { + restored = rewriter.create(loc, newArg); + } + mapping.map(original, restored); + } + for (auto it = targetBlock->begin(); it != splitBefore->getIterator(); it++) { + if (toRecompute.contains(&*it)) + rewriter.clone(*it, mapping); + } +} + +static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter, + RewriterBase &rewriter) { + auto targetOp = cast(splitBeforeOp->getParentOp()); + MLIRContext& ctx = *targetOp.getContext(); + assert(targetOp); + auto loc = targetOp.getLoc(); + auto *targetBlock = &targetOp.getRegion().front(); + rewriter.setInsertionPoint(targetOp); + + auto preMapOperands = SmallVector(targetOp.getMapVars()); + auto postMapOperands = SmallVector(targetOp.getMapVars()); + + SmallVector requiredVals; + SetVector toCache; + SetVector toRecompute; + SetVector nonRecomputable; + SmallVector allocs; + + for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); it++) { + for (auto res : it->getResults()) { + if (usedOutsideSplit(res, splitBeforeOp)) + requiredVals.push_back(res); + } + if (!isRecomputableAfterFission(&*it, splitBeforeOp)) + nonRecomputable.insert(&*it); + } + + for (auto requiredVal : requiredVals) + collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache, toRecompute); + + for (Operation *op : toCache) { + for (auto res : op->getResults()) { + auto alloc = allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter); + allocs.push_back(res); + preMapOperands.push_back(alloc.from); + postMapOperands.push_back(alloc.to); + } + } + + rewriter.setInsertionPoint(targetOp); + + auto preTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), + targetOp.getBareAttr(), targetOp.getDependKindsAttr(), + targetOp.getDependVars(), targetOp.getDevice(), + targetOp.getHasDeviceAddrVars(), targetOp.getHostEvalVars(), + targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), preMapOperands, + targetOp.getNowaitAttr(), targetOp.getPrivateVars(), + targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); + auto *preTargetBlock = rewriter.createBlock( + &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {}); + IRMapping preMapping; + for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) { + auto originalArg = targetBlock->getArgument(i); + auto newArg = preTargetBlock->addArgument(originalArg.getType(), + originalArg.getLoc()); + preMapping.map(originalArg, newArg); + } + for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); it++) + rewriter.clone(*it, preMapping); + + auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext()); + + + for (auto original : allocs) { + Value toStore = preMapping.lookup(original); + auto newArg = preTargetBlock->addArgument( + getPtrTypeForOmp(original.getType()), original.getLoc()); + if (isPtr(original.getType())) { + if (!isa(toStore.getType())) + toStore = rewriter.create(loc, llvmPtrTy, toStore); + rewriter.create(loc, toStore, newArg); + } else { + rewriter.create(loc, toStore, newArg); + } + } + rewriter.create(loc); + + rewriter.setInsertionPoint(targetOp); + + auto isolatedTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), + targetOp.getBareAttr(), targetOp.getDependKindsAttr(), + targetOp.getDependVars(), targetOp.getDevice(), + targetOp.getHasDeviceAddrVars(), targetOp.getHostEvalVars(), + targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), postMapOperands, + targetOp.getNowaitAttr(), targetOp.getPrivateVars(), + targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); + + auto *isolatedTargetBlock = + rewriter.createBlock(&isolatedTargetOp.getRegion(), + isolatedTargetOp.getRegion().begin(), {}, {}); + + IRMapping isolatedMapping; + reloadCacheAndRecompute(loc, rewriter, ctx, isolatedMapping, splitBeforeOp, + targetBlock, isolatedTargetBlock, + allocs, toRecompute); + rewriter.clone(*splitBeforeOp, isolatedMapping); + rewriter.create(loc); + + omp::TargetOp postTargetOp = nullptr; + + if (splitAfter) { + rewriter.setInsertionPoint(targetOp); + postTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), + targetOp.getBareAttr(), targetOp.getDependKindsAttr(), + targetOp.getDependVars(), targetOp.getDevice(), + targetOp.getHasDeviceAddrVars(), targetOp.getHostEvalVars(), + targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), postMapOperands, + targetOp.getNowaitAttr(), targetOp.getPrivateVars(), + targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); + auto *postTargetBlock = rewriter.createBlock( + &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {}); + IRMapping postMapping; + reloadCacheAndRecompute(loc, rewriter, ctx, postMapping, splitBeforeOp, + targetBlock, postTargetBlock, + allocs, toRecompute); + + assert(splitBeforeOp->getNumResults() == 0 || + llvm::all_of(splitBeforeOp->getResults(), + [](Value result) { return result.use_empty(); })); + + for (auto it = std::next(splitBeforeOp->getIterator()); + it != targetBlock->end(); it++) + rewriter.clone(*it, postMapping); + } + + rewriter.eraseOp(targetOp); + return SplitResult{preTargetOp, isolatedTargetOp, postTargetOp}; +} + +static mlir::LLVM::ConstantOp +genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { + mlir::Type i32Ty = rewriter.getI32Type(); + mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value); + return rewriter.create(loc, i32Ty, attr); +} + +static Type getOmpDeviceType(MLIRContext *c) { return IntegerType::get(c, 32); } + +static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) { + OpBuilder::InsertionGuard guard(rewriter); + Block *targetBlock = &targetOp.getRegion().front(); + assert(targetBlock == &targetOp.getRegion().back()); + IRMapping mapping; + for (auto map : + zip_equal(targetOp.getMapVars(), targetBlock->getArguments())) { + Value mapInfo = std::get<0>(map); + BlockArgument arg = std::get<1>(map); + Operation *op = mapInfo.getDefiningOp(); + assert(op); + auto mapInfoOp = cast(op); + mapping.map(arg, mapInfoOp.getVarPtr()); + } + rewriter.setInsertionPoint(targetOp); + SmallVector opsToMove; + for (auto it = targetBlock->begin(), end = std::prev(targetBlock->end()); + it != end; ++it) { + auto *op = &*it; + auto allocOp = dyn_cast(op); + auto freeOp = dyn_cast(op); + fir::CallOp runtimeCall = nullptr; + if (isRuntimeCall(op)) + runtimeCall = cast(op); + + if (allocOp || freeOp || runtimeCall) { + Value device = targetOp.getDevice(); + if (!device) { + device = genI32Constant(it->getLoc(), rewriter, 0); + } + if (allocOp) { + auto tmpAllocOp = rewriter.create( + allocOp.getLoc(), allocOp.getType(), device, + allocOp.getInTypeAttr(), allocOp.getUniqNameAttr(), + allocOp.getBindcNameAttr(), allocOp.getTypeparams(), + allocOp.getShape()); + auto newAllocOp = cast( + rewriter.clone(*tmpAllocOp.getOperation(), mapping)); + mapping.map(allocOp.getResult(), newAllocOp.getResult()); + rewriter.eraseOp(tmpAllocOp); + } else if (freeOp) { + auto tmpFreeOp = rewriter.create( + freeOp.getLoc(), device, freeOp.getHeapref()); + rewriter.clone(*tmpFreeOp.getOperation(), mapping); + rewriter.eraseOp(tmpFreeOp); + } else if (runtimeCall) { + auto module = runtimeCall->getParentOfType(); + auto callee = cast( + module.lookupSymbol(runtimeCall.getCalleeAttr())); + std::string newCalleeName = (callee.getName() + "_omp").str(); + mlir::OpBuilder moduleBuilder(module.getBodyRegion()); + func::FuncOp newCallee = + cast_or_null(module.lookupSymbol(newCalleeName)); + if (!newCallee) { + SmallVector argTypes(callee.getFunctionType().getInputs()); + argTypes.push_back(getOmpDeviceType(rewriter.getContext())); + newCallee = moduleBuilder.create( + callee->getLoc(), newCalleeName, + FunctionType::get(rewriter.getContext(), argTypes, + callee.getFunctionType().getResults())); + if (callee.getArgAttrs()) + newCallee.setArgAttrsAttr(*callee.getArgAttrs()); + if (callee.getResAttrs()) + newCallee.setResAttrsAttr(*callee.getResAttrs()); + newCallee.setSymVisibility(callee.getSymVisibility()); + newCallee->setDiscardableAttrs( + callee->getDiscardableAttrDictionary()); + } + SmallVector operands = runtimeCall.getOperands(); + operands.push_back(device); + auto tmpCall = rewriter.create( + runtimeCall.getLoc(), runtimeCall.getResultTypes(), + SymbolRefAttr::get(newCallee), operands, nullptr, nullptr, nullptr, + runtimeCall.getFastmathAttr()); + Operation *newCall = rewriter.clone(*tmpCall, mapping); + mapping.map(&*it, newCall); + rewriter.eraseOp(tmpCall); + } + } else { + Operation *clonedOp = rewriter.clone(*op, mapping); + for (unsigned i = 0; i < op->getNumResults(); ++i) { + mapping.map(op->getResult(i), clonedOp->getResult(i)); + } + } + } + rewriter.eraseOp(targetOp); +} + +void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) { + auto tuple = getNestedOpToIsolate(targetOp); + if (!tuple) { + LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n"); + moveToHost(targetOp, rewriter); + return; + } + + Operation *toIsolate = std::get<0>(*tuple); + bool splitBefore = !std::get<1>(*tuple); + bool splitAfter = !std::get<2>(*tuple); + + if (splitBefore && splitAfter) { + auto res = isolateOp(toIsolate, splitAfter, rewriter); + moveToHost(res.preTargetOp, rewriter); + fissionTarget(res.postTargetOp, rewriter); + return; + } + if (splitBefore) { + auto res = isolateOp(toIsolate, splitAfter, rewriter); + moveToHost(res.preTargetOp, rewriter); + return; + } + if (splitAfter) { + assert(false && "TODO"); + auto res = isolateOp(toIsolate->getNextNode(), splitAfter, rewriter); + fissionTarget(res.postTargetOp, rewriter); + return; + } +} + +class LowerWorkdistributePass + : public flangomp::impl::LowerWorkdistributeBase { +public: + void runOnOperation() override { + MLIRContext &context = getContext(); + auto moduleOp = getOperation(); + bool changed = false; + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + changed |= FissionWorkdistribute(workdistribute); + }); + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + changed |= WorkdistributeDoLower(workdistribute); + }); + moduleOp->walk([&](mlir::omp::TeamsOp teams) { + changed |= TeamsWorkdistributeToSingleOp(teams); + }); + + if (changed) { + SmallVector targetOps; + moduleOp->walk( + [&](omp::TargetOp targetOp) { targetOps.push_back(targetOp); }); + IRRewriter rewriter(&context); + for (auto targetOp : targetOps) { + auto res = splitTargetData(targetOp, rewriter); + if (res) fissionTarget(res->targetOp, rewriter); + } + } + } +}; +} // namespace diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index 70f57bdeddd3f..c63e3799be650 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -288,8 +288,10 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP, addNestedPassToAllTopLevelOperations( pm, hlfir::createInlineHLFIRAssign); pm.addPass(hlfir::createConvertHLFIRtoFIR()); - if (enableOpenMP) + if (enableOpenMP) { pm.addPass(flangomp::createLowerWorkshare()); + pm.addPass(flangomp::createLowerWorkdistribute()); + } } /// Create a pass pipeline for handling certain OpenMP transformations needed diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir index 7ac8b92f48953..a611629eeb280 100644 --- a/flang/test/Fir/basic-program.fir +++ b/flang/test/Fir/basic-program.fir @@ -69,6 +69,7 @@ func.func @_QQmain() { // PASSES-NEXT: InlineHLFIRAssign // PASSES-NEXT: ConvertHLFIRtoFIR // PASSES-NEXT: LowerWorkshare +// PASSES-NEXT: LowerWorkdistribute // PASSES-NEXT: CSE // PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd // PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir new file mode 100644 index 0000000000000..00d10d6264ec9 --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir @@ -0,0 +1,33 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// CHECK-LABEL: func.func @x({{.*}}) +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_1:.*]]) : index = (%[[ARG0:.*]]) to (%[[ARG1:.*]]) inclusive step (%[[ARG2:.*]]) { +// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index +// CHECK: fir.store %[[VAL_0]] to %[[ARG4:.*]] : !fir.ref +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } +func.func @x(%lb : index, %ub : index, %step : index, %b : i1, %addr : !fir.ref) { + omp.teams { + omp.workdistribute { + fir.do_loop %iv = %lb to %ub step %step unordered { + %zero = arith.constant 0 : index + fir.store %zero to %addr : !fir.ref + } + omp.terminator + } + omp.terminator + } + return +} diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir new file mode 100644 index 0000000000000..19bdb9ce10fbd --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir @@ -0,0 +1,112 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// CHECK-LABEL: func.func @x( +// CHECK: %[[VAL_0:.*]] = fir.alloca index {bindc_name = "lb"} +// CHECK: fir.store %[[ARG0:.*]] to %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_1:.*]] = fir.alloca index {bindc_name = "ub"} +// CHECK: fir.store %[[ARG1:.*]] to %[[VAL_1]] : !fir.ref +// CHECK: %[[VAL_2:.*]] = fir.alloca index {bindc_name = "step"} +// CHECK: fir.store %[[ARG2:.*]] to %[[VAL_2]] : !fir.ref +// CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "lb"} +// CHECK: %[[VAL_4:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "ub"} +// CHECK: %[[VAL_5:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} +// CHECK: %[[VAL_6:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} +// CHECK: %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "lb"} +// CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "ub"} +// CHECK: %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "step"} +// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "addr"} +// CHECK: omp.target_data map_entries(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { +// CHECK: %[[VAL_11:.*]] = fir.alloca index +// CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_13:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_14:.*]] = fir.alloca index +// CHECK: %[[VAL_15:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_16:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_17:.*]] = fir.alloca index +// CHECK: %[[VAL_18:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_19:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_20:.*]] = fir.alloca !fir.heap +// CHECK: %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref>, !fir.heap) map_clauses(from) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_22:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref>, !fir.heap) map_clauses(to) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_23:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_25:.*]] = fir.load %[[VAL_1]] : !fir.ref +// CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_2]] : !fir.ref +// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_25]] : index +// CHECK: %[[VAL_28:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[VAL_29:.*]] = "fir.omp_target_allocmem"(%[[VAL_28]], %[[VAL_23]]) <{in_type = index, operandSegmentSizes = array, uniq_name = "dev_buf"}> : (i32, index) -> !fir.heap +// CHECK: fir.store %[[VAL_24]] to %[[VAL_11]] : !fir.ref +// CHECK: fir.store %[[VAL_25]] to %[[VAL_14]] : !fir.ref +// CHECK: fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref +// CHECK: fir.store %[[VAL_29]] to %[[VAL_20]] : !fir.ref> +// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_30:.*]], %[[VAL_8]] -> %[[VAL_31:.*]], %[[VAL_9]] -> %[[VAL_32:.*]], %[[VAL_10]] -> %[[VAL_33:.*]], %[[VAL_13]] -> %[[VAL_34:.*]], %[[VAL_16]] -> %[[VAL_35:.*]], %[[VAL_19]] -> %[[VAL_36:.*]], %[[VAL_22]] -> %[[VAL_37:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { +// CHECK: %[[VAL_38:.*]] = fir.load %[[VAL_34]] : !fir.llvm_ptr +// CHECK: %[[VAL_39:.*]] = fir.load %[[VAL_35]] : !fir.llvm_ptr +// CHECK: %[[VAL_40:.*]] = fir.load %[[VAL_36]] : !fir.llvm_ptr +// CHECK: %[[VAL_41:.*]] = fir.load %[[VAL_37]] : !fir.llvm_ptr> +// CHECK: %[[VAL_42:.*]] = arith.addi %[[VAL_39]], %[[VAL_39]] : index +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_43:.*]]) : index = (%[[VAL_38]]) to (%[[VAL_39]]) inclusive step (%[[VAL_40]]) { +// CHECK: fir.store %[[VAL_42]] to %[[VAL_41]] : !fir.heap +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } +// CHECK: omp.terminator +// CHECK: } +// CHECK: %[[VAL_44:.*]] = fir.load %[[VAL_11]] : !fir.ref +// CHECK: %[[VAL_45:.*]] = fir.load %[[VAL_14]] : !fir.ref +// CHECK: %[[VAL_46:.*]] = fir.load %[[VAL_17]] : !fir.ref +// CHECK: %[[VAL_47:.*]] = fir.load %[[VAL_20]] : !fir.ref> +// CHECK: %[[VAL_48:.*]] = arith.addi %[[VAL_45]], %[[VAL_45]] : index +// CHECK: fir.store %[[VAL_44]] to %[[VAL_47]] : !fir.heap +// CHECK: %[[VAL_49:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: "fir.omp_target_freemem"(%[[VAL_49]], %[[VAL_47]]) : (i32, !fir.heap) -> () +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } + +func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref) { + %lb_ref = fir.alloca index {bindc_name = "lb"} + fir.store %lb to %lb_ref : !fir.ref + %ub_ref = fir.alloca index {bindc_name = "ub"} + fir.store %ub to %ub_ref : !fir.ref + %step_ref = fir.alloca index {bindc_name = "step"} + fir.store %step to %step_ref : !fir.ref + + %lb_map = omp.map.info var_ptr(%lb_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "lb"} + %ub_map = omp.map.info var_ptr(%ub_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "ub"} + %step_map = omp.map.info var_ptr(%step_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} + %addr_map = omp.map.info var_ptr(%addr : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} + + omp.target map_entries(%lb_map -> %ARG0, %ub_map -> %ARG1, %step_map -> %ARG2, %addr_map -> %ARG3 : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { + %lb_val = fir.load %ARG0 : !fir.ref + %ub_val = fir.load %ARG1 : !fir.ref + %step_val = fir.load %ARG2 : !fir.ref + %one = arith.constant 1 : index + + %20 = arith.addi %ub_val, %ub_val : index + omp.teams { + omp.workdistribute { + %dev_mem = fir.allocmem index, %one {uniq_name = "dev_buf"} + fir.do_loop %iv = %lb_val to %ub_val step %step_val unordered { + fir.store %20 to %dev_mem : !fir.heap + } + fir.store %lb_val to %dev_mem : !fir.heap + fir.freemem %dev_mem : !fir.heap + omp.terminator + } + omp.terminator + } + omp.terminator + } + return +} diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir new file mode 100644 index 0000000000000..c562b7009664d --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir @@ -0,0 +1,71 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// CHECK-LABEL: func.func @test_fission_workdistribute( +// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 9 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 5.000000e+00 : f32 +// CHECK: fir.store %[[VAL_3]] to %[[ARG2:.*]] : !fir.ref +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_4:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) { +// CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref>, index) -> !fir.ref +// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref +// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref>, index) -> !fir.ref +// CHECK: fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } +// CHECK: fir.call @regular_side_effect_func(%[[ARG2:.*]]) : (!fir.ref) -> () +// CHECK: fir.call @my_fir_parallel_runtime_func(%[[ARG3:.*]]) : (!fir.ref) -> () +// CHECK: fir.do_loop %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] { +// CHECK: %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0]], %[[VAL_8]] : (!fir.ref>, index) -> !fir.ref +// CHECK: fir.store %[[VAL_3]] to %[[VAL_9]] : !fir.ref +// CHECK: } +// CHECK: %[[VAL_10:.*]] = fir.load %[[ARG2:.*]] : !fir.ref +// CHECK: fir.store %[[VAL_10]] to %[[ARG3:.*]] : !fir.ref +// CHECK: return +// CHECK: } +module { +func.func @regular_side_effect_func(%arg0: !fir.ref) { + return +} +func.func @my_fir_parallel_runtime_func(%arg0: !fir.ref) attributes {fir.runtime} { + return +} +func.func @test_fission_workdistribute(%arr1: !fir.ref>, %arr2: !fir.ref>, %scalar_ref1: !fir.ref, %scalar_ref2: !fir.ref) { + %c0_idx = arith.constant 0 : index + %c1_idx = arith.constant 1 : index + %c9_idx = arith.constant 9 : index + %float_val = arith.constant 5.0 : f32 + omp.teams { + omp.workdistribute { + fir.store %float_val to %scalar_ref1 : !fir.ref + fir.do_loop %iv = %c0_idx to %c9_idx step %c1_idx unordered { + %elem_ptr_arr1 = fir.coordinate_of %arr1, %iv : (!fir.ref>, index) -> !fir.ref + %loaded_val_loop1 = fir.load %elem_ptr_arr1 : !fir.ref + %elem_ptr_arr2 = fir.coordinate_of %arr2, %iv : (!fir.ref>, index) -> !fir.ref + fir.store %loaded_val_loop1 to %elem_ptr_arr2 : !fir.ref + } + fir.call @regular_side_effect_func(%scalar_ref1) : (!fir.ref) -> () + fir.call @my_fir_parallel_runtime_func(%scalar_ref2) : (!fir.ref) -> () + fir.do_loop %jv = %c0_idx to %c9_idx step %c1_idx { + %elem_ptr_ordered_loop = fir.coordinate_of %arr1, %jv : (!fir.ref>, index) -> !fir.ref + fir.store %float_val to %elem_ptr_ordered_loop : !fir.ref + } + %loaded_for_hoist = fir.load %scalar_ref1 : !fir.ref + fir.store %loaded_for_hoist to %scalar_ref2 : !fir.ref + omp.terminator + } + omp.terminator + } + return +} +} diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir new file mode 100644 index 0000000000000..d96068b26ca2f --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir @@ -0,0 +1,32 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// CHECK-LABEL: func.func @test_nested_derived_type_map_operand_and_block_addition( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref}>>) { +// CHECK: %[[VAL_0:.*]] = fir.declare %[[ARG0]] {uniq_name = "_QFmaptype_derived_nested_explicit_multiple_membersEsa"} : (!fir.ref}>>) -> !fir.ref}>> +// CHECK: %[[VAL_1:.*]] = fir.coordinate_of %[[VAL_0]], n : (!fir.ref}>>) -> !fir.ref> +// CHECK: %[[VAL_2:.*]] = fir.coordinate_of %[[VAL_1]], i : (!fir.ref>) -> !fir.ref +// CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "sa%[[VAL_4:.*]]%[[VAL_5:.*]]"} +// CHECK: %[[VAL_6:.*]] = fir.coordinate_of %[[VAL_0]], n : (!fir.ref}>>) -> !fir.ref> +// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[VAL_6]], r : (!fir.ref>) -> !fir.ref +// CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_7]] : !fir.ref, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "sa%[[VAL_4]]%[[VAL_9:.*]]"} +// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref, !fir.ref) -> !fir.ref}>> {name = "sa", partial_map = true} +// CHECK: omp.target map_entries(%[[VAL_10]] -> %[[VAL_11:.*]] : !fir.ref}>>) { +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } + +func.func @test_nested_derived_type_map_operand_and_block_addition(%arg0: !fir.ref}>>) { + %0 = fir.declare %arg0 {uniq_name = "_QFmaptype_derived_nested_explicit_multiple_membersEsa"} : (!fir.ref}>>) -> !fir.ref}>> + %2 = fir.coordinate_of %0, n : (!fir.ref}>>) -> !fir.ref> + %4 = fir.coordinate_of %2, i : (!fir.ref>) -> !fir.ref + %5 = omp.map.info var_ptr(%4 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "sa%n%i"} + %7 = fir.coordinate_of %0, n : (!fir.ref}>>) -> !fir.ref> + %9 = fir.coordinate_of %7, r : (!fir.ref>) -> !fir.ref + %10 = omp.map.info var_ptr(%9 : !fir.ref, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "sa%n%r"} + %11 = omp.map.info var_ptr(%0 : !fir.ref}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%5, %10 : [1,0], [1,1] : !fir.ref, !fir.ref) -> !fir.ref}>> {name = "sa", partial_map = true} + omp.target map_entries(%11 -> %arg1 : !fir.ref}>>) { + omp.terminator + } + return +} diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index eece8573f00ec..3fed83112dc97 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -5246,6 +5246,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, omp::LoopNestOp loopOp = castOrGetParentOfType(capturedOp); unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0; + if (targetOp.getHostEvalVars().empty()) + numLoops = 0; + Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit; llvm::SmallVector lowerBounds(numLoops), upperBounds(numLoops), steps(numLoops);