|
21 | 21 | #include "flang/Optimizer/HLFIR/HLFIROps.h"
|
22 | 22 | #include "flang/Optimizer/HLFIR/Passes.h"
|
23 | 23 | #include "flang/Optimizer/OpenMP/Passes.h"
|
| 24 | +#include "flang/Optimizer/Support/Utils.h" |
24 | 25 | #include "flang/Optimizer/Transforms/Utils.h"
|
25 | 26 | #include "mlir/Dialect/Func/IR/FuncOps.h"
|
26 | 27 | #include "mlir/IR/Dominance.h"
|
@@ -786,13 +787,55 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
|
786 | 787 | mlir::Value shape = hlfir::genShape(loc, builder, lhs);
|
787 | 788 | llvm::SmallVector<mlir::Value> extents =
|
788 | 789 | hlfir::getIndexExtents(loc, builder, shape);
|
789 |
| - hlfir::LoopNest loopNest = |
790 |
| - hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true, |
791 |
| - flangomp::shouldUseWorkshareLowering(assign)); |
792 |
| - builder.setInsertionPointToStart(loopNest.body); |
793 |
| - auto arrayElement = |
794 |
| - hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices); |
795 |
| - builder.create<hlfir::AssignOp>(loc, rhs, arrayElement); |
| 790 | + |
| 791 | + if (lhs.isSimplyContiguous() && extents.size() > 1) { |
| 792 | + // Flatten the array to use a single assign loop, that can be better |
| 793 | + // optimized. |
| 794 | + mlir::Value n = extents[0]; |
| 795 | + for (size_t i = 1; i < extents.size(); ++i) |
| 796 | + n = builder.create<mlir::arith::MulIOp>(loc, n, extents[i]); |
| 797 | + llvm::SmallVector<mlir::Value> flatExtents = {n}; |
| 798 | + |
| 799 | + mlir::Type flatArrayType; |
| 800 | + mlir::Value flatArray = lhs.getBase(); |
| 801 | + if (mlir::isa<fir::BoxType>(lhs.getType())) { |
| 802 | + shape = builder.genShape(loc, flatExtents); |
| 803 | + flatArrayType = fir::BoxType::get(fir::SequenceType::get(eleTy, 1)); |
| 804 | + flatArray = builder.create<fir::ReboxOp>(loc, flatArrayType, flatArray, |
| 805 | + shape, /*slice=*/mlir::Value{}); |
| 806 | + } else { |
| 807 | + // Array references must have fixed shape, when used in assignments. |
| 808 | + int64_t flatExtent = 1; |
| 809 | + for (const mlir::Value &extent : extents) { |
| 810 | + mlir::Operation *op = extent.getDefiningOp(); |
| 811 | + assert(op && "no defining operation for constant array extent"); |
| 812 | + flatExtent *= fir::toInt(mlir::cast<mlir::arith::ConstantOp>(*op)); |
| 813 | + } |
| 814 | + |
| 815 | + flatArrayType = |
| 816 | + fir::ReferenceType::get(fir::SequenceType::get({flatExtent}, eleTy)); |
| 817 | + flatArray = builder.createConvert(loc, flatArrayType, flatArray); |
| 818 | + } |
| 819 | + |
| 820 | + hlfir::LoopNest loopNest = |
| 821 | + hlfir::genLoopNest(loc, builder, flatExtents, /*isUnordered=*/true, |
| 822 | + flangomp::shouldUseWorkshareLowering(assign)); |
| 823 | + builder.setInsertionPointToStart(loopNest.body); |
| 824 | + |
| 825 | + mlir::Value arrayElement = |
| 826 | + builder.create<hlfir::DesignateOp>(loc, fir::ReferenceType::get(eleTy), |
| 827 | + flatArray, loopNest.oneBasedIndices); |
| 828 | + builder.create<hlfir::AssignOp>(loc, rhs, arrayElement); |
| 829 | + } else { |
| 830 | + hlfir::LoopNest loopNest = |
| 831 | + hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true, |
| 832 | + flangomp::shouldUseWorkshareLowering(assign)); |
| 833 | + builder.setInsertionPointToStart(loopNest.body); |
| 834 | + auto arrayElement = |
| 835 | + hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices); |
| 836 | + builder.create<hlfir::AssignOp>(loc, rhs, arrayElement); |
| 837 | + } |
| 838 | + |
796 | 839 | rewriter.eraseOp(assign);
|
797 | 840 | return mlir::success();
|
798 | 841 | }
|
|
0 commit comments