15
15
#include " flang/Optimizer/HLFIR/HLFIRDialect.h"
16
16
#include " flang/Optimizer/HLFIR/HLFIROps.h"
17
17
#include " flang/Optimizer/Transforms/Passes.h"
18
+ #include " mlir/Analysis/SliceAnalysis.h"
19
+ #include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18
20
#include " mlir/Dialect/Func/IR/FuncOps.h"
21
+ #include " mlir/Dialect/Math/IR/Math.h"
19
22
#include " mlir/Dialect/OpenMP/OpenMPDialect.h"
20
23
#include " mlir/IR/Diagnostics.h"
21
24
#include " mlir/IR/IRMapping.h"
@@ -468,6 +471,77 @@ void sinkLoopIVArgs(mlir::ConversionPatternRewriter &rewriter,
468
471
++idx;
469
472
}
470
473
}
474
+
475
+ // / Collects values are that are destroyed by the Fortran runtime within the
476
+ // / loop's scope.
477
+ // /
478
+ // / \param [in] doLoop - the loop within which the function searches for locally
479
+ // / destroyed values.
480
+ // /
481
+ // / \param [out] local - a map from locally destroyed values to the runtime
482
+ // / destroy opertaions that destroy them.
483
+ void collectLocallyDestroyedValuesInLoop (
484
+ fir::DoLoopOp doLoop,
485
+ llvm::DenseMap<mlir::Value, mlir::Operation *> &locals) {
486
+ constexpr static auto destroy{" _FortranADestroy" };
487
+ doLoop.getRegion ().walk ([&](fir::CallOp call) {
488
+ auto callee = call.getCallee ();
489
+
490
+ if (!callee.has_value ())
491
+ return ;
492
+
493
+ if (callee.value ().getLeafReference ().str () != destroy)
494
+ return ;
495
+
496
+ assert (call.getNumOperands () == 1 );
497
+
498
+ mlir::BackwardSliceOptions options;
499
+ options.inclusive = true ;
500
+ llvm::SetVector<mlir::Operation *> opSlice;
501
+ mlir::getBackwardSlice (call, &opSlice, options);
502
+
503
+ if (auto alloca = mlir::dyn_cast_if_present<fir::AllocaOp>(opSlice.front ()))
504
+ locals.try_emplace (alloca.getResult (), call);
505
+ });
506
+ }
507
+
508
+ // / For a locally destroyed value \p local within a loop's scope, localizes that
509
+ // / value within the scope of the parallel region the loop maps to. Towards that
510
+ // / end, this function allocates a private copy of \p local within \p
511
+ // / allocRegion.
512
+ // /
513
+ // / \param local - the locally destroyed value within a loop's scope (see
514
+ // / collectLocallyDestroyedValuesInLoop).
515
+ // /
516
+ // / \param localDestroyer - the Fortran runtime call operation that destroys \p
517
+ // / local.
518
+ // /
519
+ // / \param allocRegion - the parallel region where \p local's allocation will be
520
+ // / cloned (i.e. privatized).
521
+ // /
522
+ // / \param rewriter - builder used for updating \p allocRegion.
523
+ // /
524
+ // / \param mapper - mapper to track updated references \p local within \p
525
+ // / allocRegion.
526
+ void localizeLocallyDestroyedValue (mlir::Value local,
527
+ mlir::Operation *localDestroyer,
528
+ mlir::Region &allocRegion,
529
+ mlir::ConversionPatternRewriter &rewriter,
530
+ mlir::IRMapping &mapper) {
531
+ mlir::Region *loopRegion = localDestroyer->getParentRegion ();
532
+ assert (loopRegion != nullptr );
533
+
534
+ mlir::IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint ();
535
+ rewriter.setInsertionPointToStart (&allocRegion.front ());
536
+ mlir::Operation *newLocalDef = rewriter.clone (*local.getDefiningOp (), mapper);
537
+ rewriter.replaceUsesWithIf (
538
+ local, newLocalDef->getResult (0 ), [&](mlir::OpOperand &operand) {
539
+ return operand.getOwner ()->getParentRegion () == loopRegion;
540
+ });
541
+ mapper.map (local, newLocalDef->getResult (0 ));
542
+
543
+ rewriter.restoreInsertionPoint (ip);
544
+ }
471
545
} // namespace looputils
472
546
473
547
class DoConcurrentConversion : public mlir ::OpConversionPattern<fir::DoLoopOp> {
@@ -519,9 +593,14 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
519
593
bool hasRemainingNestedLoops =
520
594
failed (looputils::collectLoopNest (doLoop, loopNest));
521
595
596
+ mlir::IRMapping mapper;
597
+
598
+ llvm::DenseMap<mlir::Value, mlir::Operation *> locals;
599
+ looputils::collectLocallyDestroyedValuesInLoop (loopNest.back ().first ,
600
+ locals);
601
+
522
602
looputils::sinkLoopIVArgs (rewriter, loopNest);
523
603
524
- mlir::IRMapping mapper;
525
604
mlir::omp::TargetOp targetOp;
526
605
mlir::omp::LoopNestClauseOps loopNestClauseOps;
527
606
@@ -541,8 +620,13 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
541
620
genDistributeOp (doLoop.getLoc (), rewriter);
542
621
}
543
622
544
- genParallelOp (doLoop.getLoc (), rewriter, loopNest, mapper,
545
- loopNestClauseOps);
623
+ mlir::omp::ParallelOp parallelOp = genParallelOp (
624
+ doLoop.getLoc (), rewriter, loopNest, mapper, loopNestClauseOps);
625
+
626
+ for (auto &[local, localDestroyer] : locals)
627
+ looputils::localizeLocallyDestroyedValue (
628
+ local, localDestroyer, parallelOp.getRegion (), rewriter, mapper);
629
+
546
630
mlir::omp::LoopNestOp ompLoopNest =
547
631
genWsLoopOp (rewriter, loopNest.back ().first , mapper, loopNestClauseOps);
548
632
@@ -919,9 +1003,10 @@ class DoConcurrentConversionPass
919
1003
context, mapTo == fir::omp::DoConcurrentMappingKind::DCMK_Device,
920
1004
concurrentLoopsToSkip);
921
1005
mlir::ConversionTarget target (*context);
922
- target.addLegalDialect <fir::FIROpsDialect, hlfir::hlfirDialect,
923
- mlir::arith::ArithDialect, mlir::func::FuncDialect,
924
- mlir::omp::OpenMPDialect>();
1006
+ target.addLegalDialect <
1007
+ fir::FIROpsDialect, hlfir::hlfirDialect, mlir::arith::ArithDialect,
1008
+ mlir::func::FuncDialect, mlir::omp::OpenMPDialect,
1009
+ mlir::cf::ControlFlowDialect, mlir::math::MathDialect>();
925
1010
926
1011
target.addDynamicallyLegalOp <fir::DoLoopOp>([&](fir::DoLoopOp op) {
927
1012
return !op.getUnordered () || concurrentLoopsToSkip.contains (op);
0 commit comments