Skip to content

Commit d85eb4e

Browse files
committed
[AsmParser] Introduce a new "Argument" abstraction + supporting logic
MLIR has a common pattern for "arguments" that uses syntax like `%x : i32 {attrs} loc("sourceloc")` which is implemented in adhoc ways throughout the codebase. The approach this uses is verbose (because it is implemented with parallel arrays) and inconsistent (e.g. lots of things drop source location info). Solve this by introducing OpAsmParser::Argument and make addRegion (which sets up BlockArguments for the region) take it. Convert the world to propagating this down. This means that we correctly capture and propagate source location information in a lot more cases (e.g. see the affine.for testcase example), and it also simplifies much code. Differential Revision: https://reviews.llvm.org/D124649
1 parent 6e689cb commit d85eb4e

File tree

19 files changed

+455
-514
lines changed

19 files changed

+455
-514
lines changed

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,8 +1200,8 @@ mlir::ParseResult fir::GlobalOp::parse(mlir::OpAsmParser &parser,
12001200
result.addRegion();
12011201
} else {
12021202
// Parse the optional initializer body.
1203-
auto parseResult = parser.parseOptionalRegion(
1204-
*result.addRegion(), /*arguments=*/llvm::None, /*argTypes=*/llvm::None);
1203+
auto parseResult =
1204+
parser.parseOptionalRegion(*result.addRegion(), /*arguments=*/{});
12051205
if (parseResult.hasValue() && mlir::failed(*parseResult))
12061206
return mlir::failure();
12071207
}
@@ -1562,9 +1562,9 @@ void fir::IterWhileOp::build(mlir::OpBuilder &builder,
15621562
mlir::ParseResult fir::IterWhileOp::parse(mlir::OpAsmParser &parser,
15631563
mlir::OperationState &result) {
15641564
auto &builder = parser.getBuilder();
1565-
mlir::OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step;
1566-
if (parser.parseLParen() ||
1567-
parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) ||
1565+
mlir::OpAsmParser::Argument inductionVariable, iterateVar;
1566+
mlir::OpAsmParser::UnresolvedOperand lb, ub, step, iterateInput;
1567+
if (parser.parseLParen() || parser.parseArgument(inductionVariable) ||
15681568
parser.parseEqual())
15691569
return mlir::failure();
15701570

@@ -1577,22 +1577,18 @@ mlir::ParseResult fir::IterWhileOp::parse(mlir::OpAsmParser &parser,
15771577
parser.resolveOperand(ub, indexType, result.operands) ||
15781578
parser.parseKeyword("step") || parser.parseOperand(step) ||
15791579
parser.parseRParen() ||
1580-
parser.resolveOperand(step, indexType, result.operands))
1581-
return mlir::failure();
1582-
1583-
mlir::OpAsmParser::UnresolvedOperand iterateVar, iterateInput;
1584-
if (parser.parseKeyword("and") || parser.parseLParen() ||
1585-
parser.parseOperand(iterateVar, /*allowResultNumber=*/false) ||
1586-
parser.parseEqual() || parser.parseOperand(iterateInput) ||
1587-
parser.parseRParen() ||
1580+
parser.resolveOperand(step, indexType, result.operands) ||
1581+
parser.parseKeyword("and") || parser.parseLParen() ||
1582+
parser.parseArgument(iterateVar) || parser.parseEqual() ||
1583+
parser.parseOperand(iterateInput) || parser.parseRParen() ||
15881584
parser.resolveOperand(iterateInput, i1Type, result.operands))
15891585
return mlir::failure();
15901586

15911587
// Parse the initial iteration arguments.
1592-
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> regionArgs;
15931588
auto prependCount = false;
15941589

15951590
// Induction variable.
1591+
llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
15961592
regionArgs.push_back(inductionVariable);
15971593
regionArgs.push_back(iterateVar);
15981594

@@ -1652,7 +1648,10 @@ mlir::ParseResult fir::IterWhileOp::parse(mlir::OpAsmParser &parser,
16521648
parser.getNameLoc(),
16531649
"mismatch in number of loop-carried values and defined values");
16541650

1655-
if (parser.parseRegion(*body, regionArgs, argTypes))
1651+
for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
1652+
regionArgs[i].type = argTypes[i];
1653+
1654+
if (parser.parseRegion(*body, regionArgs))
16561655
return mlir::failure();
16571656

16581657
fir::IterWhileOp::ensureTerminator(*body, builder, result.location);
@@ -1876,10 +1875,10 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
18761875
mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
18771876
mlir::OperationState &result) {
18781877
auto &builder = parser.getBuilder();
1879-
mlir::OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step;
1878+
mlir::OpAsmParser::Argument inductionVariable;
1879+
mlir::OpAsmParser::UnresolvedOperand lb, ub, step;
18801880
// Parse the induction variable followed by '='.
1881-
if (parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) ||
1882-
parser.parseEqual())
1881+
if (parser.parseArgument(inductionVariable) || parser.parseEqual())
18831882
return mlir::failure();
18841883

18851884
// Parse loop bounds.
@@ -1896,7 +1895,8 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
18961895
result.addAttribute("unordered", builder.getUnitAttr());
18971896

18981897
// Parse the optional initial iteration arguments.
1899-
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> regionArgs, operands;
1898+
llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
1899+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
19001900
llvm::SmallVector<mlir::Type> argTypes;
19011901
bool prependCount = false;
19021902
regionArgs.push_back(inductionVariable);
@@ -1939,8 +1939,10 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
19391939
return parser.emitError(
19401940
parser.getNameLoc(),
19411941
"mismatch in number of loop-carried values and defined values");
1942+
for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
1943+
regionArgs[i].type = argTypes[i];
19421944

1943-
if (parser.parseRegion(*body, regionArgs, argTypes))
1945+
if (parser.parseRegion(*body, regionArgs))
19441946
return mlir::failure();
19451947

19461948
DoLoopOp::ensureTerminator(*body, builder, result.location);

mlir/include/mlir/IR/FunctionImplementation.h

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ void addArgAndResultAttrs(Builder &builder, OperationState &result,
4141
ArrayRef<DictionaryAttr> argAttrs,
4242
ArrayRef<DictionaryAttr> resultAttrs);
4343
void addArgAndResultAttrs(Builder &builder, OperationState &result,
44-
ArrayRef<NamedAttrList> argAttrs,
45-
ArrayRef<NamedAttrList> resultAttrs);
44+
ArrayRef<OpAsmParser::Argument> argAttrs,
45+
ArrayRef<DictionaryAttr> resultAttrs);
4646

