Skip to content

Commit c4138a2

Browse files
authored
[mlir][acc][flang] Lower nested ACC loops with tile clause as collapsed loops (#147801)
In the case of nested loops, `acc.loop` is meant to subsume all of the loops that it applies to (when explicitly described as doing so in the OpenACC specification). So when there is a `acc loop tile(...)` present on nested Fortran DO loops, `acc.loop` should apply to the `n` loops that `tile` applies to. This change lowers such nested Fortran loops with tile clause into a collapsed `acc.loop` with `n` IVs, loop bounds, and step, in a similar fashion to the current lowering for acc loops with `collapse` clause.
1 parent 2197671 commit c4138a2

File tree

7 files changed

+80
-24
lines changed

7 files changed

+80
-24
lines changed

flang/include/flang/Lower/OpenACC.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ void attachDeclarePostDeallocAction(AbstractConverter &, fir::FirOpBuilder &,
114114
void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *,
115115
mlir::Location);
116116

117-
int64_t getCollapseValue(const Fortran::parser::AccClauseList &);
117+
int64_t getLoopCountForCollapseAndTile(const Fortran::parser::AccClauseList &);
118118

119119
bool isInOpenACCLoop(fir::FirOpBuilder &);
120120

flang/lib/Lower/Bridge.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3083,25 +3083,25 @@ class FirConverter : public Fortran::lower::AbstractConverter {
30833083
Fortran::lower::pft::Evaluation *curEval = &getEval();
30843084

30853085
if (accLoop || accCombined) {
3086-
int64_t collapseValue;
3086+
int64_t loopCount;
30873087
if (accLoop) {
30883088
const Fortran::parser::AccBeginLoopDirective &beginLoopDir =
30893089
std::get<Fortran::parser::AccBeginLoopDirective>(accLoop->t);
30903090
const Fortran::parser::AccClauseList &clauseList =
30913091
std::get<Fortran::parser::AccClauseList>(beginLoopDir.t);
3092-
collapseValue = Fortran::lower::getCollapseValue(clauseList);
3092+
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
30933093
} else if (accCombined) {
30943094
const Fortran::parser::AccBeginCombinedDirective &beginCombinedDir =
30953095
std::get<Fortran::parser::AccBeginCombinedDirective>(
30963096
accCombined->t);
30973097
const Fortran::parser::AccClauseList &clauseList =
30983098
std::get<Fortran::parser::AccClauseList>(beginCombinedDir.t);
3099-
collapseValue = Fortran::lower::getCollapseValue(clauseList);
3099+
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
31003100
}
31013101

31023102
if (curEval->lowerAsStructured()) {
31033103
curEval = &curEval->getFirstNestedEvaluation();
3104-
for (int64_t i = 1; i < collapseValue; i++)
3104+
for (int64_t i = 1; i < loopCount; i++)
31053105
curEval = &*std::next(curEval->getNestedEvaluations().begin());
31063106
}
31073107
}
@@ -6155,8 +6155,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
61556155

61566156
Fortran::lower::defineModuleVariable(*this, var);
61576157
}
6158-
for (auto &eval : mod.evaluationList)
6159-
genFIR(eval);
6158+
for (auto &eval : mod.evaluationList)
6159+
genFIR(eval);
61606160
}
61616161

61626162
/// Lower functions contained in a module.

