@@ -234,7 +234,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
234
234
}
235
235
};
236
236
auto checkReduction = [&todo](auto op, LogicalResult &result) {
237
- if (isa<omp::TeamsOp>(op) || isa<omp::SimdOp>(op) )
237
+ if (isa<omp::TeamsOp>(op))
238
238
if (!op.getReductionVars ().empty () || op.getReductionByref () ||
239
239
op.getReductionSyms ())
240
240
result = todo (" reduction" );
@@ -313,10 +313,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
313
313
if (!op.getNontemporalVars ().empty ())
314
314
op.emitWarning ()
315
315
<< " ignored clause: nontemporal in omp.simd operation" ;
316
-
317
- if (!op.getReductionVars ().empty () || op.getReductionByref () ||
318
- op.getReductionSyms ())
319
- op.emitWarning () << " ignored clause: reduction in omp.simd operation" ;
320
316
})
321
317
.Case <omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
322
318
omp::AtomicCaptureOp>([&](auto op) { checkHint (op, result); })
@@ -2693,17 +2689,19 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
2693
2689
if (failed (checkImplementationStatus (opInst)))
2694
2690
return failure ();
2695
2691
2696
- // This is needed to make sure that uses of entry block arguments for the
2697
- // reduction clause, which is not yet being translated, are mapped to the
2698
- // outside values. This has the effect of ignoring the clause without causing
2699
- // a compiler crash.
2700
- auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*simdOp);
2701
- for (auto [arg, var] : llvm::zip_equal (blockArgIface.getReductionBlockArgs (),
2702
- simdOp.getReductionVars ()))
2703
- moduleTranslation.mapValue (arg, moduleTranslation.lookupValue (var));
2704
-
2705
2692
PrivateVarsInfo privateVarsInfo (simdOp);
2706
2693
2694
+ MutableArrayRef<BlockArgument> reductionArgs =
2695
+ cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs ();
2696
+ DenseMap<Value, llvm::Value *> reductionVariableMap;
2697
+ SmallVector<llvm::Value *> privateReductionVariables (
2698
+ simdOp.getNumReductionVars ());
2699
+ SmallVector<DeferredStore> deferredStores;
2700
+ SmallVector<omp::DeclareReductionOp> reductionDecls;
2701
+ collectReductionDecls (simdOp, reductionDecls);
2702
+ llvm::ArrayRef<bool > isByRef = getIsByRef (simdOp.getReductionByref ());
2703
+ assert (isByRef.size () == simdOp.getNumReductionVars ());
2704
+
2707
2705
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2708
2706
findAllocaInsertPoint (builder, moduleTranslation);
2709
2707
@@ -2712,11 +2710,27 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
2712
2710
if (handleError (afterAllocas, opInst).failed ())
2713
2711
return failure ();
2714
2712
2713
+ if (failed (allocReductionVars (simdOp, reductionArgs, builder,
2714
+ moduleTranslation, allocaIP, reductionDecls,
2715
+ privateReductionVariables, reductionVariableMap,
2716
+ deferredStores, isByRef)))
2717
+ return failure ();
2718
+
2715
2719
if (handleError (initPrivateVars (builder, moduleTranslation, privateVarsInfo),
2716
2720
opInst)
2717
2721
.failed ())
2718
2722
return failure ();
2719
2723
2724
+ // TODO: no call to copyFirstPrivateVars?
2725
+
2726
+ assert (afterAllocas.get ()->getSinglePredecessor ());
2727
+ if (failed (initReductionVars (simdOp, reductionArgs, builder,
2728
+ moduleTranslation,
2729
+ afterAllocas.get ()->getSinglePredecessor (),
2730
+ reductionDecls, privateReductionVariables,
2731
+ reductionVariableMap, isByRef, deferredStores)))
2732
+ return failure ();
2733
+
2720
2734
llvm::ConstantInt *simdlen = nullptr ;
2721
2735
if (std::optional<uint64_t > simdlenVar = simdOp.getSimdlen ())
2722
2736
simdlen = builder.getInt64 (simdlenVar.value ());
@@ -2761,6 +2775,50 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
2761
2775
: nullptr ,
2762
2776
order, simdlen, safelen);
2763
2777
2778
+ // We now need to reduce the per-simd-lane reduction variable into the
2779
+ // original variable. This works a bit differently to other reductions (e.g.
2780
+ // wsloop) because we don't need to call into the OpenMP runtime to handle
2781
+ // threads: everything happened in this one thread.
2782
+ for (auto [i, tuple] : llvm::enumerate (
2783
+ llvm::zip (reductionDecls, isByRef, simdOp.getReductionVars (),
2784
+ privateReductionVariables))) {
2785
+ auto [decl, byRef, reductionVar, privateReductionVar] = tuple;
2786
+
2787
+ OwningReductionGen gen = makeReductionGen (decl, builder, moduleTranslation);
2788
+ llvm::Value *originalVariable = moduleTranslation.lookupValue (reductionVar);
2789
+ llvm::Type *reductionType = moduleTranslation.convertType (decl.getType ());
2790
+
2791
+ // We have one less load for by-ref case because that load is now inside of
2792
+ // the reduction region.
2793
+ llvm::Value *redValue = originalVariable;
2794
+ if (!byRef)
2795
+ redValue =
2796
+ builder.CreateLoad (reductionType, redValue, " red.value." + Twine (i));
2797
+ llvm::Value *privateRedValue = builder.CreateLoad (
2798
+ reductionType, privateReductionVar, " red.private.value." + Twine (i));
2799
+ llvm::Value *reduced;
2800
+
2801
+ auto res = gen (builder.saveIP (), redValue, privateRedValue, reduced);
2802
+ if (failed (handleError (res, opInst)))
2803
+ return failure ();
2804
+ builder.restoreIP (res.get ());
2805
+
2806
+ // For by-ref case, the store is inside of the reduction region.
2807
+ if (!byRef)
2808
+ builder.CreateStore (reduced, originalVariable);
2809
+ }
2810
+
2811
+ // After the construct, deallocate private reduction variables.
2812
+ SmallVector<Region *> reductionRegions;
2813
+ llvm::transform (reductionDecls, std::back_inserter (reductionRegions),
2814
+ [](omp::DeclareReductionOp reductionDecl) {
2815
+ return &reductionDecl.getCleanupRegion ();
2816
+ });
2817
+ if (failed (inlineOmpRegionCleanup (reductionRegions, privateReductionVariables,
2818
+ moduleTranslation, builder,
2819
+ " omp.reduction.cleanup" )))
2820
+ return failure ();
2821
+
2764
2822
return cleanupPrivateVars (builder, moduleTranslation, simdOp.getLoc (),
2765
2823
privateVarsInfo.llvmVars ,
2766
2824
privateVarsInfo.privatizers );
0 commit comments