Skip to content

Commit d397885

Browse files
committed
[mlir][core] update builder create API
1 parent c62a6e4 commit d397885

File tree

349 files changed

+6989
-6987
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

349 files changed

+6989
-6987
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 88 additions & 88 deletions
Large diffs are not rendered by default.

mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ static Value buildMinMaxReductionSeq(Location loc,
5050
Value value = *valueIt++;
5151
for (; valueIt != values.end(); ++valueIt) {
5252
if (predicate == arith::CmpIPredicate::sgt)
53-
value = builder.create<arith::MaxSIOp>(loc, value, *valueIt);
53+
value = arith::MaxSIOp::create(builder, loc, value, *valueIt);
5454
else
55-
value = builder.create<arith::MinSIOp>(loc, value, *valueIt);
55+
value = arith::MinSIOp::create(builder, loc, value, *valueIt);
5656
}
5757

5858
return value;
@@ -154,8 +154,8 @@ class AffineForLowering : public OpRewritePattern<AffineForOp> {
154154
Value lowerBound = lowerAffineLowerBound(op, rewriter);
155155
Value upperBound = lowerAffineUpperBound(op, rewriter);
156156
Value step =
157-
rewriter.create<arith::ConstantIndexOp>(loc, op.getStepAsInt());
158-
auto scfForOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound,
157+
arith::ConstantIndexOp::create(rewriter, loc, op.getStepAsInt());
158+
auto scfForOp = scf::ForOp::create(rewriter, loc, lowerBound, upperBound,
159159
step, op.getInits());
160160
rewriter.eraseBlock(scfForOp.getBody());
161161
rewriter.inlineRegionBefore(op.getRegion(), scfForOp.getRegion(),
@@ -197,15 +197,15 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
197197
}
198198
steps.reserve(op.getSteps().size());
199199
for (int64_t step : op.getSteps())
200-
steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));
200+
steps.push_back(arith::ConstantIndexOp::create(rewriter, loc, step));
201201

202202
// Get the terminator op.
203203
auto affineParOpTerminator =
204204
cast<AffineYieldOp>(op.getBody()->getTerminator());
205205
scf::ParallelOp parOp;
206206
if (op.getResults().empty()) {
207207
// Case with no reduction operations/return values.
208-
parOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple,
208+
parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple,
209209
upperBoundTuple, steps,
210210
/*bodyBuilderFn=*/nullptr);
211211
rewriter.eraseBlock(parOp.getBody());
@@ -233,7 +233,7 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
233233
identityVals.push_back(
234234
arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc));
235235
}
236-
parOp = rewriter.create<scf::ParallelOp>(
236+
parOp = scf::ParallelOp::create(rewriter,
237237
loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,
238238
/*bodyBuilderFn=*/nullptr);
239239

@@ -261,7 +261,7 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
261261
Value reductionResult = arith::getReductionOp(
262262
reductionOpValue, rewriter, loc, reductionBody.getArgument(0),
263263
reductionBody.getArgument(1));
264-
rewriter.create<scf::ReduceReturnOp>(loc, reductionResult);
264+
scf::ReduceReturnOp::create(rewriter, loc, reductionResult);
265265
}
266266
rewriter.replaceOp(op, parOp.getResults());
267267
return success();
@@ -278,7 +278,7 @@ class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
278278

279279
// Now we just have to handle the condition logic.
280280
auto integerSet = op.getIntegerSet();
281-
Value zeroConstant = rewriter.create<arith::ConstantIndexOp>(loc, 0);
281+
Value zeroConstant = arith::ConstantIndexOp::create(rewriter, loc, 0);
282282
SmallVector<Value, 8> operands(op.getOperands());
283283
auto operandsRef = llvm::ArrayRef(operands);
284284

