@@ -256,18 +256,49 @@ void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
256
256
builder.getIntegerAttr (type, value));
257
257
}
258
258
259
+ arith::ConstantIntOp arith::ConstantIntOp::create (OpBuilder &builder,
260
+ Location location,
261
+ int64_t value,
262
+ unsigned width) {
263
+ mlir::OperationState state (location, getOperationName ());
264
+ build (builder, state, value, width);
265
+ auto result = llvm::dyn_cast<ConstantIntOp>(builder.create (state));
266
+ assert (result && " builder didn't return the right type" );
267
+ return result;
268
+ }
269
+
259
270
void arith::ConstantIntOp::build (OpBuilder &builder, OperationState &result,
260
271
Type type, int64_t value) {
261
272
arith::ConstantOp::build (builder, result, type,
262
273
builder.getIntegerAttr (type, value));
263
274
}
264
275
276
+ arith::ConstantIntOp arith::ConstantIntOp::create (OpBuilder &builder,
277
+ Location location, Type type,
278
+ int64_t value) {
279
+ mlir::OperationState state (location, getOperationName ());
280
+ build (builder, state, type, value);
281
+ auto result = llvm::dyn_cast<ConstantIntOp>(builder.create (state));
282
+ assert (result && " builder didn't return the right type" );
283
+ return result;
284
+ }
285
+
265
286
void arith::ConstantIntOp::build (OpBuilder &builder, OperationState &result,
266
287
Type type, const APInt &value) {
267
288
arith::ConstantOp::build (builder, result, type,
268
289
builder.getIntegerAttr (type, value));
269
290
}
270
291
292
+ arith::ConstantIntOp arith::ConstantIntOp::create (OpBuilder &builder,
293
+ Location location, Type type,
294
+ const APInt &value) {
295
+ mlir::OperationState state (location, getOperationName ());
296
+ build (builder, state, type, value);
297
+ auto result = llvm::dyn_cast<ConstantIntOp>(builder.create (state));
298
+ assert (result && " builder didn't return the right type" );
299
+ return result;
300
+ }
301
+
271
302
bool arith::ConstantIntOp::classof (Operation *op) {
272
303
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
273
304
return constOp.getType ().isSignlessInteger ();
@@ -280,6 +311,17 @@ void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
280
311
builder.getFloatAttr (type, value));
281
312
}
282
313
314
+ arith::ConstantFloatOp arith::ConstantFloatOp::create (OpBuilder &builder,
315
+ Location location,
316
+ FloatType type,
317
+ const APFloat &value) {
318
+ mlir::OperationState state (location, getOperationName ());
319
+ build (builder, state, type, value);
320
+ auto result = llvm::dyn_cast<ConstantFloatOp>(builder.create (state));
321
+ assert (result && " builder didn't return the right type" );
322
+ return result;
323
+ }
324
+
283
325
bool arith::ConstantFloatOp::classof (Operation *op) {
284
326
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
285
327
return llvm::isa<FloatType>(constOp.getType ());
@@ -292,6 +334,16 @@ void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
292
334
builder.getIndexAttr (value));
293
335
}
294
336
337
+ arith::ConstantIndexOp arith::ConstantIndexOp::create (OpBuilder &builder,
338
+ Location location,
339
+ int64_t value) {
340
+ mlir::OperationState state (location, getOperationName ());
341
+ build (builder, state, value);
342
+ auto result = llvm::dyn_cast<ConstantIndexOp>(builder.create (state));
343
+ assert (result && " builder didn't return the right type" );
344
+ return result;
345
+ }
346
+
295
347
bool arith::ConstantIndexOp::classof (Operation *op) {
296
348
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
297
349
return constOp.getType ().isIndex ();
@@ -2335,9 +2387,8 @@ class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
2335
2387
// comparison.
2336
2388
rewriter.replaceOpWithNewOp <CmpIOp>(
2337
2389
op, pred, intVal,
2338
- ConstantOp::create (rewriter,
2339
- op.getLoc (), intVal.getType (),
2340
- rewriter.getIntegerAttr (intVal.getType (), rhsInt)));
2390
+ ConstantOp::create (rewriter, op.getLoc (), intVal.getType (),
2391
+ rewriter.getIntegerAttr (intVal.getType (), rhsInt)));
2341
2392
return success ();
2342
2393
}
2343
2394
};
@@ -2374,10 +2425,10 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2374
2425
matchPattern (op.getFalseValue (), m_One ())) {
2375
2426
rewriter.replaceOpWithNewOp <arith::ExtUIOp>(
2376
2427
op, op.getType (),
2377
- arith::XOrIOp::create (rewriter,
2378
- op.getLoc (), op.getCondition (),
2379
- arith::ConstantIntOp::create (rewriter,
2380
- op. getLoc (), op.getCondition ().getType (), 1 )));
2428
+ arith::XOrIOp::create (
2429
+ rewriter, op.getLoc (), op.getCondition (),
2430
+ arith::ConstantIntOp::create (rewriter, op. getLoc (),
2431
+ op.getCondition ().getType (), 1 )));
2381
2432
return success ();
2382
2433
}
2383
2434
@@ -2440,12 +2491,12 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2440
2491
2441
2492
// Constant-fold constant operands over non-splat constant condition.
2442
2493
// select %cst_vec, %cst0, %cst1 => %cst2
2443
- if (auto cond =
2444
- llvm::dyn_cast_if_present<DenseElementsAttr>( adaptor.getCondition ())) {
2445
- if (auto lhs =
2446
- llvm::dyn_cast_if_present<DenseElementsAttr>( adaptor.getTrueValue ())) {
2447
- if (auto rhs =
2448
- llvm::dyn_cast_if_present<DenseElementsAttr>( adaptor.getFalseValue ())) {
2494
+ if (auto cond = llvm::dyn_cast_if_present<DenseElementsAttr>(
2495
+ adaptor.getCondition ())) {
2496
+ if (auto lhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
2497
+ adaptor.getTrueValue ())) {
2498
+ if (auto rhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
2499
+ adaptor.getFalseValue ())) {
2449
2500
SmallVector<Attribute> results;
2450
2501
results.reserve (static_cast <size_t >(cond.getNumElements ()));
2451
2502
auto condVals = llvm::make_range (cond.value_begin <BoolAttr>(),
@@ -2713,7 +2764,7 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2713
2764
return arith::MaximumFOp::create (builder, loc, lhs, rhs);
2714
2765
case AtomicRMWKind::minimumf:
2715
2766
return arith::MinimumFOp::create (builder, loc, lhs, rhs);
2716
- case AtomicRMWKind::maxnumf:
2767
+ case AtomicRMWKind::maxnumf:
2717
2768
return arith::MaxNumFOp::create (builder, loc, lhs, rhs);
2718
2769
case AtomicRMWKind::minnumf:
2719
2770
return arith::MinNumFOp::create (builder, loc, lhs, rhs);
0 commit comments