25
25
// Codegen rewrite: rewriting of subgraphs of ops
26
26
// ===----------------------------------------------------------------------===//
27
27
28
- using namespace fir ;
29
- using namespace mlir ;
30
-
31
28
#define DEBUG_TYPE " flang-codegen-rewrite"
32
29
33
30
static void populateShape (llvm::SmallVectorImpl<mlir::Value> &vec,
34
- ShapeOp shape) {
31
+ fir:: ShapeOp shape) {
35
32
vec.append (shape.getExtents ().begin (), shape.getExtents ().end ());
36
33
}
37
34
38
35
// Operands of fir.shape_shift split into two vectors.
39
36
static void populateShapeAndShift (llvm::SmallVectorImpl<mlir::Value> &shapeVec,
40
37
llvm::SmallVectorImpl<mlir::Value> &shiftVec,
41
- ShapeShiftOp shift) {
42
- auto endIter = shift.getPairs ().end ();
43
- for ( auto i = shift. getPairs (). begin (); i != endIter;) {
38
+ fir:: ShapeShiftOp shift) {
39
+ for ( auto i = shift. getPairs (). begin (), endIter = shift.getPairs ().end ();
40
+ i != endIter;) {
44
41
shiftVec.push_back (*i++);
45
42
shapeVec.push_back (*i++);
46
43
}
47
44
}
48
45
49
46
static void populateShift (llvm::SmallVectorImpl<mlir::Value> &vec,
50
- ShiftOp shift) {
47
+ fir:: ShiftOp shift) {
51
48
vec.append (shift.getOrigins ().begin (), shift.getOrigins ().end ());
52
49
}
53
50
@@ -72,27 +69,26 @@ namespace {
72
69
// / (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) ->
73
70
// / !fir.box<!fir.array<?xi32>>
74
71
// / ```
75
- class EmboxConversion : public mlir ::OpRewritePattern<EmboxOp> {
72
+ class EmboxConversion : public mlir ::OpRewritePattern<fir:: EmboxOp> {
76
73
public:
77
74
using OpRewritePattern::OpRewritePattern;
78
75
79
76
mlir::LogicalResult
80
- matchAndRewrite (EmboxOp embox,
77
+ matchAndRewrite (fir:: EmboxOp embox,
81
78
mlir::PatternRewriter &rewriter) const override {
82
- auto shapeVal = embox.getShape ();
83
79
// If the embox does not include a shape, then do not convert it
84
- if (shapeVal)
80
+ if (auto shapeVal = embox. getShape () )
85
81
return rewriteDynamicShape (embox, rewriter, shapeVal);
86
- if (auto boxTy = embox.getType ().dyn_cast <BoxType>())
87
- if (auto seqTy = boxTy.getEleTy ().dyn_cast <SequenceType>())
82
+ if (auto boxTy = embox.getType ().dyn_cast <fir:: BoxType>())
83
+ if (auto seqTy = boxTy.getEleTy ().dyn_cast <fir:: SequenceType>())
88
84
if (seqTy.hasConstantShape ())
89
85
return rewriteStaticShape (embox, rewriter, seqTy);
90
86
return mlir::failure ();
91
87
}
92
88
93
- mlir::LogicalResult rewriteStaticShape (EmboxOp embox,
89
+ mlir::LogicalResult rewriteStaticShape (fir:: EmboxOp embox,
94
90
mlir::PatternRewriter &rewriter,
95
- SequenceType seqTy) const {
91
+ fir:: SequenceType seqTy) const {
96
92
auto loc = embox.getLoc ();
97
93
llvm::SmallVector<mlir::Value> shapeOpers;
98
94
auto idxTy = rewriter.getIndexType ();
@@ -101,41 +97,42 @@ class EmboxConversion : public mlir::OpRewritePattern<EmboxOp> {
101
97
auto extVal = rewriter.create <mlir::arith::ConstantOp>(loc, idxTy, iAttr);
102
98
shapeOpers.push_back (extVal);
103
99
}
104
- auto xbox = rewriter.create <cg::XEmboxOp>(
100
+ auto xbox = rewriter.create <fir:: cg::XEmboxOp>(
105
101
loc, embox.getType (), embox.getMemref (), shapeOpers, llvm::None,
106
102
llvm::None, llvm::None, llvm::None, embox.getTypeparams ());
107
103
LLVM_DEBUG (llvm::dbgs () << " rewriting " << embox << " to " << xbox << ' \n ' );
108
104
rewriter.replaceOp (embox, xbox.getOperation ()->getResults ());
109
105
return mlir::success ();
110
106
}
111
107
112
- mlir::LogicalResult rewriteDynamicShape (EmboxOp embox,
108
+ mlir::LogicalResult rewriteDynamicShape (fir:: EmboxOp embox,
113
109
mlir::PatternRewriter &rewriter,
114
110
mlir::Value shapeVal) const {
115
111
auto loc = embox.getLoc ();
116
- auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp ());
117
112
llvm::SmallVector<mlir::Value> shapeOpers;
118
113
llvm::SmallVector<mlir::Value> shiftOpers;
119
- if (shapeOp) {
114
+ if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal. getDefiningOp ()) ) {
120
115
populateShape (shapeOpers, shapeOp);
121
116
} else {
122
- auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp ());
117
+ auto shiftOp =
118
+ mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp ());
123
119
assert (shiftOp && " shape is neither fir.shape nor fir.shape_shift" );
124
120
populateShapeAndShift (shapeOpers, shiftOpers, shiftOp);
125
121
}
126
122
llvm::SmallVector<mlir::Value> sliceOpers;
127
123
llvm::SmallVector<mlir::Value> subcompOpers;
128
124
llvm::SmallVector<mlir::Value> substrOpers;
129
125
if (auto s = embox.getSlice ())
130
- if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp ())) {
126
+ if (auto sliceOp =
127
+ mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp ())) {
131
128
sliceOpers.assign (sliceOp.getTriples ().begin (),
132
129
sliceOp.getTriples ().end ());
133
130
subcompOpers.assign (sliceOp.getFields ().begin (),
134
131
sliceOp.getFields ().end ());
135
132
substrOpers.assign (sliceOp.getSubstr ().begin (),
136
133
sliceOp.getSubstr ().end ());
137
134
}
138
- auto xbox = rewriter.create <cg::XEmboxOp>(
135
+ auto xbox = rewriter.create <fir:: cg::XEmboxOp>(
139
136
loc, embox.getType (), embox.getMemref (), shapeOpers, shiftOpers,
140
137
sliceOpers, subcompOpers, substrOpers, embox.getTypeparams ());
141
138
LLVM_DEBUG (llvm::dbgs () << " rewriting " << embox << " to " << xbox << ' \n ' );
@@ -156,22 +153,24 @@ class EmboxConversion : public mlir::OpRewritePattern<EmboxOp> {
156
153
// / %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>,
157
154
// / index, index) -> !fir.box<!fir.array<?xi32>>
158
155
// / ```
159
- class ReboxConversion : public mlir ::OpRewritePattern<ReboxOp> {
156
+ class ReboxConversion : public mlir ::OpRewritePattern<fir:: ReboxOp> {
160
157
public:
161
158
using OpRewritePattern::OpRewritePattern;
162
159
163
160
mlir::LogicalResult
164
- matchAndRewrite (ReboxOp rebox,
161
+ matchAndRewrite (fir:: ReboxOp rebox,
165
162
mlir::PatternRewriter &rewriter) const override {
166
163
auto loc = rebox.getLoc ();
167
164
llvm::SmallVector<mlir::Value> shapeOpers;
168
165
llvm::SmallVector<mlir::Value> shiftOpers;
169
166
if (auto shapeVal = rebox.getShape ()) {
170
- if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp ()))
167
+ if (auto shapeOp = mlir:: dyn_cast<fir:: ShapeOp>(shapeVal.getDefiningOp ()))
171
168
populateShape (shapeOpers, shapeOp);
172
- else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp ()))
169
+ else if (auto shiftOp =
170
+ mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp ()))
173
171
populateShapeAndShift (shapeOpers, shiftOpers, shiftOp);
174
- else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp ()))
172
+ else if (auto shiftOp =
173
+ mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp ()))
175
174
populateShift (shiftOpers, shiftOp);
176
175
else
177
176
return mlir::failure ();
@@ -180,7 +179,8 @@ class ReboxConversion : public mlir::OpRewritePattern<ReboxOp> {
180
179
llvm::SmallVector<mlir::Value> subcompOpers;
181
180
llvm::SmallVector<mlir::Value> substrOpers;
182
181
if (auto s = rebox.getSlice ())
183
- if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp ())) {
182
+ if (auto sliceOp =
183
+ mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp ())) {
184
184
sliceOpers.append (sliceOp.getTriples ().begin (),
185
185
sliceOp.getTriples ().end ());
186
186
subcompOpers.append (sliceOp.getFields ().begin (),
@@ -189,7 +189,7 @@ class ReboxConversion : public mlir::OpRewritePattern<ReboxOp> {
189
189
sliceOp.getSubstr ().end ());
190
190
}
191
191
192
- auto xRebox = rewriter.create <cg::XReboxOp>(
192
+ auto xRebox = rewriter.create <fir:: cg::XReboxOp>(
193
193
loc, rebox.getType (), rebox.getBox (), shapeOpers, shiftOpers,
194
194
sliceOpers, subcompOpers, substrOpers);
195
195
LLVM_DEBUG (llvm::dbgs ()
@@ -212,30 +212,33 @@ class ReboxConversion : public mlir::OpRewritePattern<ReboxOp> {
212
212
// / (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) ->
213
213
// / !fir.ref<i32>
214
214
// / ```
215
- class ArrayCoorConversion : public mlir ::OpRewritePattern<ArrayCoorOp> {
215
+ class ArrayCoorConversion : public mlir ::OpRewritePattern<fir:: ArrayCoorOp> {
216
216
public:
217
217
using OpRewritePattern::OpRewritePattern;
218
218
219
219
mlir::LogicalResult
220
- matchAndRewrite (ArrayCoorOp arrCoor,
220
+ matchAndRewrite (fir:: ArrayCoorOp arrCoor,
221
221
mlir::PatternRewriter &rewriter) const override {
222
222
auto loc = arrCoor.getLoc ();
223
223
llvm::SmallVector<mlir::Value> shapeOpers;
224
224
llvm::SmallVector<mlir::Value> shiftOpers;
225
225
if (auto shapeVal = arrCoor.getShape ()) {
226
- if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp ()))
226
+ if (auto shapeOp = mlir:: dyn_cast<fir:: ShapeOp>(shapeVal.getDefiningOp ()))
227
227
populateShape (shapeOpers, shapeOp);
228
- else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp ()))
228
+ else if (auto shiftOp =
229
+ mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp ()))
229
230
populateShapeAndShift (shapeOpers, shiftOpers, shiftOp);
230
- else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp ()))
231
+ else if (auto shiftOp =
232
+ mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp ()))
231
233
populateShift (shiftOpers, shiftOp);
232
234
else
233
235
return mlir::failure ();
234
236
}
235
237
llvm::SmallVector<mlir::Value> sliceOpers;
236
238
llvm::SmallVector<mlir::Value> subcompOpers;
237
239
if (auto s = arrCoor.getSlice ())
238
- if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp ())) {
240
+ if (auto sliceOp =
241
+ mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp ())) {
239
242
sliceOpers.append (sliceOp.getTriples ().begin (),
240
243
sliceOp.getTriples ().end ());
241
244
subcompOpers.append (sliceOp.getFields ().begin (),
@@ -244,7 +247,7 @@ class ArrayCoorConversion : public mlir::OpRewritePattern<ArrayCoorOp> {
244
247
" Don't allow substring operations on array_coor. This "
245
248
" restriction may be lifted in the future." );
246
249
}
247
- auto xArrCoor = rewriter.create <cg::XArrayCoorOp>(
250
+ auto xArrCoor = rewriter.create <fir:: cg::XArrayCoorOp>(
248
251
loc, arrCoor.getType (), arrCoor.getMemref (), shapeOpers, shiftOpers,
249
252
sliceOpers, subcompOpers, arrCoor.getIndices (),
250
253
arrCoor.getTypeparams ());
@@ -255,20 +258,22 @@ class ArrayCoorConversion : public mlir::OpRewritePattern<ArrayCoorOp> {
255
258
}
256
259
};
257
260
258
- class CodeGenRewrite : public CodeGenRewriteBase <CodeGenRewrite> {
261
+ class CodeGenRewrite : public fir :: CodeGenRewriteBase<CodeGenRewrite> {
259
262
public:
260
263
void runOnOperation () override final {
261
264
auto op = getOperation ();
262
265
auto &context = getContext ();
263
266
mlir::OpBuilder rewriter (&context);
264
267
mlir::ConversionTarget target (context);
265
- target.addLegalDialect <mlir::arith::ArithmeticDialect, FIROpsDialect,
266
- FIRCodeGenDialect, mlir::func::FuncDialect>();
267
- target.addIllegalOp <ArrayCoorOp>();
268
- target.addIllegalOp <ReboxOp>();
269
- target.addDynamicallyLegalOp <EmboxOp>([](EmboxOp embox) {
270
- return !(embox.getShape () ||
271
- embox.getType ().cast <BoxType>().getEleTy ().isa <SequenceType>());
268
+ target.addLegalDialect <mlir::arith::ArithmeticDialect, fir::FIROpsDialect,
269
+ fir::FIRCodeGenDialect, mlir::func::FuncDialect>();
270
+ target.addIllegalOp <fir::ArrayCoorOp>();
271
+ target.addIllegalOp <fir::ReboxOp>();
272
+ target.addDynamicallyLegalOp <fir::EmboxOp>([](fir::EmboxOp embox) {
273
+ return !(embox.getShape () || embox.getType ()
274
+ .cast <fir::BoxType>()
275
+ .getEleTy ()
276
+ .isa <fir::SequenceType>());
272
277
});
273
278
mlir::RewritePatternSet patterns (&context);
274
279
patterns.insert <EmboxConversion, ArrayCoorConversion, ReboxConversion>(
0 commit comments