Skip to content

Commit ff0dcc4

Browse files
authored
[MLIR][Linalg] Harden parsing Linalg named ops (#145337)
This thread through proper error handling / reporting capabilities to avoid hitting llvm_unreachable while parsing linalg ops. Fixes #132755 Fixes #132740 Fixes #129185
1 parent ac29858 commit ff0dcc4

File tree

10 files changed

+263
-75
lines changed

10 files changed

+263
-75
lines changed

mlir/include/mlir/Dialect/Linalg/IR/Linalg.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/IR/AffineMap.h"
1717
#include "mlir/IR/BuiltinDialect.h"
1818
#include "mlir/IR/BuiltinTypes.h"
19+
#include "mlir/IR/Diagnostics.h"
1920
#include "mlir/IR/Dialect.h"
2021
#include "mlir/IR/ImplicitLocOpBuilder.h"
2122
#include "mlir/IR/TypeUtilities.h"
@@ -26,6 +27,9 @@
2627
#include "mlir/Interfaces/SideEffectInterfaces.h"
2728
#include "mlir/Interfaces/TilingInterface.h"
2829
#include "mlir/Interfaces/ViewLikeInterface.h"
30+
31+
#include "llvm/ADT/STLFunctionalExtras.h"
32+
2933
#include <optional>
3034

3135
namespace mlir {

mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def Linalg_Dialect : Dialect {
5252
kMemoizedIndexingMapsAttrName = "linalg.memoized_indexing_maps";
5353

5454
using RegionBuilderFunType = llvm::function_ref<
55-
void(ImplicitLocOpBuilder &b, Block &, ArrayRef<NamedAttribute>)>;
55+
void(ImplicitLocOpBuilder &b, Block &, ArrayRef<NamedAttribute>,
56+
function_ref<InFlightDiagnostic()>)>;
5657
RegionBuilderFunType getRegionBuilder(StringRef name) {
5758
return namedStructuredOpRegionBuilders.lookup(name);
5859
}

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ def LinalgStructuredInterface
720720
Returns a null function if this named op does not define a region
721721
builder.
722722
}],
723-
/*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>)>",
723+
/*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>, function_ref<InFlightDiagnostic()>)>",
724724
/*methodName=*/"getRegionBuilder",
725725
(ins),
726726
[{ return ConcreteOp::getRegionBuilder(); }]

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
192192
}
193193

194194
static std::function<void(ImplicitLocOpBuilder &,
195-
Block &, ArrayRef<NamedAttribute>)>
195+
Block &, ArrayRef<NamedAttribute>,
196+
function_ref<InFlightDiagnostic()>)>
196197
getRegionBuilder() {
197198
return nullptr;
198199
}
@@ -300,7 +301,8 @@ def MapOp : LinalgStructuredBase_Op<"map", [
300301
}
301302

302303
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
303-
mlir::ArrayRef<mlir::NamedAttribute>)>
304+
mlir::ArrayRef<mlir::NamedAttribute>,
305+
function_ref<InFlightDiagnostic()>)>
304306
getRegionBuilder() {
305307
return nullptr;
306308
}
@@ -380,7 +382,8 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
380382

381383
// Implement functions necessary for DestinationStyleOpInterface.
382384
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
383-
mlir::ArrayRef<mlir::NamedAttribute>)>
385+
mlir::ArrayRef<mlir::NamedAttribute>,
386+
function_ref<InFlightDiagnostic()>)>
384387
getRegionBuilder() {
385388
return nullptr;
386389
}
@@ -449,13 +452,14 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
449452
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
450453

451454
static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
452-
mlir::ArrayRef<mlir::NamedAttribute>) {
455+
mlir::ArrayRef<mlir::NamedAttribute>, function_ref<InFlightDiagnostic()> emitError) {
453456
OpBuilder::InsertionGuard guard(b);
454457
b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
455458
}
456459

457460
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
458-
mlir::ArrayRef<mlir::NamedAttribute>)>
461+
mlir::ArrayRef<mlir::NamedAttribute>,
462+
function_ref<InFlightDiagnostic()>)>
459463
getRegionBuilder() {
460464
return regionBuilder;
461465
}
@@ -521,13 +525,15 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
521525
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
522526

523527
static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
524-
mlir::ArrayRef<mlir::NamedAttribute>) {
528+
mlir::ArrayRef<mlir::NamedAttribute>,
529+
function_ref<InFlightDiagnostic()> emitError) {
525530
OpBuilder::InsertionGuard guard(b);
526531
b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
527532
}
528533