4747
/// Callback type for `parseFunctionOp`, the callback should produce the
4848
/// type that will be associated with a function-like operation from lists of
@@ -52,26 +52,20 @@ void addArgAndResultAttrs(Builder &builder, OperationState &result,
5252
using FuncTypeBuilder = function_ref<Type(
5353
Builder &, ArrayRef<Type>, ArrayRef<Type>, VariadicFlag, std::string &)>;
5454

55-
/// Parses function arguments using `parser`. The `allowVariadic` argument
56-
/// indicates whether functions with variadic arguments are supported. The
57-
/// trailing arguments are populated by this function with names, types,
58-
/// attributes and locations of the arguments.
59-
ParseResult parseFunctionArgumentList(
60-
OpAsmParser &parser, bool allowAttributes, bool allowVariadic,
61-
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames,
62-
SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
63-
bool &isVariadic);
64-
6555
/// Parses a function signature using `parser`. The `allowVariadic` argument
6656
/// indicates whether functions with variadic arguments are supported. The
6757
/// trailing arguments are populated by this function with names, types,
6858
/// attributes and locations of the arguments and those of the results.
69-
ParseResult parseFunctionSignature(
70-
OpAsmParser &parser, bool allowVariadic,
71-
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames,
72-
SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
73-
bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
74-
SmallVectorImpl<NamedAttrList> &resultAttrs);
59+
ParseResult
60+
parseFunctionSignature(OpAsmParser &parser, bool allowVariadic,
61+
SmallVectorImpl<OpAsmParser::Argument> &arguments,
62+
bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
63+
SmallVectorImpl<DictionaryAttr> &resultAttrs);
64+
65+
/// Get a function type corresponding to an array of arguments (which have
66+
/// types) and a set of result types.
67+
Type getFunctionType(Builder &builder, ArrayRef<OpAsmParser::Argument> argAttrs,
68+
ArrayRef<Type> resultTypes);
7569

7670
/// Parser implementation for function-like operations. Uses
7771
/// `funcTypeBuilder` to construct the custom function type given lists of

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -633,14 +633,14 @@ class AsmParser {
633633
/// unlike `OpBuilder::getType`, this method does not implicitly insert a
634634
/// context parameter.
635635
template <typename T, typename... ParamsT>
636-
T getChecked(SMLoc loc, ParamsT &&... params) {
636+
T getChecked(SMLoc loc, ParamsT &&...params) {
637637
return T::getChecked([&] { return emitError(loc); },
638638
std::forward<ParamsT>(params)...);
639639
}
640640
/// A variant of `getChecked` that uses the result of `getNameLoc` to emit
641641
/// errors.
642642
template <typename T, typename... ParamsT>
643-
T getChecked(ParamsT &&... params) {
643+
T getChecked(ParamsT &&...params) {
644644
return T::getChecked([&] { return emitError(getNameLoc()); },
645645
std::forward<ParamsT>(params)...);
646646
}
@@ -1093,7 +1093,6 @@ class OpAsmParser : public AsmParser {
10931093
SMLoc location; // Location of the token.
10941094
StringRef name; // Value name, e.g. %42 or %abc
10951095
unsigned number; // Number, e.g. 12 for an operand like %xyz#12
1096-
Optional<Location> sourceLoc; // Source location specifier if present.
10971096
};
10981097

10991098
/// Parse different components, viz., use-info of operand(s), successor(s),
@@ -1219,34 +1218,64 @@ class OpAsmParser : public AsmParser {
12191218
SmallVectorImpl<UnresolvedOperand> &symbOperands,
12201219
AffineExpr &expr) = 0;
12211220

1221+
//===--------------------------------------------------------------------===//
1222+
// Argument Parsing
1223+
//===--------------------------------------------------------------------===//
1224+
1225+
struct Argument {
1226+
UnresolvedOperand ssaName; // SourceLoc, SSA name, result #.
1227+
Type type; // Type.
1228+
DictionaryAttr attrs; // Attributes if present.
1229+
Optional<Location> sourceLoc; // Source location specifier if present.
1230+
};
1231+
1232+
/// Parse a single argument with the following syntax:
1233+
///
1234+
/// `%ssaName : !type { optionalAttrDict} loc(optionalSourceLoc)`
1235+
///
1236+
/// If `allowType` is false or `allowAttrs` are false then the respective
1237+
/// parts of the grammar are not parsed.
1238+
virtual ParseResult parseArgument(Argument &result, bool allowType = false,
1239+
bool allowAttrs = false) = 0;
1240+
1241+
/// Parse a single argument if present.
1242+
virtual OptionalParseResult
1243+
parseOptionalArgument(Argument &result, bool allowType = false,
1244+
bool allowAttrs = false) = 0;
1245+
1246+
/// Parse zero or more arguments with a specified surrounding delimiter.
1247+
virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
1248+
Delimiter delimiter = Delimiter::None,
1249+
bool allowType = false,
1250+
bool allowAttrs = false) = 0;
1251+
12221252
//===--------------------------------------------------------------------===//
12231253
// Region Parsing
12241254
//===--------------------------------------------------------------------===//
12251255

12261256
/// Parses a region. Any parsed blocks are appended to 'region' and must be
12271257
/// moved to the op regions after the op is created. The first block of the
1228-
/// region takes 'arguments' of types 'argTypes'. If 'enableNameShadowing' is
1229-
/// set to true, the argument names are allowed to shadow the names of other
1230-
/// existing SSA values defined above the region scope. 'enableNameShadowing'
1231-
/// can only be set to true for regions attached to operations that are
1232-
/// 'IsolatedFromAbove'.
1258+
/// region takes 'arguments'.
1259+
///
1260+
/// If 'enableNameShadowing' is set to true, the argument names are allowed to
1261+
/// shadow the names of other existing SSA values defined above the region
1262+
/// scope. 'enableNameShadowing' can only be set to true for regions attached
1263+
/// to operations that are 'IsolatedFromAbove'.
12331264
virtual ParseResult parseRegion(Region &region,
1234-
ArrayRef<UnresolvedOperand> arguments = {},
1235-
ArrayRef<Type> argTypes = {},
1265+
ArrayRef<Argument> arguments = {},
12361266
bool enableNameShadowing = false) = 0;
12371267

12381268
/// Parses a region if present.
1239-
virtual OptionalParseResult parseOptionalRegion(
1240-
Region &region, ArrayRef<UnresolvedOperand> arguments = {},
1241-
ArrayRef<Type> argTypes = {}, bool enableNameShadowing = false) = 0;
1269+
virtual OptionalParseResult
1270+
parseOptionalRegion(Region &region, ArrayRef<Argument> arguments = {},
1271+
bool enableNameShadowing = false) = 0;
12421272

12431273
/// Parses a region if present. If the region is present, a new region is
12441274
/// allocated and placed in `region`. If no region is present or on failure,
12451275
/// `region` remains untouched.
12461276
virtual OptionalParseResult
12471277
parseOptionalRegion(std::unique_ptr<Region> &region,
1248-
ArrayRef<UnresolvedOperand> arguments = {},
1249-
ArrayRef<Type> argTypes = {},
1278+
ArrayRef<Argument> arguments = {},
12501279
bool enableNameShadowing = false) = 0;
12511280

12521281
//===--------------------------------------------------------------------===//
@@ -1269,7 +1298,7 @@ class OpAsmParser : public AsmParser {
12691298

12701299
/// Parse a list of assignments of the form
12711300
/// (%x1 = %y1, %x2 = %y2, ...)
1272-
ParseResult parseAssignmentList(SmallVectorImpl<UnresolvedOperand> &lhs,
1301+
ParseResult parseAssignmentList(SmallVectorImpl<Argument> &lhs,
12731302
SmallVectorImpl<UnresolvedOperand> &rhs) {
12741303
OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
12751304
if (!result.hasValue())
@@ -1278,26 +1307,8 @@ class OpAsmParser : public AsmParser {
12781307
}
12791308

12801309
virtual OptionalParseResult
1281-
parseOptionalAssignmentList(SmallVectorImpl<UnresolvedOperand> &lhs,
1310+
parseOptionalAssignmentList(SmallVectorImpl<Argument> &lhs,
12821311
SmallVectorImpl<UnresolvedOperand> &rhs) = 0;
1283-
1284-
/// Parse a list of assignments of the form
1285-
/// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...)
1286-
ParseResult
1287-
parseAssignmentListWithTypes(SmallVectorImpl<UnresolvedOperand> &lhs,
1288-
SmallVectorImpl<UnresolvedOperand> &rhs,
1289-
SmallVectorImpl<Type> &types) {
1290-
OptionalParseResult result =
1291-
parseOptionalAssignmentListWithTypes(lhs, rhs, types);
1292-
if (!result.hasValue())
1293-
return emitError(getCurrentLocation(), "expected '('");
1294-
return result.getValue();
1295-
}
1296-
1297-
virtual OptionalParseResult
1298-
parseOptionalAssignmentListWithTypes(SmallVectorImpl<UnresolvedOperand> &lhs,
1299-
SmallVectorImpl<UnresolvedOperand> &rhs,
1300-
SmallVectorImpl<Type> &types) = 0;
13011312
};
13021313

13031314
//===--------------------------------------------------------------------===//
@@ -1339,7 +1350,6 @@ class OpAsmDialectInterface
13391350
virtual AliasResult getAlias(Type type, raw_ostream &os) const {
13401351
return AliasResult::NoAlias;
13411352
}
1342-
13431353
};
13441354
} // namespace mlir
13451355

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,10 +1431,10 @@ static ParseResult parseBound(bool isLower, OperationState &result,
14311431

14321432
ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
14331433
auto &builder = parser.getBuilder();
1434-
OpAsmParser::UnresolvedOperand inductionVariable;
1434+
OpAsmParser::Argument inductionVariable;
1435+
inductionVariable.type = builder.getIndexType();
14351436
// Parse the induction variable followed by '='.
1436-
if (parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) ||
1437-
parser.parseEqual())
1437+
if (parser.parseArgument(inductionVariable) || parser.parseEqual())
14381438
return failure();
14391439

14401440
// Parse loop bounds.
@@ -1463,8 +1463,10 @@ ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
14631463
}
14641464

14651465
// Parse the optional initial iteration arguments.
1466-
SmallVector<OpAsmParser::UnresolvedOperand, 4> regionArgs, operands;
1467-
SmallVector<Type, 4> argTypes;
1466+
SmallVector<OpAsmParser::Argument, 4> regionArgs;
1467+
SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1468+
1469+
// Induction variable.
14681470
regionArgs.push_back(inductionVariable);
14691471

14701472
if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
@@ -1473,23 +1475,23 @@ ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
14731475
parser.parseArrowTypeList(result.types))
14741476
return failure();
14751477
// Resolve input operands.
1476-
for (auto operandType : llvm::zip(operands, result.types))
1477-
if (parser.resolveOperand(std::get<0>(operandType),
1478-
std::get<1>(operandType), result.operands))
1478+
for (auto argOperandType :
1479+
llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
1480+
Type type = std::get<2>(argOperandType);
1481+
std::get<0>(argOperandType).type = type;
1482+
if (parser.resolveOperand(std::get<1>(argOperandType), type,
1483+
result.operands))
14791484
return failure();
1485+
}
14801486
}
1481-
// Induction variable.
1482-
Type indexType = builder.getIndexType();
1483-
argTypes.push_back(indexType);
1484-
// Loop carried variables.
1485-
argTypes.append(result.types.begin(), result.types.end());
1487+
14861488
// Parse the body region.
14871489
Region *body = result.addRegion();
1488-
if (regionArgs.size() != argTypes.size())
1490+
if (regionArgs.size() != result.types.size() + 1)
14891491
return parser.emitError(
14901492
parser.getNameLoc(),
14911493
"mismatch between the number of loop-carried values and results");
1492-
if (parser.parseRegion(*body, regionArgs, argTypes))
1494+
if (parser.parseRegion(*body, regionArgs))
14931495
return failure();
14941496

14951497
AffineForOp::ensureTerminator(*body, builder, result.location);
@@ -1548,7 +1550,8 @@ unsigned AffineForOp::getNumIterOperands() {
15481550

15491551
void AffineForOp::print(OpAsmPrinter &p) {
15501552
p << ' ';
1551-
p.printOperand(getBody()->getArgument(0));
1553+
p.printRegionArgument(getBody()->getArgument(0), /*argAtrs=*/{},
1554+
/*omitType=*/true);
15521555
p << " = ";
15531556
printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
15541557
p << " to ";
@@ -3527,9 +3530,8 @@ ParseResult AffineParallelOp::parse(OpAsmParser &parser,
35273530
OperationState &result) {
35283531
auto &builder = parser.getBuilder();
35293532
auto indexType = builder.getIndexType();
3530-
SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
3531-
if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
3532-
/*allowResultNumber=*/false) ||
3533+
SmallVector<OpAsmParser::Argument, 4> ivs;
3534+
if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
35333535
parser.parseEqual() ||
35343536
parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
35353537
parser.parseKeyword("to") ||
@@ -3600,8 +3602,9 @@ ParseResult AffineParallelOp::parse(OpAsmParser &parser,
36003602

36013603
// Now parse the body.
36023604
Region *body = result.addRegion();
3603-
SmallVector<Type, 4> types(ivs.size(), indexType);
3604-
if (parser.parseRegion(*body, ivs, types) ||
3605+
for (auto &iv : ivs)
3606+
iv.type = indexType;
3607+
if (parser.parseRegion(*body, ivs) ||
36053608
parser.parseOptionalAttrDict(result.attributes))
36063609
return failure();
36073610

0 commit comments

Comments
 (0)