Skip to content

Commit 5e24434

Browse files
committed
fix cpp custom builders
1 parent d397885 commit 5e24434

File tree

3 files changed

+76
-15
lines changed

3 files changed

+76
-15
lines changed

mlir/include/mlir/Dialect/Arith/IR/Arith.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,21 @@ class ConstantIntOp : public arith::ConstantOp {
5959
/// Build a constant int op that produces an integer of the specified width.
6060
static void build(OpBuilder &builder, OperationState &result, int64_t value,
6161
unsigned width);
62+
static ConstantIntOp create(OpBuilder &builder, Location location,
63+
int64_t value, unsigned width);
6264

6365
/// Build a constant int op that produces an integer of the specified type,
6466
/// which must be an integer type.
6567
static void build(OpBuilder &builder, OperationState &result, Type type,
6668
int64_t value);
69+
static ConstantIntOp create(OpBuilder &builder, Location location, Type type,
70+
int64_t value);
6771

6872
/// Build a constant int op that produces an integer from an APInt
6973
static void build(OpBuilder &builder, OperationState &result, Type type,
7074
const APInt &value);
75+
static ConstantIntOp create(OpBuilder &builder, Location location, Type type,
76+
const APInt &value);
7177

7278
inline int64_t value() {
7379
return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
@@ -85,6 +91,8 @@ class ConstantFloatOp : public arith::ConstantOp {
8591
/// Build a constant float op that produces a float of the specified type.
8692
static void build(OpBuilder &builder, OperationState &result, FloatType type,
8793
const APFloat &value);
94+
static ConstantFloatOp create(OpBuilder &builder, Location location,
95+
FloatType type, const APFloat &value);
8896

8997
inline APFloat value() {
9098
return cast<FloatAttr>(arith::ConstantOp::getValue()).getValue();
@@ -100,6 +108,8 @@ class ConstantIndexOp : public arith::ConstantOp {
100108
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
101109
/// Build a constant int op that produces an index.
102110
static void build(OpBuilder &builder, OperationState &result, int64_t value);
111+
static ConstantIndexOp create(OpBuilder &builder, Location location,
112+
int64_t value);
103113

104114
inline int64_t value() {
105115
return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -256,18 +256,49 @@ void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
256256
builder.getIntegerAttr(type, value));
257257
}
258258

259+
arith::ConstantIntOp arith::ConstantIntOp::create(OpBuilder &builder,
260+
Location location,
261+
int64_t value,
262+
unsigned width) {
263+
mlir::OperationState state(location, getOperationName());
264+
build(builder, state, value, width);
265+
auto result = llvm::dyn_cast<ConstantIntOp>(builder.create(state));
266+
assert(result && "builder didn't return the right type");
267+
return result;
268+
}
269+
259270
void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
260271
Type type, int64_t value) {
261272
arith::ConstantOp::build(builder, result, type,
262273
builder.getIntegerAttr(type, value));
263274
}
264275

276+
arith::ConstantIntOp arith::ConstantIntOp::create(OpBuilder &builder,
277+
Location location, Type type,
278+
int64_t value) {
279+
mlir::OperationState state(location, getOperationName());
280+
build(builder, state, type, value);
281+
auto result = llvm::dyn_cast<ConstantIntOp>(builder.create(state));
282+
assert(result && "builder didn't return the right type");
283+
return result;
284+
}
285+
265286
void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
266287
Type type, const APInt &value) {
267288
arith::ConstantOp::build(builder, result, type,
268289
builder.getIntegerAttr(type, value));
269290
}
270291

292+
arith::ConstantIntOp arith::ConstantIntOp::create(OpBuilder &builder,
293+
Location location, Type type,
294+
const APInt &value) {
295+
mlir::OperationState state(location, getOperationName());
296+
build(builder, state, type, value);
297+
auto result = llvm::dyn_cast<ConstantIntOp>(builder.create(state));
298+
assert(result && "builder didn't return the right type");
299+
return result;
300+
}
301+
271302
bool arith::ConstantIntOp::classof(Operation *op) {
272303
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
273304
return constOp.getType().isSignlessInteger();
@@ -280,6 +311,17 @@ void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
280311
builder.getFloatAttr(type, value));
281312
}
282313

314+
arith::ConstantFloatOp arith::ConstantFloatOp::create(OpBuilder &builder,
315+
Location location,
316+
FloatType type,
317+
const APFloat &value) {
318+
mlir::OperationState state(location, getOperationName());
319+
build(builder, state, type, value);
320+
auto result = llvm::dyn_cast<ConstantFloatOp>(builder.create(state));
321+
assert(result && "builder didn't return the right type");
322+
return result;
323+
}
324+
283325
bool arith::ConstantFloatOp::classof(Operation *op) {
284326
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
285327
return llvm::isa<FloatType>(constOp.getType());
@@ -292,6 +334,16 @@ void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
292334
builder.getIndexAttr(value));
293335
}
294336