529534
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
530-
mlir::ArrayRef<mlir::NamedAttribute>)>
535+
mlir::ArrayRef<mlir::NamedAttribute>,
536+
function_ref<InFlightDiagnostic()>)>
531537
getRegionBuilder() {
532538
return regionBuilder;
533539
}
@@ -631,10 +637,12 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
631637
/// Implements the block region builder for the elementwiseOp. This is
632638
/// called by the 'fillStructuredOpRegion'.
633639
static void regionBuilder(ImplicitLocOpBuilder &b,
634-
Block &block, ArrayRef<NamedAttribute> attrs);
640+
Block &block, ArrayRef<NamedAttribute> attrs,
641+
function_ref<InFlightDiagnostic()> emitError);
635642

636643
static std::function<void(ImplicitLocOpBuilder &,
637-
Block &, ArrayRef<NamedAttribute>)>
644+
Block &, ArrayRef<NamedAttribute>,
645+
function_ref<InFlightDiagnostic()>)>
638646
getRegionBuilder() {
639647
return regionBuilder;
640648
}
@@ -771,7 +779,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
771779

772780
/// Implements the block region builder.
773781
static void regionBuilder(ImplicitLocOpBuilder &b,
774-
Block &block, ArrayRef<NamedAttribute> attrs);
782+
Block &block, ArrayRef<NamedAttribute> attrs,
783+
function_ref<InFlightDiagnostic()> emitError);
775784

776785
/// Returns a list of AffineMap with the default matmul indexing charactristic.
777786
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
@@ -780,7 +789,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
780789
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
781790

782791
static std::function<void(ImplicitLocOpBuilder &,
783-
Block &, ArrayRef<NamedAttribute>)>
792+
Block &, ArrayRef<NamedAttribute>,
793+
function_ref<InFlightDiagnostic()>)>
784794
getRegionBuilder() {
785795
return regionBuilder;
786796
}
@@ -916,10 +926,12 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
916926
static unsigned getNumRegionArgs();
917927

918928
static void regionBuilder(ImplicitLocOpBuilder &b,
919-
Block &block, ArrayRef<NamedAttribute> attrs);
929+
Block &block, ArrayRef<NamedAttribute> attrs,
930+
function_ref<InFlightDiagnostic()> emitError);
920931

921932
static std::function<void(ImplicitLocOpBuilder &,
922-
Block &, ArrayRef<NamedAttribute>)>
933+
Block &, ArrayRef<NamedAttribute>,
934+
function_ref<InFlightDiagnostic()>)>
923935
getRegionBuilder() {
924936
return regionBuilder;
925937
}
@@ -1033,9 +1045,11 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
10331045

10341046
SmallVector<utils::IteratorType> getIteratorTypesArray();
10351047
static void regionBuilder(ImplicitLocOpBuilder &b,
1036-
Block &block, ArrayRef<NamedAttribute> attrs);
1048+
Block &block, ArrayRef<NamedAttribute> attrs,
1049+
function_ref<InFlightDiagnostic()> emitError);
10371050
static std::function<void(ImplicitLocOpBuilder &,
1038-
Block &, ArrayRef<NamedAttribute>)>
1051+
Block &, ArrayRef<NamedAttribute>,
1052+
function_ref<InFlightDiagnostic()>)>
10391053
getRegionBuilder() {
10401054
return regionBuilder;
10411055
}
@@ -1161,7 +1175,8 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
11611175

11621176
/// Implements the block region builder.
11631177
static void regionBuilder(ImplicitLocOpBuilder &b,
1164-
Block &block, ArrayRef<NamedAttribute> attrs);
1178+
Block &block, ArrayRef<NamedAttribute> attrs,
1179+
function_ref<InFlightDiagnostic()> emitError);
11651180

11661181
/// Returns a list of AffineMap with the default batch_reduce_matmul indexing charactristic.
11671182
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
@@ -1170,7 +1185,8 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
11701185
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
11711186

11721187
static std::function<void(ImplicitLocOpBuilder &,
1173-
Block &, ArrayRef<NamedAttribute>)>
1188+
Block &, ArrayRef<NamedAttribute>,
1189+
function_ref<InFlightDiagnostic()>)>
11741190
getRegionBuilder() {
11751191
return regionBuilder;
11761192
}

mlir/lib/CAPI/Dialect/Linalg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
3838
Region &region = op->getRegion(0);
3939
Block *body = b.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
4040
b.setInsertionPointToStart(body);
41-
fun(b, *body, op->getAttrs());
41+
fun(b, *body, op->getAttrs(), /*emitError=*/{});
4242
}
4343

4444
MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op) {

0 commit comments

Comments
 (0)