Skip to content

Commit 25c06f3

Browse files
committed
[mlir][core] update builder create API
1 parent 1e57cb8 commit 25c06f3

File tree

349 files changed

+6988
-6988
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

349 files changed

+6988
-6988
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 88 additions & 88 deletions
Large diffs are not rendered by default.

mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ static Value buildMinMaxReductionSeq(Location loc,
5050
Value value = *valueIt++;
5151
for (; valueIt != values.end(); ++valueIt) {
5252
if (predicate == arith::CmpIPredicate::sgt)
53-
value = builder.create<arith::MaxSIOp>(loc, value, *valueIt);
53+
value = arith::MaxSIOp::create(builder, loc, value, *valueIt);
5454
else
55-
value = builder.create<arith::MinSIOp>(loc, value, *valueIt);
55+
value = arith::MinSIOp::create(builder, loc, value, *valueIt);
5656
}
5757

5858
return value;
@@ -154,8 +154,8 @@ class AffineForLowering : public OpRewritePattern<AffineForOp> {
154154
Value lowerBound = lowerAffineLowerBound(op, rewriter);
155155
Value upperBound = lowerAffineUpperBound(op, rewriter);
156156
Value step =
157-
rewriter.create<arith::ConstantIndexOp>(loc, op.getStepAsInt());
158-
auto scfForOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound,
157+
arith::ConstantIndexOp::create(rewriter, loc, op.getStepAsInt());
158+
auto scfForOp = scf::ForOp::create(rewriter, loc, lowerBound, upperBound,
159159
step, op.getInits());
160160
rewriter.eraseBlock(scfForOp.getBody());
161161
rewriter.inlineRegionBefore(op.getRegion(), scfForOp.getRegion(),
@@ -197,15 +197,15 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
197197
}
198198
steps.reserve(op.getSteps().size());
199199
for (int64_t step : op.getSteps())
200-
steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));
200+
steps.push_back(arith::ConstantIndexOp::create(rewriter, loc, step));
201201

202202
// Get the terminator op.
203203
auto affineParOpTerminator =
204204
cast<AffineYieldOp>(op.getBody()->getTerminator());
205205
scf::ParallelOp parOp;
206206
if (op.getResults().empty()) {
207207
// Case with no reduction operations/return values.
208-
parOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple,
208+
parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple,
209209
upperBoundTuple, steps,
210210
/*bodyBuilderFn=*/nullptr);
211211
rewriter.eraseBlock(parOp.getBody());
@@ -233,7 +233,7 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
233233
identityVals.push_back(
234234
arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc));
235235
}
236-
parOp = rewriter.create<scf::ParallelOp>(
236+
parOp = scf::ParallelOp::create(rewriter,
237237
loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,
238238
/*bodyBuilderFn=*/nullptr);
239239

@@ -261,7 +261,7 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
261261
Value reductionResult = arith::getReductionOp(
262262
reductionOpValue, rewriter, loc, reductionBody.getArgument(0),
263263
reductionBody.getArgument(1));
264-
rewriter.create<scf::ReduceReturnOp>(loc, reductionResult);
264+
scf::ReduceReturnOp::create(rewriter, loc, reductionResult);
265265
}
266266
rewriter.replaceOp(op, parOp.getResults());
267267
return success();
@@ -278,7 +278,7 @@ class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
278278

279279
// Now we just have to handle the condition logic.
280280
auto integerSet = op.getIntegerSet();
281-
Value zeroConstant = rewriter.create<arith::ConstantIndexOp>(loc, 0);
281+
Value zeroConstant = arith::ConstantIndexOp::create(rewriter, loc, 0);
282282
SmallVector<Value, 8> operands(op.getOperands());
283283
auto operandsRef = llvm::ArrayRef(operands);
284284

