Skip to content

Commit 943ecef

Browse files
committed
Initial implementation of tiling.
1 parent 74ec1c2 commit 943ecef

File tree

15 files changed

+459
-89
lines changed

15 files changed

+459
-89
lines changed

flang/include/flang/Lower/OpenMP.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ void genOpenMPDeclarativeConstruct(AbstractConverter &,
7979
void genOpenMPSymbolProperties(AbstractConverter &converter,
8080
const pft::Variable &var);
8181

82-
int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
8382
void genThreadprivateOp(AbstractConverter &, const pft::Variable &);
8483
void genDeclareTargetIntGlobal(AbstractConverter &, const pft::Variable &);
8584
bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);

flang/include/flang/Parser/parse-tree.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5024,9 +5024,10 @@ struct OpenMPBlockConstruct {
50245024
struct OpenMPLoopConstruct {
50255025
TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct);
50265026
OpenMPLoopConstruct(OmpBeginLoopDirective &&a)
5027-
: t({std::move(a), std::nullopt, std::nullopt}) {}
5027+
: t({std::move(a), std::nullopt, std::nullopt, std::nullopt}) {}
50285028
std::tuple<OmpBeginLoopDirective, std::optional<DoConstruct>,
5029-
std::optional<OmpEndLoopDirective>>
5029+
std::optional<common::Indirection<OpenMPLoopConstruct>>,
5030+
std::optional<OmpEndLoopDirective>>
50305031
t;
50315032
};
50325033

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
456456
return;
457457

458458
const parser::OmpClauseList *beginClauseList = nullptr;
459+
const parser::OmpClauseList *middleClauseList = nullptr;
459460
const parser::OmpClauseList *endClauseList = nullptr;
460461
common::visit(
461462
common::visitors{
@@ -473,6 +474,22 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
473474
beginClauseList =
474475
&std::get<parser::OmpClauseList>(beginDirective.t);
475476

477+
// FIXME(JAN): For now we check if there is an inner
478+
// OpenMPLoopConstruct, and extract the size clause from there
479+
const auto &innerOptional = std::get<std::optional<
480+
common::Indirection<parser::OpenMPLoopConstruct>>>(
481+
ompConstruct.t);
482+
if (innerOptional.has_value()) {
483+
const auto &innerLoopDirective = innerOptional.value().value();
484+
const auto &innerBegin =
485+
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
486+
const auto &innerDirective =
487+
std::get<parser::OmpLoopDirective>(innerBegin.t);
488+
if (innerDirective.v == llvm::omp::Directive::OMPD_tile) {
489+
middleClauseList =
490+
&std::get<parser::OmpClauseList>(innerBegin.t);
491+
}
492+
}
476493
if (auto &endDirective =
477494
std::get<std::optional<parser::OmpEndLoopDirective>>(
478495
ompConstruct.t))
@@ -485,6 +502,9 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
485502
assert(beginClauseList && "expected begin directive");
486503
clauses.append(makeClauses(*beginClauseList, semaCtx));
487504

505+
if (middleClauseList)
506+
clauses.append(makeClauses(*middleClauseList, semaCtx));
507+
488508
if (endClauseList)
489509
clauses.append(makeClauses(*endClauseList, semaCtx));
490510
};
@@ -960,6 +980,7 @@ static void genLoopVars(
960980
storeOp =
961981
createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
962982
}
983+
963984
firOpBuilder.setInsertionPointAfter(storeOp);
964985
}
965986

@@ -1712,6 +1733,23 @@ genLoopNestClauses(lower::AbstractConverter &converter,
17121733
cp.processCollapse(loc, eval, clauseOps, iv);
17131734

17141735
clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
1736+
1737+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1738+
for (auto &clause : clauses) {
1739+
if (clause.id == llvm::omp::Clause::OMPC_collapse) {
1740+
const auto &collapse = std::get<clause::Collapse>(clause.u);
1741+
int64_t collapseValue = evaluate::ToInt64(collapse.v).value();
1742+
clauseOps.numCollapse = firOpBuilder.getI64IntegerAttr(collapseValue);
1743+
} else if (clause.id == llvm::omp::Clause::OMPC_sizes) {
1744+
const auto &sizes = std::get<clause::Sizes>(clause.u);
1745+
llvm::SmallVector<int64_t> sizeValues;
1746+
for (auto &size : sizes.v) {
1747+
int64_t sizeValue = evaluate::ToInt64(size).value();
1748+
sizeValues.push_back(sizeValue);
1749+
}
1750+
clauseOps.tileSizes = sizeValues;
1751+
}
1752+
}
17151753
}
17161754

