Skip to content

[flang] Add support for workdistribute construct in flang frontend #146029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions flang/include/flang/Semantics/openmp-directive-sets.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ static const OmpDirectiveSet topTargetSet{
Directive::OMPD_target_teams_distribute_parallel_do_simd,
Directive::OMPD_target_teams_distribute_simd,
Directive::OMPD_target_teams_loop,
Directive::OMPD_target_teams_workdistribute,
};

static const OmpDirectiveSet allTargetSet{topTargetSet};
Expand Down Expand Up @@ -172,6 +173,7 @@ static const OmpDirectiveSet topTeamsSet{
Directive::OMPD_teams_distribute_parallel_do_simd,
Directive::OMPD_teams_distribute_simd,
Directive::OMPD_teams_loop,
Directive::OMPD_teams_workdistribute,
};

static const OmpDirectiveSet bottomTeamsSet{
Expand All @@ -187,9 +189,16 @@ static const OmpDirectiveSet allTeamsSet{
Directive::OMPD_target_teams_distribute_parallel_do_simd,
Directive::OMPD_target_teams_distribute_simd,
Directive::OMPD_target_teams_loop,
Directive::OMPD_target_teams_workdistribute,
} | topTeamsSet,
};

static const OmpDirectiveSet allWorkdistributeSet{
Directive::OMPD_workdistribute,
Directive::OMPD_teams_workdistribute,
Directive::OMPD_target_teams_workdistribute,
};

//===----------------------------------------------------------------------===//
// Directive sets for groups of multiple directives
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -230,6 +239,9 @@ static const OmpDirectiveSet blockConstructSet{
Directive::OMPD_taskgroup,
Directive::OMPD_teams,
Directive::OMPD_workshare,
Directive::OMPD_target_teams_workdistribute,
Directive::OMPD_teams_workdistribute,
Directive::OMPD_workdistribute,
};

static const OmpDirectiveSet loopConstructSet{
Expand Down Expand Up @@ -294,6 +306,7 @@ static const OmpDirectiveSet workShareSet{
Directive::OMPD_scope,
Directive::OMPD_sections,
Directive::OMPD_single,
Directive::OMPD_workdistribute,
} | allDoSet,
};

Expand Down Expand Up @@ -376,6 +389,7 @@ static const OmpDirectiveSet nestedReduceWorkshareAllowedSet{
};

static const OmpDirectiveSet nestedTeamsAllowedSet{
Directive::OMPD_workdistribute,
Directive::OMPD_distribute,
Directive::OMPD_distribute_parallel_do,
Directive::OMPD_distribute_parallel_do_simd,
Expand Down
26 changes: 25 additions & 1 deletion flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,16 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
break;

case OMPD_teams_workdistribute:
cp.processThreadLimit(stmtCtx, hostInfo.ops);
[[fallthrough]];
case OMPD_target_teams_workdistribute:
cp.processNumTeams(stmtCtx, hostInfo.ops);
processSingleNestedIf([](Directive nestedDir) {
return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir);
});
break;

// Standalone 'target' case.
case OMPD_target: {
processSingleNestedIf(
Expand Down Expand Up @@ -2682,6 +2692,17 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
queue, item, clauseOps);
}

static mlir::omp::WorkdistributeOp genWorkdistributeOp(
lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
mlir::Location loc, const ConstructQueue &queue,
ConstructQueue::const_iterator item) {
return genOpWithBody<mlir::omp::WorkdistributeOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_workdistribute),
queue, item);
}