flang/lib/Lower/OpenACC.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2442,8 +2442,9 @@ static mlir::acc::LoopOp createLoopOp(
24422442
inclusiveBounds.push_back(true);
24432443
}
24442444
} else {
2445-
int64_t collapseValue = Fortran::lower::getCollapseValue(accClauseList);
2446-
for (unsigned i = 0; i < collapseValue; ++i) {
2445+
int64_t loopCount =
2446+
Fortran::lower::getLoopCountForCollapseAndTile(accClauseList);
2447+
for (unsigned i = 0; i < loopCount; ++i) {
24472448
const Fortran::parser::LoopControl *loopControl;
24482449
if (i == 0) {
24492450
loopControl = &*outerDoConstruct.GetLoopControl();
@@ -2478,7 +2479,7 @@ static mlir::acc::LoopOp createLoopOp(
24782479

24792480
inclusiveBounds.push_back(true);
24802481

2481-
if (i < collapseValue - 1)
2482+
if (i < loopCount - 1)
24822483
crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
24832484
}
24842485
}
@@ -4940,15 +4941,25 @@ void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder,
49404941
builder.create<mlir::acc::YieldOp>(loc, yieldValue);
49414942
}
49424943

4943-
int64_t Fortran::lower::getCollapseValue(
4944+
int64_t Fortran::lower::getLoopCountForCollapseAndTile(
49444945
const Fortran::parser::AccClauseList &clauseList) {
4946+
int64_t collapseLoopCount = 1;
4947+
int64_t tileLoopCount = 1;
49454948
for (const Fortran::parser::AccClause &clause : clauseList.v) {
49464949
if (const auto *collapseClause =
49474950
std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
49484951
const parser::AccCollapseArg &arg = collapseClause->v;
49494952
const auto &collapseValue{std::get<parser::ScalarIntConstantExpr>(arg.t)};
4950-
return *Fortran::semantics::GetIntValue(collapseValue);
4953+
collapseLoopCount = *Fortran::semantics::GetIntValue(collapseValue);
4954+
}
4955+
if (const auto *tileClause =
4956+
std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
4957+
const parser::AccTileExprList &tileExprList = tileClause->v;
4958+
const std::list<parser::AccTileExpr> &listTileExpr = tileExprList.v;
4959+
tileLoopCount = listTileExpr.size();
49514960
}
49524961
}
4953-
return 1;
4962+
if (tileLoopCount > collapseLoopCount)
4963+
return tileLoopCount;
4964+
return collapseLoopCount;
49544965
}

flang/test/Lower/OpenACC/acc-kernels-loop.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ subroutine acc_kernels_loop
663663
! CHECK: acc.kernels {{.*}} {
664664
! CHECK: [[TILESIZE1:%.*]] = arith.constant 2 : i32
665665
! CHECK: [[TILESIZE2:%.*]] = arith.constant 2 : i32
666-
! CHECK: acc.loop {{.*}} tile({[[TILESIZE1]] : i32, [[TILESIZE2]] : i32}) {{.*}} {
666+
! CHECK: acc.loop {{.*}} tile({[[TILESIZE1]] : i32, [[TILESIZE2]] : i32}) control(%arg0 : i32, %arg1 : i32) {{.*}} {
667667
! CHECK: acc.yield
668668
! CHECK-NEXT: }{{$}}
669669
! CHECK: acc.terminator
@@ -689,7 +689,7 @@ subroutine acc_kernels_loop
689689
END DO
690690

691691
! CHECK: acc.kernels {{.*}} {
692-
! CHECK: acc.loop {{.*}} tile({%{{.*}} : i32, %{{.*}} : i32}) {{.*}} {
692+
! CHECK: acc.loop {{.*}} tile({%{{.*}} : i32, %{{.*}} : i32}) control(%arg0 : i32, %arg1 : i32) {{.*}} {
693693
! CHECK: acc.yield
694694
! CHECK-NEXT: }{{$}}
695695
! CHECK: acc.terminator

flang/test/Lower/OpenACC/acc-loop.f90

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
program acc_loop
1111

12-
integer :: i, j
12+
integer :: i, j, k
1313
integer, parameter :: n = 10
1414
real, dimension(n) :: a, b
1515
real, dimension(n, n) :: c, d
@@ -209,9 +209,9 @@ program acc_loop
209209

210210
! CHECK: [[TILESIZE1:%.*]] = arith.constant 2 : i32
211211
! CHECK: [[TILESIZE2:%.*]] = arith.constant 2 : i32
212-
! CHECK: acc.loop {{.*}} tile({[[TILESIZE1]] : i32, [[TILESIZE2]] : i32}) control(%arg0 : i32) = (%{{.*}} : i32) to (%{{.*}} : i32) step (%{{.*}} : i32) {
212+
! CHECK: acc.loop {{.*}} tile({[[TILESIZE1]] : i32, [[TILESIZE2]] : i32}) control(%arg0 : i32, %arg1 : i32) = (%{{.*}} : i32, i32) to (%{{.*}} : i32, i32) step (%{{.*}} : i32, i32) {
213213
! CHECK: acc.yield
214-
! CHECK-NEXT: } attributes {inclusiveUpperbound = array<i1: true>, independent = [#acc.device_type<none>]}
214+
! CHECK-NEXT: } attributes {inclusiveUpperbound = array<i1: true, true>, independent = [#acc.device_type<none>]}
215215

216216
!$acc loop tile(tileSize)
217217
DO i = 1, n
@@ -229,9 +229,9 @@ program acc_loop
229229
END DO
230230
END DO
231231

232-
! CHECK: acc.loop {{.*}} tile({%{{.*}} : i32, %{{.*}} : i32}) control(%arg0 : i32) = (%{{.*}} : i32) to (%{{.*}} : i32) step (%{{.*}} : i32) {
232+
! CHECK: acc.loop {{.*}} tile({%{{.*}} : i32, %{{.*}} : i32}) control(%arg0 : i32, %arg1 : i32) = (%{{.*}} : i32, i32) to (%{{.*}} : i32, i32) step (%{{.*}} : i32, i32) {
233233
! CHECK: acc.yield
234-
! CHECK-NEXT: } attributes {inclusiveUpperbound = array<i1: true>, independent = [#acc.device_type<none>]}
234+
! CHECK-NEXT: } attributes {inclusiveUpperbound = array<i1: true, true>, independent = [#acc.device_type<none>]}
235235

236236
!$acc loop collapse(2)
237237
DO i = 1, n
@@ -246,6 +246,51 @@ program acc_loop
246246
! CHECK: acc.yield
247247
! CHECK-NEXT: } attributes {collapse = [2], collapseDeviceType = [#acc.device_type<none>]{{.*}}}
248248

249+
!$acc loop collapse(2) tile(tileSize)
250+
DO i = 1, n
251+
DO j = 1, n
252+
c(i, j) = d(i, j)
253+
END DO
254+
END DO
255+
256+
! CHECK: acc.loop {{.*}} tile({%{{.*}} : i32}) control(%arg0 : i32, %arg1 : i32) = (%{{.*}} : i32, i32) to (%{{.*}} : i32, i32) step (%{{.*}} : i32, i32) {
257+
! CHECK: fir.store %arg0 to %{{.*}} : !fir.ref<i32>
258+
! CHECK: fir.store %arg1 to %{{.*}} : !fir.ref<i32>
259+
! CHECK: acc.yield
260+
! CHECK-NEXT: } attributes {collapse = [2], collapseDeviceType = [#acc.device_type<none>]{{.*}}}
261+
262+
!$acc loop collapse(2) tile(tileSize, tileSize, tileSize)
263+
DO i = 1, n
264+
DO j = 1, n
265+
DO k = 1, n
266+
c(i, j) = d(i, j)
267+
END DO
268+
END DO
269+
END DO
270+
271+
! CHECK: acc.loop {{.*}} tile({%{{.*}} : i32, %{{.*}} : i32, %{{.*}} : i32}) control(%arg0 : i32, %arg1 : i32, %arg2 : i32) = (%{{.*}} : i32, i32, i32) to (%{{.*}} : i32, i32, i32) step (%{{.*}} : i32, i32, i32) {
272+
! CHECK: fir.store %arg0 to %{{.*}} : !fir.ref<i32>
273+
! CHECK: fir.store %arg1 to %{{.*}} : !fir.ref<i32>
274+
! CHECK: fir.store %arg2 to %{{.*}} : !fir.ref<i32>
275+
! CHECK: acc.yield
276+
! CHECK-NEXT: } attributes {collapse = [2], collapseDeviceType = [#acc.device_type<none>]{{.*}}}
277+
278+
!$acc loop collapse(3) tile(tileSize, tileSize)
279+
DO i = 1, n
280+
DO j = 1, n
281+
DO k = 1, n
282+
c(i, j) = d(i, j)
283+
END DO
284+
END DO
285+
END DO
286+
287+
! CHECK: acc.loop {{.*}} tile({%{{.*}} : i32, %{{.*}} : i32}) control(%arg0 : i32, %arg1 : i32, %arg2 : i32) = (%{{.*}} : i32, i32, i32) to (%{{.*}} : i32, i32, i32) step (%{{.*}} : i32, i32, i32) {
288+
! CHECK: fir.store %arg0 to %{{.*}} : !fir.ref<i32>
289+
! CHECK: fir.store %arg1 to %{{.*}} : !fir.ref<i32>
290+
! CHECK: fir.store %arg2 to %{{.*}} : !fir.ref<i32>
291+
! CHECK: acc.yield
292+
! CHECK-NEXT: } attributes {collapse = [3], collapseDeviceType = [#acc.device_type<none>]{{.*}}}
293+
249294
!$acc loop
250295
DO i = 1, n
251296
!$acc loop

flang/test/Lower/OpenACC/acc-parallel-loop.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ subroutine acc_parallel_loop
681681
! CHECK: acc.parallel {{.*}} {
682682
! CHECK: [[TILESIZE1:%.*]] = arith.constant 2 : i32
683683
! CHECK: [[TILESIZE2:%.*]] = arith.constant 2 : i32
684-
! CHECK: acc.loop {{.*}} tile({[[TILESIZE1]] : i32, [[TILESIZE2]] : i32}) {{.*}} {
684+
! CHECK: acc.loop {{.*}} tile({[[TILESIZE1]] : i32, [[TILESIZE2]] : i32}) control(%arg0 : i32, %arg1 : i32) {{.*}} {
685685
! CHECK: acc.yield
686686
! CHECK-NEXT: }{{$}}
687687
! CHECK: acc.yield
@@ -707,7 +707,7 @@ subroutine acc_parallel_loop
707707
END DO
708708

709709
! CHECK: acc.parallel {{.*}} {
710-
! CHECK: acc.loop {{.*}} tile({%{{.*}} : i32, %{{.*}} : i32}) {{.*}} {
710+
! CHECK: acc.loop {{.*}} tile({%{{.*}} : i32, %{{.*}} : i32}) control(%arg0 : i32, %arg1 : i32) {{.*}} {
711711
! CHECK: acc.yield
712712
! CHECK-NEXT: }{{$}}
713713
! CHECK: acc.yield

flang/test/Lower/OpenACC/acc-serial-loop.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ subroutine acc_serial_loop
620620
! CHECK: acc.serial {{.*}} {
621621
! CHECK: [[TILESIZE1:%.*]] = arith.constant 2 : i32
622622
! CHECK: [[TILESIZE2:%.*]] = arith.constant 2 : i32
623-
! CHECK: acc.loop {{.*}} tile({[[TILESIZE1]] : i32, [[TILESIZE2]] : i32}) {{.*}} {
623+
! CHECK: acc.loop {{.*}} tile({[[TILESIZE1]] : i32, [[TILESIZE2]] : i32}) control(%arg0 : i32, %arg1 : i32) {{.*}} {
624624
! CHECK: acc.yield
625625
! CHECK-NEXT: }{{$}}
626626
! CHECK: acc.yield
@@ -646,7 +646,7 @@ subroutine acc_serial_loop
646646
END DO
647647

648648
! CHECK: acc.serial {{.*}} {
649-
! CHECK: acc.loop {{.*}} tile({%{{.*}} : i32, %{{.*}} : i32}) {{.*}} {
649+
! CHECK: acc.loop {{.*}} tile({%{{.*}} : i32, %{{.*}} : i32}) control(%arg0 : i32, %arg1 : i32) {{.*}} {
650650
! CHECK: acc.yield
651651
! CHECK-NEXT: }{{$}}
652652
! CHECK: acc.yield

0 commit comments

Comments
 (0)