@@ -298,17 +298,17 @@ class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
298298
auto pred =
299299
isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge;
300300
Value cmpVal =
301-
rewriter.create<arith::CmpIOp>(loc, pred, affResult, zeroConstant);
301+
arith::CmpIOp::create(rewriter, loc, pred, affResult, zeroConstant);
302302
cond = cond
303-
? rewriter.create<arith::AndIOp>(loc, cond, cmpVal).getResult()
303+
? arith::AndIOp::create(rewriter, loc, cond, cmpVal).getResult()
304304
: cmpVal;
305305
}
306306
cond = cond ? cond
307-
: rewriter.create<arith::ConstantIntOp>(loc, /*value=*/1,
307+
: arith::ConstantIntOp::create(rewriter, loc, /*value=*/1,
308308
/*width=*/1);
309309

310310
bool hasElseRegion = !op.getElseRegion().empty();
311-
auto ifOp = rewriter.create<scf::IfOp>(loc, op.getResultTypes(), cond,
311+
auto ifOp = scf::IfOp::create(rewriter, loc, op.getResultTypes(), cond,
312312
hasElseRegion);
313313
rewriter.inlineRegionBefore(op.getThenRegion(),
314314
&ifOp.getThenRegion().back());

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ static Value castF32To(Type desType, Value f32, Location loc,
8989
if (elementType.isF32())
9090
return f32;
9191
if (elementType.getIntOrFloatBitWidth() < 32)
92-
return rewriter.create<arith::TruncFOp>(loc, desType, f32);
92+
return arith::TruncFOp::create(rewriter, loc, desType, f32);
9393
if (elementType.getIntOrFloatBitWidth() > 32)
94-
return rewriter.create<arith::ExtFOp>(loc, desType, f32);
94+
return arith::ExtFOp::create(rewriter, loc, desType, f32);
9595
llvm_unreachable("The only 32-bit float type is f32");
9696
}
9797

@@ -113,26 +113,26 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
113113
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
114114
VectorType extResType = VectorType::get(2, rewriter.getF32Type());
115115
if (!inVecType) {
116-
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
116+
Value asFloat = amdgpu::ExtPackedFp8Op::create(rewriter,
117117
loc, rewriter.getF32Type(), in, 0);
118118
Value result = castF32To(outElemType, asFloat, loc, rewriter);
119119
rewriter.replaceOp(op, result);
120120
return success();
121121
}
122122
int64_t numElements = inVecType.getNumElements();
123123

124-
Value zero = rewriter.create<arith::ConstantOp>(
124+
Value zero = arith::ConstantOp::create(rewriter,
125125
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
126126
VectorType outType = cast<VectorType>(op.getOut().getType());
127127

128128
if (inVecType.getShape().empty()) {
129129
Value zerodSplat =
130130
rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
131131
Value scalarIn =
132-
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
132+
vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
133133
Value scalarExt =
134-
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
135-
Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat,
134+
arith::ExtFOp::create(rewriter, loc, outElemType, scalarIn);
135+
Value result = vector::InsertOp::create(rewriter, loc, scalarExt, zerodSplat,
136136
ArrayRef<int64_t>{});
137137
rewriter.replaceOp(op, result);
138138
return success();
@@ -145,32 +145,32 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
145145
if (inVecType.getRank() > 1) {
146146
inVecType = VectorType::get(SmallVector<int64_t>{numElements},
147147
inVecType.getElementType());
148-
in = rewriter.create<vector::ShapeCastOp>(loc, inVecType, in);
148+
in = vector::ShapeCastOp::create(rewriter, loc, inVecType, in);
149149
}
150150

151151
for (int64_t i = 0; i < numElements; i += 4) {
152152
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
153-
Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
153+
Value inSlice = vector::ExtractStridedSliceOp::create(rewriter,
154154
loc, in, i, elemsThisOp, 1);
155155
for (int64_t j = 0; j < elemsThisOp; j += 2) {
156156
if (i + j + 1 < numElements) { // Convert two 8-bit elements
157-
Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>(
157+
Value asFloats = amdgpu::ExtPackedFp8Op::create(rewriter,
158158
loc, extResType, inSlice, j / 2);
159159
Type desType = VectorType::get(2, outElemType);
160160
Value asType = castF32To(desType, asFloats, loc, rewriter);
161-
result = rewriter.create<vector::InsertStridedSliceOp>(
161+
result = vector::InsertStridedSliceOp::create(rewriter,
162162
loc, asType, result, i + j, 1);
163163
} else { // Convert a 8-bit element
164-
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
164+
Value asFloat = amdgpu::ExtPackedFp8Op::create(rewriter,
165165
loc, rewriter.getF32Type(), inSlice, j / 2 * 2);
166166
Value asType = castF32To(outElemType, asFloat, loc, rewriter);
167-
result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
167+
result = vector::InsertOp::create(rewriter, loc, asType, result, i + j);
168168
}
169169
}
170170
}
171171

172172
if (inVecType.getRank() != outType.getRank()) {
173-
result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
173+
result = vector::ShapeCastOp::create(rewriter, loc, outType, result);
174174
}
175175

176176
rewriter.replaceOp(op, result);
@@ -182,9 +182,9 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
182182
if (type.isF32())
183183
return value;
184184
if (type.getIntOrFloatBitWidth() < 32)
185-
return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
185+
return arith::ExtFOp::create(rewriter, loc, rewriter.getF32Type(), value);
186186
if (type.getIntOrFloatBitWidth() > 32)
187-
return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
187+
return arith::TruncFOp::create(rewriter, loc, rewriter.getF32Type(), value);
188188
llvm_unreachable("The only 32-bit float type is f32");
189189
}
190190

@@ -224,13 +224,13 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
224224
loc, arith::CmpFPredicate::OEQ, source, negInf);
225225
Value isNan = rewriter.createOrFold<arith::CmpFOp>(
226226
loc, arith::CmpFPredicate::UNO, source, source);
227-
Value isNonFinite = rewriter.create<arith::OrIOp>(
228-
loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
227+
Value isNonFinite = arith::OrIOp::create(rewriter,
228+
loc, arith::OrIOp::create(rewriter, loc, isInf, isNegInf), isNan);
229229

230-
Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
231-
Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
230+
Value clampedBelow = arith::MaximumFOp::create(rewriter, loc, source, minCst);
231+
Value clamped = arith::MinimumFOp::create(rewriter, loc, clampedBelow, maxCst);
232232
Value res =
233-
rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
233+
arith::SelectOp::create(rewriter, loc, isNonFinite, source, clamped);
234234
return res;
235235
}
236236

@@ -264,24 +264,24 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
264264
VectorType truncResType = VectorType::get(4, outElemType);
265265
if (!inVectorTy) {
266266
Value asFloat = castToF32(in, loc, rewriter);
267-
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
267+
Value asF8s = amdgpu::PackedTrunc2xFp8Op::create(rewriter,
268268
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
269269
/*existing=*/nullptr);
270-
Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
270+
Value result = vector::ExtractOp::create(rewriter, loc, asF8s, 0);
271271
rewriter.replaceOp(op, result);
272272
return success();
273273
}
274274

275275
int64_t numElements = outVecType.getNumElements();
276-
Value zero = rewriter.create<arith::ConstantOp>(
276+
Value zero = arith::ConstantOp::create(rewriter,
277277
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
278278
if (outVecType.getShape().empty()) {
279279
Value scalarIn =
280-
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
280+
vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
281281
// Recurse to send the 0-D vector case to the 1-D vector case
282282
Value scalarTrunc =
283-
rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
284-
Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
283+
arith::TruncFOp::create(rewriter, loc, outElemType, scalarIn);
284+
Value result = vector::InsertOp::create(rewriter, loc, scalarTrunc, zero,
285285
ArrayRef<int64_t>{});
286286
rewriter.replaceOp(op, result);
287287
return success();
@@ -294,32 +294,32 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
294294
if (inVectorTy.getRank() > 1) {
295295
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
296296
inVectorTy.getElementType());
297-
in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
297+
in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
298298
}
299299

300300
for (int64_t i = 0; i < numElements; i += 4) {
301301
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
302302
Value thisResult = nullptr;
303303
for (int64_t j = 0; j < elemsThisOp; j += 2) {
304-
Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
304+
Value elemA = vector::ExtractOp::create(rewriter, loc, in, i + j);
305305
Value asFloatA = castToF32(elemA, loc, rewriter);
306306
Value asFloatB = nullptr;
307307
if (j + 1 < elemsThisOp) {
308-
Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
308+
Value elemB = vector::ExtractOp::create(rewriter, loc, in, i + j + 1);
309309
asFloatB = castToF32(elemB, loc, rewriter);
310310
}
311-
thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
311+
thisResult = amdgpu::PackedTrunc2xFp8Op::create(rewriter,
312312
loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
313313
}
314314
if (elemsThisOp < 4)
315-
thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
315+
thisResult = vector::ExtractStridedSliceOp::create(rewriter,
316316
loc, thisResult, 0, elemsThisOp, 1);
317-
result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
317+
result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
318318
result, i, 1);
319319
}
320320

321321
if (inVectorTy.getRank() != outVecType.getRank()) {
322-
result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
322+
result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
323323
}
324324

325325
rewriter.replaceOp(op, result);
@@ -347,10 +347,10 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
347347

348348
// Handle the case where input type is not a vector type
349349
if (!inVectorTy) {
350-
auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
350+
auto sourceB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());
351351
Value asF16s =
352-
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
353-
Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
352+
ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, in, sourceB);
353+
Value result = vector::ExtractOp::create(rewriter, loc, asF16s, 0);
354354
rewriter.replaceOp(op, result);
355355
return success();
356356
}
@@ -362,33 +362,33 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
362362
if (inVectorTy.getRank() > 1) {
363363
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
364364
inVectorTy.getElementType());
365-
in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
365+
in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
366366
}
367367

368368
// Handle the vector case. We also handle the (uncommon) case where the vector
369369
// length is odd
370370
for (int64_t i = 0; i < numElements; i += 2) {
371371
int64_t elemsThisOp = std::min(numElements, i + 2) - i;
372372
Value thisResult = nullptr;
373-
Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
374-
Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
373+
Value elemA = vector::ExtractOp::create(rewriter, loc, in, i);
374+
Value elemB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());
375375

376376
if (elemsThisOp == 2) {
377-
elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
377+
elemB = vector::ExtractOp::create(rewriter, loc, in, i + 1);
378378
}
379379

380380
thisResult =
381-
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
381+
ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, elemA, elemB);
382382
// Place back the truncated result into the possibly larger vector. If we
383383
// are operating on a size 2 vector, these operations should be folded away
384-
thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
384+
thisResult = vector::ExtractStridedSliceOp::create(rewriter,
385385
loc, thisResult, 0, elemsThisOp, 1);
386-
result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
386+
result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
387387
result, i, 1);
388388
}
389389

390390
if (inVectorTy.getRank() != outVecType.getRank()) {
391-
result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
391+
result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
392392
}
393393

394394
rewriter.replaceOp(op, result);

mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,14 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
7474
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
7575
auto denseAttr1D = DenseElementsAttr::get(
7676
tileSliceType, denseAttr.getSplatValue<Attribute>());
77-
auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
77+
auto constantOp1D = arith::ConstantOp::create(rewriter, loc, denseAttr1D);
7878

79-
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
79+
auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
8080
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
8181
Value currentTile) {
8282
// Create 'arm_sme.insert_tile_slice' to write vector to tile
8383
// slice.
84-
auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
84+
auto nextTile = arm_sme::InsertTileSliceOp::create(b,
8585
loc, tileType, constantOp1D, currentTile, tileSliceIndex);
8686
return nextTile.getResult();
8787
};

0 commit comments

Comments
 (0)