//===----------------------------------------------------------------------===//
// Code generation functions for the standalone version of constructs that can
// also be a leaf of a composite construct
Expand Down Expand Up @@ -3302,7 +3323,10 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
TODO(loc, "Unhandled loop directive (" +
llvm::omp::getOpenMPDirectiveName(dir, version) + ")");
}
// case llvm::omp::Directive::OMPD_workdistribute:
case llvm::omp::Directive::OMPD_workdistribute:
newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue,
item);
break;
case llvm::omp::Directive::OMPD_workshare:
newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc,
queue, item);
Expand Down
7 changes: 6 additions & 1 deletion flang/lib/Parser/openmp-parsers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1492,12 +1492,17 @@ TYPE_PARSER(
"SINGLE" >> pure(llvm::omp::Directive::OMPD_single),
"TARGET DATA" >> pure(llvm::omp::Directive::OMPD_target_data),
"TARGET PARALLEL" >> pure(llvm::omp::Directive::OMPD_target_parallel),
"TARGET TEAMS WORKDISTRIBUTE" >>
pure(llvm::omp::Directive::OMPD_target_teams_workdistribute),
"TARGET TEAMS" >> pure(llvm::omp::Directive::OMPD_target_teams),
"TARGET" >> pure(llvm::omp::Directive::OMPD_target),
"TASK"_id >> pure(llvm::omp::Directive::OMPD_task),
"TASKGROUP" >> pure(llvm::omp::Directive::OMPD_taskgroup),
"TEAMS WORKDISTRIBUTE" >>
pure(llvm::omp::Directive::OMPD_teams_workdistribute),
"TEAMS" >> pure(llvm::omp::Directive::OMPD_teams),
"WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare))))
"WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare),
"WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_workdistribute))))

TYPE_PARSER(sourced(construct<OmpBeginBlockDirective>(
sourced(Parser<OmpBlockDirective>{}), Parser<OmpClauseList>{})))
Expand Down
8 changes: 7 additions & 1 deletion flang/lib/Semantics/resolve-directives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1656,10 +1656,13 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) {
case llvm::omp::Directive::OMPD_task:
case llvm::omp::Directive::OMPD_taskgroup:
case llvm::omp::Directive::OMPD_teams:
case llvm::omp::Directive::OMPD_workdistribute:
case llvm::omp::Directive::OMPD_workshare:
case llvm::omp::Directive::OMPD_parallel_workshare:
case llvm::omp::Directive::OMPD_target_teams:
case llvm::omp::Directive::OMPD_target_teams_workdistribute:
case llvm::omp::Directive::OMPD_target_parallel:
case llvm::omp::Directive::OMPD_teams_workdistribute:
PushContext(beginDir.source, beginDir.v);
break;
default:
Expand Down Expand Up @@ -1689,9 +1692,12 @@ void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) {
case llvm::omp::Directive::OMPD_target:
case llvm::omp::Directive::OMPD_task:
case llvm::omp::Directive::OMPD_teams:
case llvm::omp::Directive::OMPD_workdistribute:
case llvm::omp::Directive::OMPD_parallel_workshare:
case llvm::omp::Directive::OMPD_target_teams:
case llvm::omp::Directive::OMPD_target_parallel: {
case llvm::omp::Directive::OMPD_target_parallel:
case llvm::omp::Directive::OMPD_target_teams_workdistribute:
case llvm::omp::Directive::OMPD_teams_workdistribute: {
bool hasPrivate;
for (const auto *allocName : allocateNames_) {
hasPrivate = false;
Expand Down
59 changes: 59 additions & 0 deletions flang/test/Lower/OpenMP/workdistribute.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s

! CHECK-LABEL: func @_QPtarget_teams_workdistribute
subroutine target_teams_workdistribute()
! CHECK: omp.target
! CHECK: omp.teams
! CHECK: omp.workdistribute
!$omp target teams workdistribute
! CHECK: fir.call
call f1()
! CHECK: omp.terminator
! CHECK: omp.terminator
! CHECK: omp.terminator
!$omp end target teams workdistribute
end subroutine target_teams_workdistribute

! CHECK-LABEL: func @_QPteams_workdistribute
subroutine teams_workdistribute()
! CHECK: omp.teams
! CHECK: omp.workdistribute
!$omp teams workdistribute
! CHECK: fir.call
call f1()
! CHECK: omp.terminator
! CHECK: omp.terminator
!$omp end teams workdistribute
end subroutine teams_workdistribute

! CHECK-LABEL: func @_QPtarget_teams_workdistribute_m
subroutine target_teams_workdistribute_m()
! CHECK: omp.target
! CHECK: omp.teams
! CHECK: omp.workdistribute
!$omp target
!$omp teams
!$omp workdistribute
! CHECK: fir.call
call f1()
! CHECK: omp.terminator
! CHECK: omp.terminator
! CHECK: omp.terminator
!$omp end workdistribute
!$omp end teams
!$omp end target
end subroutine target_teams_workdistribute_m

! CHECK-LABEL: func @_QPteams_workdistribute_m
subroutine teams_workdistribute_m()
! CHECK: omp.teams
! CHECK: omp.workdistribute
!$omp teams
!$omp workdistribute
! CHECK: fir.call
call f1()
! CHECK: omp.terminator
! CHECK: omp.terminator
!$omp end workdistribute
!$omp end teams
end subroutine teams_workdistribute_m
55 changes: 55 additions & 0 deletions llvm/include/llvm/Frontend/OpenMP/OMP.td
Original file line number Diff line number Diff line change
Expand Up @@ -1286,6 +1286,15 @@ def OMP_EndWorkshare : Directive<[Spelling<"end workshare">]> {
let category = OMP_Workshare.category;
let languages = [L_Fortran];
}
def OMP_Workdistribute : Directive<[Spelling<"workdistribute">]> {
let association = AS_Block;
let category = CA_Executable;
}
def OMP_EndWorkdistribute : Directive<[Spelling<"end workdistribute">]> {
let leafConstructs = OMP_Workdistribute.leafConstructs;
let association = OMP_Workdistribute.association;
let category = OMP_Workdistribute.category;
}

//===----------------------------------------------------------------------===//
// Definitions of OpenMP compound directives
Expand Down Expand Up @@ -2429,6 +2438,34 @@ def OMP_TargetTeamsDistributeSimd
let leafConstructs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Simd];
let category = CA_Executable;
}
def OMP_TargetTeamsWorkdistribute : Directive<[Spelling<"target teams workdistribute">]> {
let allowedClauses = [
VersionedClause<OMPC_Allocate>,
VersionedClause<OMPC_Depend>,
VersionedClause<OMPC_FirstPrivate>,
VersionedClause<OMPC_HasDeviceAddr, 51>,
VersionedClause<OMPC_If>,
VersionedClause<OMPC_IsDevicePtr>,
VersionedClause<OMPC_Map>,
VersionedClause<OMPC_OMPX_Attribute>,
VersionedClause<OMPC_Private>,
VersionedClause<OMPC_Reduction>,
VersionedClause<OMPC_Shared>,
VersionedClause<OMPC_UsesAllocators, 50>,
];
let allowedOnceClauses = [
VersionedClause<OMPC_Default>,
VersionedClause<OMPC_DefaultMap>,
VersionedClause<OMPC_Device>,
VersionedClause<OMPC_NoWait>,
VersionedClause<OMPC_NumTeams>,
VersionedClause<OMPC_OMPX_DynCGroupMem>,
VersionedClause<OMPC_OMPX_Bare>,
VersionedClause<OMPC_ThreadLimit>,
];
let leafConstructs = [OMP_Target, OMP_Teams, OMP_Workdistribute];
let category = CA_Executable;
}
def OMP_target_teams_loop : Directive<[Spelling<"target teams loop">]> {
let allowedClauses = [
VersionedClause<OMPC_Allocate>,
Expand Down Expand Up @@ -2659,6 +2696,24 @@ def OMP_TeamsDistributeSimd : Directive<[Spelling<"teams distribute simd">]> {
let leafConstructs = [OMP_Teams, OMP_Distribute, OMP_Simd];
let category = CA_Executable;
}
def OMP_TeamsWorkdistribute : Directive<[Spelling<"teams workdistribute">]> {
let allowedClauses = [
VersionedClause<OMPC_Allocate>,
VersionedClause<OMPC_FirstPrivate>,
VersionedClause<OMPC_OMPX_Attribute>,
VersionedClause<OMPC_Private>,
VersionedClause<OMPC_Reduction>,
VersionedClause<OMPC_Shared>,
];
let allowedOnceClauses = [
VersionedClause<OMPC_Default>,
VersionedClause<OMPC_If, 52>,
VersionedClause<OMPC_NumTeams>,
VersionedClause<OMPC_ThreadLimit>,
];
let leafConstructs = [OMP_Teams, OMP_Workdistribute];
let category = CA_Executable;
}
def OMP_teams_loop : Directive<[Spelling<"teams loop">]> {
let allowedClauses = [
VersionedClause<OMPC_Allocate>,
Expand Down
23 changes: 23 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1887,4 +1887,27 @@ def MaskedOp : OpenMP_Op<"masked", clauses = [
];
}

//===----------------------------------------------------------------------===//
// workdistribute Construct
//===----------------------------------------------------------------------===//

def WorkdistributeOp : OpenMP_Op<"workdistribute"> {
let summary = "workdistribute directive";
let description = [{
workdistribute divides execution of the enclosed structured block into
separate units of work, each executed only once by each
initial thread in the league.
```
!$omp target teams
!$omp workdistribute
y = a * x + y
!$omp end workdistribute
!$omp end target teams
```
}];
let regions = (region AnyRegion:$region);
let hasVerifier = 1;
let assemblyFormat = "$region attr-dict";
}

#endif // OPENMP_OPS
15 changes: 15 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3493,6 +3493,21 @@ LogicalResult ScanOp::verify() {
"reduction modifier");
}

//===----------------------------------------------------------------------===//
// WorkdistributeOp
//===----------------------------------------------------------------------===//

LogicalResult WorkdistributeOp::verify() {
Region &region = getRegion();
if (!region.hasOneBlock())
return emitOpError("region must contain exactly one block");

Operation *parentOp = (*this)->getParentOp();
if (!llvm::dyn_cast<TeamsOp>(parentOp))
return emitOpError("workdistribute must be nested under teams");
return success();
}

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"

Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2960,3 +2960,24 @@ llvm.func @invalid_mapper(%0 : !llvm.ptr) {
}
llvm.return
}

func.func @invalid_workdistribute_with_multiple_blocks() {
// expected-error @below {{workdistribute must be nested under teams}}
omp.workdistribute {
omp.terminator
}
return
}

func.func @invalid_workdistribute_with_multiple_blocks() {
omp.teams {
// expected-error @below {{region must contain exactly one block}}
omp.workdistribute {
cf.br ^bb1
^bb1:
omp.terminator
}
omp.terminator
}
return
}
13 changes: 13 additions & 0 deletions mlir/test/Dialect/OpenMP/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3197,3 +3197,16 @@ func.func @omp_workshare_loop_wrapper_attrs(%idx : index) {
}
return
}

// CHECK-LABEL: func @omp_workdistribute
func.func @omp_workdistribute() {
// CHECK: omp.teams
omp.teams {
// CHECK: omp.workdistribute
omp.workdistribute {
omp.terminator
}
omp.terminator
}
return
}
Loading