337+
arith::ConstantIndexOp arith::ConstantIndexOp::create(OpBuilder &builder,
338+
Location location,
339+
int64_t value) {
340+
mlir::OperationState state(location, getOperationName());
341+
build(builder, state, value);
342+
auto result = llvm::dyn_cast<ConstantIndexOp>(builder.create(state));
343+
assert(result && "builder didn't return the right type");
344+
return result;
345+
}
346+
295347
bool arith::ConstantIndexOp::classof(Operation *op) {
296348
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
297349
return constOp.getType().isIndex();
@@ -2335,9 +2387,8 @@ class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
23352387
// comparison.
23362388
rewriter.replaceOpWithNewOp<CmpIOp>(
23372389
op, pred, intVal,
2338-
ConstantOp::create(rewriter,
2339-
op.getLoc(), intVal.getType(),
2340-
rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2390+
ConstantOp::create(rewriter, op.getLoc(), intVal.getType(),
2391+
rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
23412392
return success();
23422393
}
23432394
};
@@ -2374,10 +2425,10 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
23742425
matchPattern(op.getFalseValue(), m_One())) {
23752426
rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
23762427
op, op.getType(),
2377-
arith::XOrIOp::create(rewriter,
2378-
op.getLoc(), op.getCondition(),
2379-
arith::ConstantIntOp::create(rewriter,
2380-
op.getLoc(), op.getCondition().getType(), 1)));
2428+
arith::XOrIOp::create(
2429+
rewriter, op.getLoc(), op.getCondition(),
2430+
arith::ConstantIntOp::create(rewriter, op.getLoc(),
2431+
op.getCondition().getType(), 1)));
23812432
return success();
23822433
}
23832434

@@ -2440,12 +2491,12 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
24402491

24412492
// Constant-fold constant operands over non-splat constant condition.
24422493
// select %cst_vec, %cst0, %cst1 => %cst2
2443-
if (auto cond =
2444-
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2445-
if (auto lhs =
2446-
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2447-
if (auto rhs =
2448-
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2494+
if (auto cond = llvm::dyn_cast_if_present<DenseElementsAttr>(
2495+
adaptor.getCondition())) {
2496+
if (auto lhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
2497+
adaptor.getTrueValue())) {
2498+
if (auto rhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
2499+
adaptor.getFalseValue())) {
24492500
SmallVector<Attribute> results;
24502501
results.reserve(static_cast<size_t>(cond.getNumElements()));
24512502
auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
@@ -2713,7 +2764,7 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
27132764
return arith::MaximumFOp::create(builder, loc, lhs, rhs);
27142765
case AtomicRMWKind::minimumf:
27152766
return arith::MinimumFOp::create(builder, loc, lhs, rhs);
2716-
case AtomicRMWKind::maxnumf:
2767+
case AtomicRMWKind::maxnumf:
27172768
return arith::MaxNumFOp::create(builder, loc, lhs, rhs);
27182769
case AtomicRMWKind::minnumf:
27192770
return arith::MinNumFOp::create(builder, loc, lhs, rhs);

mlir/lib/IR/BuiltinDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ void ModuleOp::build(OpBuilder &builder, OperationState &state,
132132
/// Construct a module from the given context.
133133
ModuleOp ModuleOp::create(Location loc, std::optional<StringRef> name) {
134134
OpBuilder builder(loc->getContext());
135-
return ModuleOp::create(builder, loc, name);
135+
return builder.create<ModuleOp>(loc, name);
136136
}
137137

138138
DataLayoutSpecInterface ModuleOp::getDataLayoutSpec() {

0 commit comments

Comments
 (0)