@@ -298,17 +298,17 @@ class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
298298
auto pred =
299299
isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge;
300300
Value cmpVal =
301-
rewriter.create<arith::CmpIOp>(loc, pred, affResult, zeroConstant);
301+
arith::CmpIOp::create(rewriter, loc, pred, affResult, zeroConstant);
302302
cond = cond
303-
? rewriter.create<arith::AndIOp>(loc, cond, cmpVal).getResult()
303+
? arith::AndIOp::create(rewriter, loc, cond, cmpVal).getResult()
304304
: cmpVal;
305305
}
306306
cond = cond ? cond
307-
: rewriter.create<arith::ConstantIntOp>(loc, /*value=*/1,
307+
: arith::ConstantIntOp::create(rewriter, loc, /*value=*/1,
308308
/*width=*/1);
309309

310310
bool hasElseRegion = !op.getElseRegion().empty();
311-
auto ifOp = rewriter.create<scf::IfOp>(loc, op.getResultTypes(), cond,
311+
auto ifOp = scf::IfOp::create(rewriter, loc, op.getResultTypes(), cond,
312312
hasElseRegion);
313313
rewriter.inlineRegionBefore(op.getThenRegion(),
314314
&ifOp.getThenRegion().back());

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ static Value castF32To(Type desType, Value f32, Location loc,
115115
if (elementType.isF32())
116116
return f32;
117117
if (elementType.getIntOrFloatBitWidth() < 32)
118-
return rewriter.create<arith::TruncFOp>(loc, desType, f32);
118+
return arith::TruncFOp::create(rewriter, loc, desType, f32);
119119
if (elementType.getIntOrFloatBitWidth() > 32)
120-
return rewriter.create<arith::ExtFOp>(loc, desType, f32);
120+
return arith::ExtFOp::create(rewriter, loc, desType, f32);
121121
llvm_unreachable("The only 32-bit float type is f32");
122122
}
123123

@@ -139,26 +139,26 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
139139
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
140140
VectorType extResType = VectorType::get(2, rewriter.getF32Type());
141141
if (!inVecType) {
142-
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
142+
Value asFloat = amdgpu::ExtPackedFp8Op::create(rewriter,
143143
loc, rewriter.getF32Type(), in, 0);
144144
Value result = castF32To(outElemType, asFloat, loc, rewriter);
145145
rewriter.replaceOp(op, result);
146146
return success();
147147
}
148148
int64_t numElements = inVecType.getNumElements();
149149

150-
Value zero = rewriter.create<arith::ConstantOp>(
150+
Value zero = arith::ConstantOp::create(rewriter,
151151
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
152152
VectorType outType = cast<VectorType>(op.getOut().getType());
153153

154154
if (inVecType.getShape().empty()) {
155155
Value zerodSplat =
156156
rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
157157
Value scalarIn =
158-
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
158+
vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
159159
Value scalarExt =
160-
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
161-
Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat,
160+
arith::ExtFOp::create(rewriter, loc, outElemType, scalarIn);
161+
Value result = vector::InsertOp::create(rewriter, loc, scalarExt, zerodSplat,
162162
ArrayRef<int64_t>{});
163163
rewriter.replaceOp(op, result);
164164
return success();
@@ -171,32 +171,32 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
171171
if (inVecType.getRank() > 1) {
172172
inVecType = VectorType::get(SmallVector<int64_t>{numElements},
173173
inVecType.getElementType());
174-
in = rewriter.create<vector::ShapeCastOp>(loc, inVecType, in);
174+
in = vector::ShapeCastOp::create(rewriter, loc, inVecType, in);
175175
}
176176

177177
for (int64_t i = 0; i < numElements; i += 4) {
178178
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
179-
Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
179+
Value inSlice = vector::ExtractStridedSliceOp::create(rewriter,
180180
loc, in, i, elemsThisOp, 1);
181181
for (int64_t j = 0; j < elemsThisOp; j += 2) {
182182
if (i + j + 1 < numElements) { // Convert two 8-bit elements
183-
Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>(
183+
Value asFloats = amdgpu::ExtPackedFp8Op::create(rewriter,
184184
loc, extResType, inSlice, j / 2);
185185
Type desType = VectorType::get(2, outElemType);
186186
Value asType = castF32To(desType, asFloats, loc, rewriter);
187-
result = rewriter.create<vector::InsertStridedSliceOp>(
187+
result = vector::InsertStridedSliceOp::create(rewriter,
188188
loc, asType, result, i + j, 1);
189189
} else { // Convert a 8-bit element
190-
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
190+
Value asFloat = amdgpu::ExtPackedFp8Op::create(rewriter,
191191
loc, rewriter.getF32Type(), inSlice, j / 2 * 2);
192192
Value asType = castF32To(outElemType, asFloat, loc, rewriter);
193-
result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
193+
result = vector::InsertOp::create(rewriter, loc, asType, result, i + j);
194194
}
195195
}
196196
}
197197

198198
if (inVecType.getRank() != outType.getRank()) {
199-
result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
199+
result = vector::ShapeCastOp::create(rewriter, loc, outType, result);
200200
}
201201

202202
rewriter.replaceOp(op, result);
@@ -208,9 +208,9 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
208208
if (type.isF32())
209209
return value;
210210
if (type.getIntOrFloatBitWidth() < 32)
211-
return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
211+
return arith::ExtFOp::create(rewriter, loc, rewriter.getF32Type(), value);
212212
if (type.getIntOrFloatBitWidth() > 32)
213-
return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
213+
return arith::TruncFOp::create(rewriter, loc, rewriter.getF32Type(), value);
214214
llvm_unreachable("The only 32-bit float type is f32");
215215
}
216216

@@ -250,13 +250,13 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
250250
loc, arith::CmpFPredicate::OEQ, source, negInf);
251251
Value isNan = rewriter.createOrFold<arith::CmpFOp>(
252252
loc, arith::CmpFPredicate::UNO, source, source);
253-
Value isNonFinite = rewriter.create<arith::OrIOp>(
254-
loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
253+
Value isNonFinite = arith::OrIOp::create(rewriter,
254+
loc, arith::OrIOp::create(rewriter, loc, isInf, isNegInf), isNan);
255255

256-
Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
257-
Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
256+
Value clampedBelow = arith::MaximumFOp::create(rewriter, loc, source, minCst);
257+
Value clamped = arith::MinimumFOp::create(rewriter, loc, clampedBelow, maxCst);
258258
Value res =
259-
rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
259+
arith::SelectOp::create(rewriter, loc, isNonFinite, source, clamped);
260260
return res;
261261
}
262262

@@ -290,24 +290,24 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
290290
VectorType truncResType = VectorType::get(4, outElemType);
291291
if (!inVectorTy) {
292292
Value asFloat = castToF32(in, loc, rewriter);
293-
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
293+
Value asF8s = amdgpu::PackedTrunc2xFp8Op::create(rewriter,
294294
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
295295
/*existing=*/nullptr);
296-
Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
296+
Value result = vector::ExtractOp::create(rewriter, loc, asF8s, 0);
297297
rewriter.replaceOp(op, result);
298298
return success();
299299
}
300300

301301
int64_t numElements = outVecType.getNumElements();
302-
Value zero = rewriter.create<arith::ConstantOp>(
302+
Value zero = arith::ConstantOp::create(rewriter,
303303
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
304304
if (outVecType.getShape().empty()) {
305305
Value scalarIn =
306-
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
306+
vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
307307
// Recurse to send the 0-D vector case to the 1-D vector case
308308
Value scalarTrunc =
309-
rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
310-
Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
309+
arith::TruncFOp::create(rewriter, loc, outElemType, scalarIn);
310+
Value result = vector::InsertOp::create(rewriter, loc, scalarTrunc, zero,
311311
ArrayRef<int64_t>{});
312312
rewriter.replaceOp(op, result);
313313
return success();
@@ -320,32 +320,32 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
320320
if (inVectorTy.getRank() > 1) {
321321
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
322322
inVectorTy.getElementType());
323-
in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
323+
in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
324324
}
325325

326326
for (int64_t i = 0; i < numElements; i += 4) {
327327
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
328328
Value thisResult = nullptr;
329329
for (int64_t j = 0; j < elemsThisOp; j += 2) {
330-
Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
330+
Value elemA = vector::ExtractOp::create(rewriter, loc, in, i + j);
331331
Value asFloatA = castToF32(elemA, loc, rewriter);
332332
Value asFloatB = nullptr;
333333
if (j + 1 < elemsThisOp) {
334-
Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
334+
Value elemB = vector::ExtractOp::create(rewriter, loc, in, i + j + 1);
335335
asFloatB = castToF32(elemB, loc, rewriter);
336336
}
337-
thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
337+
thisResult = amdgpu::PackedTrunc2xFp8Op::create(rewriter,
338338
loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
339339
}
340340
if (elemsThisOp < 4)
341-
thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
341+
thisResult = vector::ExtractStridedSliceOp::create(rewriter,
342342
loc, thisResult, 0, elemsThisOp, 1);
343-
result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
343+
result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
344344
result, i, 1);
345345
}
346346

347347
if (inVectorTy.getRank() != outVecType.getRank()) {
348-
result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
348+
result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
349349
}
350350

351351
rewriter.replaceOp(op, result);
@@ -373,10 +373,10 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
373373

374374
// Handle the case where input type is not a vector type
375375
if (!inVectorTy) {
376-
auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
376+
auto sourceB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());
377377
Value asF16s =
378-
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
379-
Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
378+
ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, in, sourceB);
379+
Value result = vector::ExtractOp::create(rewriter, loc, asF16s, 0);
380380
rewriter.replaceOp(op, result);
381381
return success();
382382
}
@@ -388,33 +388,33 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
388388
if (inVectorTy.getRank() > 1) {
389389
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
390390
inVectorTy.getElementType());
391-
in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
391+
in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
392392
}
393393

394394
// Handle the vector case. We also handle the (uncommon) case where the vector
395395
// length is odd
396396
for (int64_t i = 0; i < numElements; i += 2) {
397397
int64_t elemsThisOp = std::min(numElements, i + 2) - i;
398398
Value thisResult = nullptr;
399-
Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
400-
Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
399+
Value elemA = vector::ExtractOp::create(rewriter, loc, in, i);
400+
Value elemB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());
401401

402402
if (elemsThisOp == 2) {
403-
elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
403+
elemB = vector::ExtractOp::create(rewriter, loc, in, i + 1);
404404
}
405405

406406
thisResult =
407-
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
407+
ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, elemA, elemB);
408408
// Place back the truncated result into the possibly larger vector. If we
409409
// are operating on a size 2 vector, these operations should be folded away
410-
thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
410+
thisResult = vector::ExtractStridedSliceOp::create(rewriter,
411411
loc, thisResult, 0, elemsThisOp, 1);
412-
result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
412+
result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
413413
result, i, 1);
414414
}
415415

416416
if (inVectorTy.getRank() != outVecType.getRank()) {
417-
result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
417+
result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
418418
}
419419

420420
rewriter.replaceOp(op, result);

mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,14 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
7474
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
7575
auto denseAttr1D = DenseElementsAttr::get(
7676
tileSliceType, denseAttr.getSplatValue<Attribute>());
77-
auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
77+
auto constantOp1D = arith::ConstantOp::create(rewriter, loc, denseAttr1D);
7878

79-
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
79+
auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
8080
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
8181
Value currentTile) {
8282
// Create 'arm_sme.insert_tile_slice' to write vector to tile
8383
// slice.
84-
auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
84+
auto nextTile = arm_sme::InsertTileSliceOp::create(b,
8585
loc, tileType, constantOp1D, currentTile, tileSliceIndex);
8686
return nextTile.getResult();
8787
};

0 commit comments

Comments
 (0)