Skip to content

Commit bd8edf7

Browse files
committed
Initial implementation of tiling.
1 parent 90beda2 commit bd8edf7

File tree

15 files changed

+461
-86
lines changed

15 files changed

+461
-86
lines changed

flang/include/flang/Lower/OpenMP.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ void genOpenMPDeclarativeConstruct(AbstractConverter &,
7979
void genOpenMPSymbolProperties(AbstractConverter &converter,
8080
const pft::Variable &var);
8181

82-
int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
8382
void genThreadprivateOp(AbstractConverter &, const pft::Variable &);
8483
void genDeclareTargetIntGlobal(AbstractConverter &, const pft::Variable &);
8584
bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);

flang/include/flang/Parser/parse-tree.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5075,9 +5075,10 @@ struct OpenMPBlockConstruct {
50755075
struct OpenMPLoopConstruct {
50765076
TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct);
50775077
OpenMPLoopConstruct(OmpBeginLoopDirective &&a)
5078-
: t({std::move(a), std::nullopt, std::nullopt}) {}
5078+
: t({std::move(a), std::nullopt, std::nullopt, std::nullopt}) {}
50795079
std::tuple<OmpBeginLoopDirective, std::optional<DoConstruct>,
5080-
std::optional<OmpEndLoopDirective>>
5080+
std::optional<common::Indirection<OpenMPLoopConstruct>>,
5081+
std::optional<OmpEndLoopDirective>>
50815082
t;
50825083
};
50835084

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
451451
return;
452452

453453
const parser::OmpClauseList *beginClauseList = nullptr;
454+
const parser::OmpClauseList *middleClauseList = nullptr;
454455
const parser::OmpClauseList *endClauseList = nullptr;
455456
common::visit(
456457
common::visitors{
@@ -468,6 +469,22 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
468469
beginClauseList =
469470
&std::get<parser::OmpClauseList>(beginDirective.t);
470471

472+
// FIXME(JAN): For now we check if there is an inner
473+
// OpenMPLoopConstruct, and extract the size clause from there
474+
const auto &innerOptional = std::get<std::optional<
475+
common::Indirection<parser::OpenMPLoopConstruct>>>(
476+
ompConstruct.t);
477+
if (innerOptional.has_value()) {
478+
const auto &innerLoopDirective = innerOptional.value().value();
479+
const auto &innerBegin =
480+
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
481+
const auto &innerDirective =
482+
std::get<parser::OmpLoopDirective>(innerBegin.t);
483+
if (innerDirective.v == llvm::omp::Directive::OMPD_tile) {
484+
middleClauseList =
485+
&std::get<parser::OmpClauseList>(innerBegin.t);
486+
}
487+
}
471488
if (auto &endDirective =
472489
std::get<std::optional<parser::OmpEndLoopDirective>>(
473490
ompConstruct.t))
@@ -480,6 +497,9 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
480497
assert(beginClauseList && "expected begin directive");
481498
clauses.append(makeClauses(*beginClauseList, semaCtx));
482499

500+
if (middleClauseList)
501+
clauses.append(makeClauses(*middleClauseList, semaCtx));
502+
483503
if (endClauseList)
484504
clauses.append(makeClauses(*endClauseList, semaCtx));
485505
};
@@ -955,6 +975,7 @@ static void genLoopVars(
955975
storeOp =
956976
createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
957977
}
978+
958979
firOpBuilder.setInsertionPointAfter(storeOp);
959980
}
960981

@@ -1697,6 +1718,23 @@ genLoopNestClauses(lower::AbstractConverter &converter,
16971718
cp.processCollapse(loc, eval, clauseOps, iv);
16981719

16991720
clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
1721+
1722+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1723+
for (auto &clause : clauses) {
1724+
if (clause.id == llvm::omp::Clause::OMPC_collapse) {
1725+
const auto &collapse = std::get<clause::Collapse>(clause.u);
1726+
int64_t collapseValue = evaluate::ToInt64(collapse.v).value();
1727+
clauseOps.numCollapse = firOpBuilder.getI64IntegerAttr(collapseValue);
1728+
} else if (clause.id == llvm::omp::Clause::OMPC_sizes) {
1729+
const auto &sizes = std::get<clause::Sizes>(clause.u);
1730+
llvm::SmallVector<int64_t> sizeValues;
1731+
for (auto &size : sizes.v) {
1732+
int64_t sizeValue = evaluate::ToInt64(size).value();
1733+
sizeValues.push_back(sizeValue);
1734+
}
1735+
clauseOps.tileSizes = sizeValues;
1736+
}
1737+
}
17001738
}
17011739

