@@ -4557,11 +4557,10 @@ static std::optional<int64_t> extractConstInteger(Value value) {
4557
4557
// / function for the target region, so that they can be used to initialize the
4558
4558
// / corresponding global `ConfigurationEnvironmentTy` structure.
4559
4559
static void
4560
- initTargetDefaultAttrs (omp::TargetOp targetOp,
4560
+ initTargetDefaultAttrs (omp::TargetOp targetOp, Operation *capturedOp,
4561
4561
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
4562
4562
bool isTargetDevice) {
4563
4563
// TODO: Handle constant 'if' clauses.
4564
- Operation *capturedOp = targetOp.getInnermostCapturedOmpOp ();
4565
4564
4566
4565
Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
4567
4566
if (!isTargetDevice) {
@@ -4643,7 +4642,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
4643
4642
combinedMaxThreadsVal = maxThreadsVal;
4644
4643
4645
4644
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
4646
- attrs.ExecFlags = targetOp.getKernelExecFlags ();
4645
+ attrs.ExecFlags = targetOp.getKernelExecFlags (capturedOp );
4647
4646
attrs.MinTeams = minTeamsVal;
4648
4647
attrs.MaxTeams .front () = maxTeamsVal;
4649
4648
attrs.MinThreads = 1 ;
@@ -4659,10 +4658,9 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
4659
4658
static void
4660
4659
initTargetRuntimeAttrs (llvm::IRBuilderBase &builder,
4661
4660
LLVM::ModuleTranslation &moduleTranslation,
4662
- omp::TargetOp targetOp,
4661
+ omp::TargetOp targetOp, Operation *capturedOp,
4663
4662
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
4664
- omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(
4665
- targetOp.getInnermostCapturedOmpOp ());
4663
+ omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
4666
4664
unsigned numLoops = loopOp ? loopOp.getNumLoops () : 0 ;
4667
4665
4668
4666
Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
@@ -4689,7 +4687,8 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
4689
4687
if (numThreads)
4690
4688
attrs.MaxThreads = moduleTranslation.lookupValue (numThreads);
4691
4689
4692
- if (targetOp.getKernelExecFlags () != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
4690
+ if (targetOp.getKernelExecFlags (capturedOp) !=
4691
+ llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
4693
4692
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
4694
4693
attrs.LoopTripCount = nullptr ;
4695
4694
@@ -4938,12 +4937,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
4938
4937
4939
4938
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
4940
4939
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
4941
- initTargetDefaultAttrs (targetOp, defaultAttrs, isTargetDevice);
4940
+ Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp ();
4941
+ initTargetDefaultAttrs (targetOp, targetCapturedOp, defaultAttrs,
4942
+ isTargetDevice);
4942
4943
4943
4944
// Collect host-evaluated values needed to properly launch the kernel from the
4944
4945
// host.
4945
4946
if (!isTargetDevice)
4946
- initTargetRuntimeAttrs (builder, moduleTranslation, targetOp, runtimeAttrs);
4947
+ initTargetRuntimeAttrs (builder, moduleTranslation, targetOp,
4948
+ targetCapturedOp, runtimeAttrs);
4947
4949
4948
4950
// Pass host-evaluated values as parameters to the kernel / host fallback,
4949
4951
// except if they are constants. In any case, map the MLIR block argument to
0 commit comments