17171755
static void genLoopClauses(
@@ -2085,9 +2123,9 @@ static mlir::omp::LoopNestOp genLoopNestOp(
20852123
return llvm::SmallVector<const semantics::Symbol *>(iv);
20862124
};
20872125

2088-
auto *nestedEval =
2089-
getCollapsedLoopEval(eval, getCollapseValue(item->clauses));
2090-
2126+
uint64_t nestValue = getCollapseValue(item->clauses);
2127+
nestValue = nestValue < iv.size() ? iv.size() : nestValue;
2128+
auto *nestedEval = getCollapsedLoopEval(eval, nestValue);
20912129
return genOpWithBody<mlir::omp::LoopNestOp>(
20922130
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, *nestedEval,
20932131
directive)
@@ -4186,6 +4224,20 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
41864224
std::get<parser::OmpBeginLoopDirective>(loopConstruct.t);
41874225
List<Clause> clauses = makeClauses(
41884226
std::get<parser::OmpClauseList>(beginLoopDirective.t), semaCtx);
4227+
4228+
const auto &innerOptional = std::get<std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(loopConstruct.t);
4229+
if (innerOptional.has_value()) {
4230+
const auto &innerLoopDirective = innerOptional.value().value();
4231+
const auto &innerBegin =
4232+
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
4233+
const auto &innerDirective =
4234+
std::get<parser::OmpLoopDirective>(innerBegin.t);
4235+
if (innerDirective.v == llvm::omp::Directive::OMPD_tile) {
4236+
clauses.append(
4237+
makeClauses(std::get<parser::OmpClauseList>(innerBegin.t), semaCtx));
4238+
}
4239+
}
4240+
41894241
if (auto &endLoopDirective =
41904242
std::get<std::optional<parser::OmpEndLoopDirective>>(
41914243
loopConstruct.t)) {
@@ -4292,18 +4344,6 @@ void Fortran::lower::genOpenMPSymbolProperties(
42924344
lower::genDeclareTargetIntGlobal(converter, var);
42934345
}
42944346

4295-
int64_t
4296-
Fortran::lower::getCollapseValue(const parser::OmpClauseList &clauseList) {
4297-
for (const parser::OmpClause &clause : clauseList.v) {
4298-
if (const auto &collapseClause =
4299-
std::get_if<parser::OmpClause::Collapse>(&clause.u)) {
4300-
const auto *expr = semantics::GetExpr(collapseClause->v);
4301-
return evaluate::ToInt64(*expr).value();
4302-
}
4303-
}
4304-
return 1;
4305-
}
4306-
43074347
void Fortran::lower::genThreadprivateOp(lower::AbstractConverter &converter,
43084348
const lower::pft::Variable &var) {
43094349
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,22 @@ namespace lower {
3838
namespace omp {
3939

4040
int64_t getCollapseValue(const List<Clause> &clauses) {
41-
auto iter = llvm::find_if(clauses, [](const Clause &clause) {
42-
return clause.id == llvm::omp::Clause::OMPC_collapse;
43-
});
44-
if (iter != clauses.end()) {
45-
const auto &collapse = std::get<clause::Collapse>(iter->u);
46-
return evaluate::ToInt64(collapse.v).value();
41+
int64_t collapseValue = 1;
42+
int64_t numTileSizes = 0;
43+
for (auto &clause : clauses) {
44+
if (clause.id == llvm::omp::Clause::OMPC_collapse) {
45+
const auto &collapse = std::get<clause::Collapse>(clause.u);
46+
collapseValue = evaluate::ToInt64(collapse.v).value();
47+
} else if (clause.id == llvm::omp::Clause::OMPC_sizes) {
48+
const auto &sizes = std::get<clause::Sizes>(clause.u);
49+
numTileSizes = sizes.v.size();
50+
}
4751
}
48-
return 1;
52+
53+
collapseValue = collapseValue - numTileSizes;
54+
int64_t result =
55+
collapseValue > numTileSizes ? collapseValue : numTileSizes;
56+
return result;
4957
}
5058

5159
void genObjectList(const ObjectList &objects,
@@ -611,6 +619,7 @@ bool collectLoopRelatedInfo(
611619
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
612620
mlir::omp::LoopRelatedClauseOps &result,
613621
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
622+
614623
bool found = false;
615624
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
616625

@@ -626,7 +635,16 @@ bool collectLoopRelatedInfo(
626635
collapseValue = evaluate::ToInt64(clause->v).value();
627636
found = true;
628637
}
638+
std::int64_t sizesLengthValue = 0l;
639+
if (auto *clause =
640+
ClauseFinder::findUniqueClause<omp::clause::Sizes>(clauses)) {
641+
sizesLengthValue = clause->v.size();
642+
found = true;
643+
}
629644

645+
collapseValue = collapseValue - sizesLengthValue;
646+
collapseValue =
647+
collapseValue < sizesLengthValue ? sizesLengthValue : collapseValue;
630648
std::size_t loopVarTypeSize = 0;
631649
do {
632650
lower::pft::Evaluation *doLoop =
@@ -659,7 +677,6 @@ bool collectLoopRelatedInfo(
659677
} while (collapseValue > 0);
660678

661679
convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);
662-
663680
return found;
664681
}
665682
} // namespace omp

flang/lib/Semantics/canonicalize-omp.cpp

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
#include "canonicalize-omp.h"
1010
#include "flang/Parser/parse-tree-visitor.h"
11-
11+
# include <stack>
1212
// After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
1313
// Constructs more structured which provide explicit scopes for later
1414
// structural checks and semantic analysis.
@@ -112,15 +112,17 @@ class CanonicalizationOfOmp {
112112
// in the same iteration
113113
//
114114
// Original:
115-
// ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
116-
// OmpBeginLoopDirective
115+
// ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct t->
116+
// OmpBeginLoopDirective t-> OmpLoopDirective
117+
// [ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct u->
118+
/// OmpBeginLoopDirective t-> OmpLoopDirective t-> Tile v-> OMP_tile]
117119
// ExecutableConstruct -> DoConstruct
120+
// [ExecutableConstruct -> OmpEndLoopDirective]
118121
// ExecutableConstruct -> OmpEndLoopDirective (if available)
119122
//
120123
// After rewriting:
121-
// ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
122-
// OmpBeginLoopDirective
123-
// DoConstruct
124+
// ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct t->
125+
// OmpBeginLoopDirective t -> OmpLoopDirective -> DoConstruct
124126
// OmpEndLoopDirective (if available)
125127
parser::Block::iterator nextIt;
126128
auto &beginDir{std::get<parser::OmpBeginLoopDirective>(x.t)};
@@ -132,19 +134,39 @@ class CanonicalizationOfOmp {
132134
if (GetConstructIf<parser::CompilerDirective>(*nextIt))
133135
continue;
134136

137+
// Keep track of the loops to handle the end loop directives
138+
std::stack<parser::OpenMPLoopConstruct *> loops;
139+
loops.push(&x);
140+
while (auto *innerConstruct{
141+
GetConstructIf<parser::OpenMPConstruct>(*nextIt)}) {
142+
if (auto *innerOmpLoop{
143+
std::get_if<parser::OpenMPLoopConstruct>(&innerConstruct->u)}) {
144+
std::get<
145+
std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
146+
loops.top()->t) = std::move(*innerOmpLoop);
147+
// Retrieveing the address so that DoConstruct or inner loop can be
148+
// set later.
149+
loops.push(&(std::get<std::optional<
150+
common::Indirection<parser::OpenMPLoopConstruct>>>(
151+
loops.top()->t)
152+
.value()
153+
.value()));
154+
nextIt = block.erase(nextIt);
155+
}
156+
}
135157
if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
136158
if (doCons->GetLoopControl()) {
137-
// move DoConstruct
138-
std::get<std::optional<parser::DoConstruct>>(x.t) =
159+
std::get<std::optional<parser::DoConstruct>>(loops.top()->t) =
139160
std::move(*doCons);
140161
nextIt = block.erase(nextIt);
141162
// try to match OmpEndLoopDirective
142-
if (nextIt != block.end()) {
163+
while (nextIt != block.end() && !loops.empty()) {
143164
if (auto *endDir{
144165
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
145-
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
146-
std::move(*endDir);
147-
block.erase(nextIt);
166+
std::get<std::optional<parser::OmpEndLoopDirective>>(
167+
loops.top()->t) = std::move(*endDir);
168+
nextIt = block.erase(nextIt);
169+
loops.pop();
148170
}
149171
}
150172
} else {

0 commit comments

Comments
 (0)