17021740
static void genLoopClauses(
@@ -2069,9 +2107,9 @@ static mlir::omp::LoopNestOp genLoopNestOp(
20692107
return llvm::SmallVector<const semantics::Symbol *>(iv);
20702108
};
20712109

2072-
auto *nestedEval =
2073-
getCollapsedLoopEval(eval, getCollapseValue(item->clauses));
2074-
2110+
uint64_t nestValue = getCollapseValue(item->clauses);
2111+
nestValue = nestValue < iv.size() ? iv.size() : nestValue;
2112+
auto *nestedEval = getCollapsedLoopEval(eval, nestValue);
20752113
return genOpWithBody<mlir::omp::LoopNestOp>(
20762114
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, *nestedEval,
20772115
directive)
@@ -4385,6 +4423,20 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
43854423
std::get<parser::OmpBeginLoopDirective>(loopConstruct.t);
43864424
List<Clause> clauses = makeClauses(
43874425
std::get<parser::OmpClauseList>(beginLoopDirective.t), semaCtx);
4426+
4427+
const auto &innerOptional = std::get<std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(loopConstruct.t);
4428+
if (innerOptional.has_value()) {
4429+
const auto &innerLoopDirective = innerOptional.value().value();
4430+
const auto &innerBegin =
4431+
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
4432+
const auto &innerDirective =
4433+
std::get<parser::OmpLoopDirective>(innerBegin.t);
4434+
if (innerDirective.v == llvm::omp::Directive::OMPD_tile) {
4435+
clauses.append(
4436+
makeClauses(std::get<parser::OmpClauseList>(innerBegin.t), semaCtx));
4437+
}
4438+
}
4439+
43884440
if (auto &endLoopDirective =
43894441
std::get<std::optional<parser::OmpEndLoopDirective>>(
43904442
loopConstruct.t)) {
@@ -4505,18 +4557,6 @@ void Fortran::lower::genOpenMPSymbolProperties(
45054557
lower::genDeclareTargetIntGlobal(converter, var);
45064558
}
45074559

4508-
int64_t
4509-
Fortran::lower::getCollapseValue(const parser::OmpClauseList &clauseList) {
4510-
for (const parser::OmpClause &clause : clauseList.v) {
4511-
if (const auto &collapseClause =
4512-
std::get_if<parser::OmpClause::Collapse>(&clause.u)) {
4513-
const auto *expr = semantics::GetExpr(collapseClause->v);
4514-
return evaluate::ToInt64(*expr).value();
4515-
}
4516-
}
4517-
return 1;
4518-
}
4519-
45204560
void Fortran::lower::genThreadprivateOp(lower::AbstractConverter &converter,
45214561
const lower::pft::Variable &var) {
45224562
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,22 @@ namespace lower {
3838
namespace omp {
3939

4040
int64_t getCollapseValue(const List<Clause> &clauses) {
41-
auto iter = llvm::find_if(clauses, [](const Clause &clause) {
42-
return clause.id == llvm::omp::Clause::OMPC_collapse;
43-
});
44-
if (iter != clauses.end()) {
45-
const auto &collapse = std::get<clause::Collapse>(iter->u);
46-
return evaluate::ToInt64(collapse.v).value();
41+
int64_t collapseValue = 1;
42+
int64_t numTileSizes = 0;
43+
for (auto &clause : clauses) {
44+
if (clause.id == llvm::omp::Clause::OMPC_collapse) {
45+
const auto &collapse = std::get<clause::Collapse>(clause.u);
46+
collapseValue = evaluate::ToInt64(collapse.v).value();
47+
} else if (clause.id == llvm::omp::Clause::OMPC_sizes) {
48+
const auto &sizes = std::get<clause::Sizes>(clause.u);
49+
numTileSizes = sizes.v.size();
50+
}
4751
}
48-
return 1;
52+
53+
collapseValue = collapseValue - numTileSizes;
54+
int64_t result =
55+
collapseValue > numTileSizes ? collapseValue : numTileSizes;
56+
return result;
4957
}
5058

5159
void genObjectList(const ObjectList &objects,
@@ -611,6 +619,7 @@ bool collectLoopRelatedInfo(
611619
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
612620
mlir::omp::LoopRelatedClauseOps &result,
613621
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
622+
614623
bool found = false;
615624
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
616625

@@ -626,7 +635,16 @@ bool collectLoopRelatedInfo(
626635
collapseValue = evaluate::ToInt64(clause->v).value();
627636
found = true;
628637
}
638+
std::int64_t sizesLengthValue = 0l;
639+
if (auto *clause =
640+
ClauseFinder::findUniqueClause<omp::clause::Sizes>(clauses)) {
641+
sizesLengthValue = clause->v.size();
642+
found = true;
643+
}
629644

645+
collapseValue = collapseValue - sizesLengthValue;
646+
collapseValue =
647+
collapseValue < sizesLengthValue ? sizesLengthValue : collapseValue;
630648
std::size_t loopVarTypeSize = 0;
631649
do {
632650
lower::pft::Evaluation *doLoop =
@@ -659,7 +677,6 @@ bool collectLoopRelatedInfo(
659677
} while (collapseValue > 0);
660678

661679
convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);
662-
663680
return found;
664681
}
665682
} // namespace omp

flang/lib/Semantics/canonicalize-omp.cpp

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
#include "canonicalize-omp.h"
1010
#include "flang/Parser/parse-tree-visitor.h"
11-
11+
# include <stack>
1212
// After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
1313
// Constructs more structured which provide explicit scopes for later
1414
// structural checks and semantic analysis.
@@ -112,15 +112,17 @@ class CanonicalizationOfOmp {
112112
// in the same iteration
113113
//
114114
// Original:
115-
// ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
116-
// OmpBeginLoopDirective
115+
// ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct t->
116+
// OmpBeginLoopDirective t-> OmpLoopDirective
117+
// [ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct u->
118+
/// OmpBeginLoopDirective t-> OmpLoopDirective t-> Tile v-> OMP_tile]
117119
// ExecutableConstruct -> DoConstruct
120+
// [ExecutableConstruct -> OmpEndLoopDirective]
118121
// ExecutableConstruct -> OmpEndLoopDirective (if available)
119122
//
120123
// After rewriting:
121-
// ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
122-
// OmpBeginLoopDirective
123-
// DoConstruct
124+
// ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct t->
125+
// OmpBeginLoopDirective t -> OmpLoopDirective -> DoConstruct
124126
// OmpEndLoopDirective (if available)
125127
parser::Block::iterator nextIt;
126128
auto &beginDir{std::get<parser::OmpBeginLoopDirective>(x.t)};
@@ -132,19 +134,39 @@ class CanonicalizationOfOmp {
132134
if (GetConstructIf<parser::CompilerDirective>(*nextIt))
133135
continue;
134136

137+
// Keep track of the loops to handle the end loop directives
138+
std::stack<parser::OpenMPLoopConstruct *> loops;
139+
loops.push(&x);
140+
while (auto *innerConstruct{
141+
GetConstructIf<parser::OpenMPConstruct>(*nextIt)}) {
142+
if (auto *innerOmpLoop{
143+
std::get_if<parser::OpenMPLoopConstruct>(&innerConstruct->u)}) {
144+
std::get<
145+
std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
146+
loops.top()->t) = std::move(*innerOmpLoop);
147+
// Retrieveing the address so that DoConstruct or inner loop can be
148+
// set later.
149+
loops.push(&(std::get<std::optional<
150+
common::Indirection<parser::OpenMPLoopConstruct>>>(
151+
loops.top()->t)
152+
.value()
153+
.value()));
154+
nextIt = block.erase(nextIt);
155+
}
156+
}
135157
if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
136158
if (doCons->GetLoopControl()) {
137-
// move DoConstruct
138-
std::get<std::optional<parser::DoConstruct>>(x.t) =
159+
std::get<std::optional<parser::DoConstruct>>(loops.top()->t) =
139160
std::move(*doCons);
140161
nextIt = block.erase(nextIt);
141162
// try to match OmpEndLoopDirective
142-
if (nextIt != block.end()) {
163+
while (nextIt != block.end() && !loops.empty()) {
143164
if (auto *endDir{
144165
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
145-
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
146-
std::move(*endDir);
147-
block.erase(nextIt);
166+
std::get<std::optional<parser::OmpEndLoopDirective>>(
167+
loops.top()->t) = std::move(*endDir);
168+
nextIt = block.erase(nextIt);
169+
loops.pop();
148170
}
149171
}
150172
} else {

0 commit comments

Comments
 (0)