From 5e399775e16693a332a0e2437bc13b6dc6d36e8c Mon Sep 17 00:00:00 2001 From: skc7 Date: Fri, 27 Jun 2025 11:05:35 +0530 Subject: [PATCH 1/2] [flang] Add support for workdistribute construct in flang frontend --- .../flang/Semantics/openmp-directive-sets.h | 14 +++++ flang/lib/Lower/OpenMP/OpenMP.cpp | 26 +++++++- flang/lib/Parser/openmp-parsers.cpp | 7 ++- flang/lib/Semantics/resolve-directives.cpp | 8 ++- flang/test/Lower/OpenMP/workdistribute.f90 | 59 +++++++++++++++++++ llvm/include/llvm/Frontend/OpenMP/OMP.td | 55 +++++++++++++++++ mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 22 +++++++ 7 files changed, 188 insertions(+), 3 deletions(-) create mode 100644 flang/test/Lower/OpenMP/workdistribute.f90 diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h index dd610c9702c28..7ced6ed9b44d6 100644 --- a/flang/include/flang/Semantics/openmp-directive-sets.h +++ b/flang/include/flang/Semantics/openmp-directive-sets.h @@ -143,6 +143,7 @@ static const OmpDirectiveSet topTargetSet{ Directive::OMPD_target_teams_distribute_parallel_do_simd, Directive::OMPD_target_teams_distribute_simd, Directive::OMPD_target_teams_loop, + Directive::OMPD_target_teams_workdistribute, }; static const OmpDirectiveSet allTargetSet{topTargetSet}; @@ -172,6 +173,7 @@ static const OmpDirectiveSet topTeamsSet{ Directive::OMPD_teams_distribute_parallel_do_simd, Directive::OMPD_teams_distribute_simd, Directive::OMPD_teams_loop, + Directive::OMPD_teams_workdistribute, }; static const OmpDirectiveSet bottomTeamsSet{ @@ -187,9 +189,16 @@ static const OmpDirectiveSet allTeamsSet{ Directive::OMPD_target_teams_distribute_parallel_do_simd, Directive::OMPD_target_teams_distribute_simd, Directive::OMPD_target_teams_loop, + Directive::OMPD_target_teams_workdistribute, } | topTeamsSet, }; +static const OmpDirectiveSet allWorkdistributeSet{ + Directive::OMPD_workdistribute, + Directive::OMPD_teams_workdistribute, + Directive::OMPD_target_teams_workdistribute, +}; + //===----------------------------------------------------------------------===// // Directive sets for groups of multiple directives //===----------------------------------------------------------------------===// @@ -230,6 +239,9 @@ static const OmpDirectiveSet blockConstructSet{ Directive::OMPD_taskgroup, Directive::OMPD_teams, Directive::OMPD_workshare, + Directive::OMPD_target_teams_workdistribute, + Directive::OMPD_teams_workdistribute, + Directive::OMPD_workdistribute, }; static const OmpDirectiveSet loopConstructSet{ @@ -294,6 +306,7 @@ static const OmpDirectiveSet workShareSet{ Directive::OMPD_scope, Directive::OMPD_sections, Directive::OMPD_single, + Directive::OMPD_workdistribute, } | allDoSet, }; @@ -376,6 +389,7 @@ static const OmpDirectiveSet nestedReduceWorkshareAllowedSet{ }; static const OmpDirectiveSet nestedTeamsAllowedSet{ + Directive::OMPD_workdistribute, Directive::OMPD_distribute, Directive::OMPD_distribute_parallel_do, Directive::OMPD_distribute_parallel_do_simd, diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index ebd1d038716e4..16d58b6be535f 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -585,6 +585,16 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv); break; + case OMPD_teams_workdistribute: + cp.processThreadLimit(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_target_teams_workdistribute: + cp.processNumTeams(stmtCtx, hostInfo.ops); + processSingleNestedIf([](Directive nestedDir) { + return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir); + }); + break; + // Standalone 'target' case. case OMPD_target: { processSingleNestedIf( @@ -2682,6 +2692,17 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, queue, item, clauseOps); } +static mlir::omp::WorkdistributeOp genWorkdistributeOp( + lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, + mlir::Location loc, const ConstructQueue &queue, + ConstructQueue::const_iterator item) { + return genOpWithBody( + OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, + llvm::omp::Directive::OMPD_workdistribute), + queue, item); +} + //===----------------------------------------------------------------------===// // Code generation functions for the standalone version of constructs that can // also be a leaf of a composite construct @@ -3302,7 +3323,10 @@ static void genOMPDispatch(lower::AbstractConverter &converter, TODO(loc, "Unhandled loop directive (" + llvm::omp::getOpenMPDirectiveName(dir, version) + ")"); } - // case llvm::omp::Directive::OMPD_workdistribute: + case llvm::omp::Directive::OMPD_workdistribute: + newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue, + item); + break; case llvm::omp::Directive::OMPD_workshare: newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp index c55642d969503..ad729932a5f00 100644 --- a/flang/lib/Parser/openmp-parsers.cpp +++ b/flang/lib/Parser/openmp-parsers.cpp @@ -1492,12 +1492,17 @@ TYPE_PARSER( "SINGLE" >> pure(llvm::omp::Directive::OMPD_single), "TARGET DATA" >> pure(llvm::omp::Directive::OMPD_target_data), "TARGET PARALLEL" >> pure(llvm::omp::Directive::OMPD_target_parallel), + "TARGET TEAMS WORKDISTRIBUTE" >> + pure(llvm::omp::Directive::OMPD_target_teams_workdistribute), "TARGET TEAMS" >> pure(llvm::omp::Directive::OMPD_target_teams), "TARGET" >> pure(llvm::omp::Directive::OMPD_target), "TASK"_id >> pure(llvm::omp::Directive::OMPD_task), "TASKGROUP" >> pure(llvm::omp::Directive::OMPD_taskgroup), + "TEAMS WORKDISTRIBUTE" >> + pure(llvm::omp::Directive::OMPD_teams_workdistribute), "TEAMS" >> pure(llvm::omp::Directive::OMPD_teams), - "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare)))) + "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare), + "WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_workdistribute)))) TYPE_PARSER(sourced(construct( sourced(Parser{}), Parser{}))) diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index 885c02e6ec74b..2e4e05f9e293b 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -1656,10 +1656,13 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) { case llvm::omp::Directive::OMPD_task: case llvm::omp::Directive::OMPD_taskgroup: case llvm::omp::Directive::OMPD_teams: + case llvm::omp::Directive::OMPD_workdistribute: case llvm::omp::Directive::OMPD_workshare: case llvm::omp::Directive::OMPD_parallel_workshare: case llvm::omp::Directive::OMPD_target_teams: + case llvm::omp::Directive::OMPD_target_teams_workdistribute: case llvm::omp::Directive::OMPD_target_parallel: + case llvm::omp::Directive::OMPD_teams_workdistribute: PushContext(beginDir.source, beginDir.v); break; default: @@ -1689,9 +1692,12 @@ void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) { case llvm::omp::Directive::OMPD_target: case llvm::omp::Directive::OMPD_task: case llvm::omp::Directive::OMPD_teams: + case llvm::omp::Directive::OMPD_workdistribute: case llvm::omp::Directive::OMPD_parallel_workshare: case llvm::omp::Directive::OMPD_target_teams: - case llvm::omp::Directive::OMPD_target_parallel: { + case llvm::omp::Directive::OMPD_target_parallel: + case llvm::omp::Directive::OMPD_target_teams_workdistribute: + case llvm::omp::Directive::OMPD_teams_workdistribute: { bool hasPrivate; for (const auto *allocName : allocateNames_) { hasPrivate = false; diff --git a/flang/test/Lower/OpenMP/workdistribute.f90 b/flang/test/Lower/OpenMP/workdistribute.f90 new file mode 100644 index 0000000000000..924205bb72e5e --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute.f90 @@ -0,0 +1,59 @@ +! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +subroutine target_teams_workdistribute() + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.workdistribute + !$omp target teams workdistribute + ! CHECK: fir.call + call f1() + ! CHECK: omp.terminator + ! CHECK: omp.terminator + ! CHECK: omp.terminator + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute + +! CHECK-LABEL: func @_QPteams_workdistribute +subroutine teams_workdistribute() + ! CHECK: omp.teams + ! CHECK: omp.workdistribute + !$omp teams workdistribute + ! CHECK: fir.call + call f1() + ! CHECK: omp.terminator + ! CHECK: omp.terminator + !$omp end teams workdistribute +end subroutine teams_workdistribute + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute_m +subroutine target_teams_workdistribute_m() + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.workdistribute + !$omp target + !$omp teams + !$omp workdistribute + ! CHECK: fir.call + call f1() + ! CHECK: omp.terminator + ! CHECK: omp.terminator + ! CHECK: omp.terminator + !$omp end workdistribute + !$omp end teams + !$omp end target +end subroutine target_teams_workdistribute_m + +! CHECK-LABEL: func @_QPteams_workdistribute_m +subroutine teams_workdistribute_m() + ! CHECK: omp.teams + ! CHECK: omp.workdistribute + !$omp teams + !$omp workdistribute + ! CHECK: fir.call + call f1() + ! CHECK: omp.terminator + ! CHECK: omp.terminator + !$omp end workdistribute + !$omp end teams +end subroutine teams_workdistribute_m diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td index a87111cb5a11d..d1831db37fc46 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -1286,6 +1286,15 @@ def OMP_EndWorkshare : Directive<[Spelling<"end workshare">]> { let category = OMP_Workshare.category; let languages = [L_Fortran]; } +def OMP_Workdistribute : Directive<[Spelling<"workdistribute">]> { + let association = AS_Block; + let category = CA_Executable; +} +def OMP_EndWorkdistribute : Directive<[Spelling<"end workdistribute">]> { + let leafConstructs = OMP_Workdistribute.leafConstructs; + let association = OMP_Workdistribute.association; + let category = OMP_Workdistribute.category; +} //===----------------------------------------------------------------------===// // Definitions of OpenMP compound directives @@ -2429,6 +2438,34 @@ def OMP_TargetTeamsDistributeSimd let leafConstructs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Simd]; let category = CA_Executable; } +def OMP_TargetTeamsWorkdistribute : Directive<[Spelling<"target teams workdistribute">]> { + let allowedClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let allowedOnceClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let leafConstructs = [OMP_Target, OMP_Teams, OMP_Workdistribute]; + let category = CA_Executable; +} def OMP_target_teams_loop : Directive<[Spelling<"target teams loop">]> { let allowedClauses = [ VersionedClause, @@ -2659,6 +2696,24 @@ def OMP_TeamsDistributeSimd : Directive<[Spelling<"teams distribute simd">]> { let leafConstructs = [OMP_Teams, OMP_Distribute, OMP_Simd]; let category = CA_Executable; } +def OMP_TeamsWorkdistribute : Directive<[Spelling<"teams workdistribute">]> { + let allowedClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let allowedOnceClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let leafConstructs = [OMP_Teams, OMP_Workdistribute]; + let category = CA_Executable; +} def OMP_teams_loop : Directive<[Spelling<"teams loop">]> { let allowedClauses = [ VersionedClause, diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index ac80926053a2d..49824cba733b1 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1887,4 +1887,26 @@ def MaskedOp : OpenMP_Op<"masked", clauses = [ ]; } +//===----------------------------------------------------------------------===// +// workdistribute Construct +//===----------------------------------------------------------------------===// + +def WorkdistributeOp : OpenMP_Op<"workdistribute"> { + let summary = "workdistribute directive"; + let description = [{ + workdistribute divides execution of the enclosed structured block into + separate units of work, each executed only once by each + initial thread in the league. + ``` + !$omp target teams + !$omp workdistribute + y = a * x + y + !$omp end workdistribute + !$omp end target teams + ``` + }]; + let regions = (region AnyRegion:$region); + let assemblyFormat = "$region attr-dict"; +} + #endif // OPENMP_OPS From 54fb2288d3ec1e23566662387ee484e292af8f33 Mon Sep 17 00:00:00 2001 From: skc7 Date: Fri, 27 Jun 2025 16:33:58 +0530 Subject: [PATCH 2/2] [OpenMP] Add verifier and tests for workdistribute mlir op. --- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 1 + mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 15 +++++++++++++ mlir/test/Dialect/OpenMP/invalid.mlir | 21 +++++++++++++++++++ mlir/test/Dialect/OpenMP/ops.mlir | 13 ++++++++++++ 4 files changed, 50 insertions(+) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 49824cba733b1..a58e09d7bda71 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1906,6 +1906,7 @@ def WorkdistributeOp : OpenMP_Op<"workdistribute"> { ``` }]; let regions = (region AnyRegion:$region); + let hasVerifier = 1; let assemblyFormat = "$region attr-dict"; } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index e94d570b57122..e2dd338829e76 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -3493,6 +3493,21 @@ LogicalResult ScanOp::verify() { "reduction modifier"); } +//===----------------------------------------------------------------------===// +// WorkdistributeOp +//===----------------------------------------------------------------------===// + +LogicalResult WorkdistributeOp::verify() { + Region ®ion = getRegion(); + if (!region.hasOneBlock()) + return emitOpError("region must contain exactly one block"); + + Operation *parentOp = (*this)->getParentOp(); + if (!llvm::dyn_cast(parentOp)) + return emitOpError("workdistribute must be nested under teams"); + return success(); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 060b3cd2455a0..522d20558a2b5 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -2960,3 +2960,24 @@ llvm.func @invalid_mapper(%0 : !llvm.ptr) { } llvm.return } + +func.func @invalid_workdistribute_with_multiple_blocks() { + // expected-error @below {{workdistribute must be nested under teams}} + omp.workdistribute { + omp.terminator + } + return +} + +func.func @invalid_workdistribute_with_multiple_blocks() { + omp.teams { + // expected-error @below {{region must contain exactly one block}} + omp.workdistribute { + cf.br ^bb1 + ^bb1: + omp.terminator + } + omp.terminator + } + return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 47cfc5278a5d0..af80284e53537 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -3197,3 +3197,16 @@ func.func @omp_workshare_loop_wrapper_attrs(%idx : index) { } return } + +// CHECK-LABEL: func @omp_workdistribute +func.func @omp_workdistribute() { + // CHECK: omp.teams + omp.teams { + // CHECK: omp.workdistribute + omp.workdistribute { + omp.terminator + } + omp.terminator + } + return +}