Skip to content

Commit 4110972

Browse files
committed
[MLIR][Linalg] Remove matmul_transpose variants
Removes the `(batch_)matmul_transpose_{a|b}` variants from OpDSL and replace it with `matmul affine_maps [...]` whenever appropriate.
1 parent 7bbb65c commit 4110972

File tree

20 files changed

+88
-970
lines changed

20 files changed

+88
-970
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 0 additions & 286 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,152 +1055,6 @@ structured_op: !LinalgStructuredOpConfig
10551055
- !ScalarExpression
10561056
scalar_arg: BZp
10571057
--- !LinalgOpConfig
1058-
metadata: !LinalgOpMetadata
1059-
name: matmul_transpose_a
1060-
cpp_class_name: MatmulTransposeAOp
1061-
doc: |-
1062-
Performs a matrix multiplication of two 2D inputs with lhs operand
1063-
transposed.
1064-
1065-
Numeric casting is performed on the operands to the inner multiply, promoting
1066-
them to the same data type as the accumulator/output.
1067-
implements:
1068-
- LinalgContractionOpInterface
1069-
structured_op: !LinalgStructuredOpConfig
1070-
args:
1071-
- !LinalgOperandDefConfig
1072-
name: A
1073-
kind: input_tensor
1074-
type_var: T1
1075-
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
1076-
- !LinalgOperandDefConfig
1077-
name: B
1078-
kind: input_tensor
1079-
type_var: T2
1080-
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
1081-
- !LinalgOperandDefConfig
1082-
name: C
1083-
kind: output_tensor
1084-
type_var: U
1085-
shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
1086-
- !LinalgOperandDefConfig
1087-
name: cast
1088-
kind: type_fn_attr
1089-
default_fn: cast_signed
1090-
indexing_maps: !LinalgIndexingMapsConfig
1091-
static_indexing_maps:
1092-
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d0)>
1093-
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
1094-
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
1095-
iterator_types:
1096-
- parallel
1097-
- parallel
1098-
- reduction
1099-
assignments:
1100-
- !ScalarAssign
1101-
arg: C
1102-
value: !ScalarExpression
1103-
scalar_fn:
1104-
kind: binary
1105-
fn_name: add
1106-
operands:
1107-
- !ScalarExpression
1108-
scalar_arg: C
1109-
- !ScalarExpression
1110-
scalar_fn:
1111-
kind: binary
1112-
fn_name: mul
1113-
operands:
1114-
- !ScalarExpression
1115-
scalar_fn:
1116-
kind: type
1117-
attr_name: cast
1118-
type_var: U
1119-
operands:
1120-
- !ScalarExpression
1121-
scalar_arg: A
1122-
- !ScalarExpression
1123-
scalar_fn:
1124-
kind: type
1125-
attr_name: cast
1126-
type_var: U
1127-
operands:
1128-
- !ScalarExpression
1129-
scalar_arg: B
1130-
--- !LinalgOpConfig
1131-
metadata: !LinalgOpMetadata
1132-
name: matmul_transpose_b
1133-
cpp_class_name: MatmulTransposeBOp
1134-
doc: |-
1135-
Performs a matrix multiplication of two 2D inputs with rhs operand
1136-
transposed.
1137-
1138-
Numeric casting is performed on the operands to the inner multiply, promoting
1139-
them to the same data type as the accumulator/output.
1140-
implements:
1141-
- LinalgContractionOpInterface
1142-
structured_op: !LinalgStructuredOpConfig
1143-
args:
1144-
- !LinalgOperandDefConfig
1145-
name: A
1146-
kind: input_tensor
1147-
type_var: T1
1148-
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
1149-
- !LinalgOperandDefConfig
1150-
name: B
1151-
kind: input_tensor
1152-
type_var: T2
1153-
shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
1154-
- !LinalgOperandDefConfig
1155-
name: C
1156-
kind: output_tensor
1157-
type_var: U
1158-
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
1159-
- !LinalgOperandDefConfig
1160-
name: cast
1161-
kind: type_fn_attr
1162-
default_fn: cast_signed
1163-
indexing_maps: !LinalgIndexingMapsConfig
1164-
static_indexing_maps:
1165-
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
1166-
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d1, d2)>
1167-
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
1168-
iterator_types:
1169-
- parallel
1170-
- parallel
1171-
- reduction
1172-
assignments:
1173-
- !ScalarAssign
1174-
arg: C
1175-
value: !ScalarExpression
1176-
scalar_fn:
1177-
kind: binary
1178-
fn_name: add
1179-
operands:
1180-
- !ScalarExpression
1181-
scalar_arg: C
1182-
- !ScalarExpression
1183-
scalar_fn:
1184-
kind: binary
1185-
fn_name: mul
1186-
operands:
1187-
- !ScalarExpression
1188-
scalar_fn:
1189-
kind: type
1190-
attr_name: cast
1191-
type_var: U
1192-
operands:
1193-
- !ScalarExpression
1194-
scalar_arg: A
1195-
- !ScalarExpression
1196-
scalar_fn:
1197-
kind: type
1198-
attr_name: cast
1199-
type_var: U
1200-
operands:
1201-
- !ScalarExpression
1202-
scalar_arg: B
1203-
--- !LinalgOpConfig
12041058
metadata: !LinalgOpMetadata
12051059
name: mmt4d
12061060
cpp_class_name: Mmt4DOp
@@ -1358,146 +1212,6 @@ structured_op: !LinalgStructuredOpConfig
13581212
- !ScalarExpression
13591213
scalar_arg: rhs
13601214
--- !LinalgOpConfig
1361-
metadata: !LinalgOpMetadata
1362-
name: batch_matmul_transpose_a
1363-
cpp_class_name: BatchMatmulTransposeAOp
1364-
doc: |-
1365-
Performs a batched matrix multiplication of two 3D inputs where lhs operand
1366-
has its non-batch dimensions transposed.
1367-
1368-
Numeric casting is performed on the operands to the inner multiply, promoting
1369-
them to the same data type as the accumulator/output.
1370-
implements:
1371-
- LinalgContractionOpInterface
1372-
structured_op: !LinalgStructuredOpConfig
1373-
args:
1374-
- !LinalgOperandDefConfig
1375-
name: A
1376-
kind: input_tensor
1377-
type_var: T1
1378-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
1379-
- !LinalgOperandDefConfig
1380-
name: B
1381-
kind: input_tensor
1382-
type_var: T2
1383-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
1384-
- !LinalgOperandDefConfig
1385-
name: C
1386-
kind: output_tensor
1387-
type_var: U
1388-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
1389-
indexing_maps: !LinalgIndexingMapsConfig
1390-
static_indexing_maps:
1391-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d1)>
1392-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
1393-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
1394-
iterator_types:
1395-
- parallel
1396-
- parallel
1397-
- parallel
1398-
- reduction
1399-
assignments:
1400-
- !ScalarAssign
1401-
arg: C
1402-
value: !ScalarExpression
1403-
scalar_fn:
1404-
kind: binary
1405-
fn_name: add
1406-
operands:
1407-
- !ScalarExpression
1408-
scalar_arg: C
1409-
- !ScalarExpression
1410-
scalar_fn:
1411-
kind: binary
1412-
fn_name: mul
1413-
operands:
1414-
- !ScalarExpression
1415-
scalar_fn:
1416-
kind: type
1417-
fn_name: cast_signed
1418-
type_var: U
1419-
operands:
1420-
- !ScalarExpression
1421-
scalar_arg: A
1422-
- !ScalarExpression
1423-
scalar_fn:
1424-
kind: type
1425-
fn_name: cast_signed
1426-
type_var: U
1427-
operands:
1428-
- !ScalarExpression
1429-
scalar_arg: B
1430-
--- !LinalgOpConfig
1431-
metadata: !LinalgOpMetadata
1432-
name: batch_matmul_transpose_b
1433-
cpp_class_name: BatchMatmulTransposeBOp
1434-
doc: |-
1435-
Performs a batched matrix multiplication of two 3D inputs where rhs operand
1436-
has its non-batch dimensions transposed.
1437-
1438-
Numeric casting is performed on the operands to the inner multiply, promoting
1439-
them to the same data type as the accumulator/output.
1440-
implements:
1441-
- LinalgContractionOpInterface
1442-
structured_op: !LinalgStructuredOpConfig
1443-
args:
1444-
- !LinalgOperandDefConfig
1445-
name: A
1446-
kind: input_tensor
1447-
type_var: T1
1448-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
1449-
- !LinalgOperandDefConfig
1450-
name: B
1451-
kind: input_tensor
1452-
type_var: T2
1453-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
1454-
- !LinalgOperandDefConfig
1455-
name: C
1456-
kind: output_tensor
1457-
type_var: U
1458-
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
1459-
indexing_maps: !LinalgIndexingMapsConfig
1460-
static_indexing_maps:
1461-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
1462-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d2, d3)>
1463-
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
1464-
iterator_types:
1465-
- parallel
1466-
- parallel
1467-
- parallel
1468-
- reduction
1469-
assignments:
1470-
- !ScalarAssign
1471-
arg: C
1472-
value: !ScalarExpression
1473-
scalar_fn:
1474-
kind: binary
1475-
fn_name: add
1476-
operands:
1477-
- !ScalarExpression
1478-
scalar_arg: C
1479-
- !ScalarExpression
1480-
scalar_fn:
1481-
kind: binary
1482-
fn_name: mul
1483-
operands:
1484-
- !ScalarExpression
1485-
scalar_fn:
1486-
kind: type
1487-
fn_name: cast_signed
1488-
type_var: U
1489-
operands:
1490-
- !ScalarExpression
1491-
scalar_arg: A
1492-
- !ScalarExpression
1493-
scalar_fn:
1494-
kind: type
1495-
fn_name: cast_signed
1496-
type_var: U
1497-
operands:
1498-
- !ScalarExpression
1499-
scalar_arg: B
1500-
--- !LinalgOpConfig
15011215
metadata: !LinalgOpMetadata
15021216
name: quantized_batch_matmul
15031217
cpp_class_name: QuantizedBatchMatmulOp

mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -320,10 +320,6 @@ void linalg::populateBlockPackMatmulPatterns(
320320
RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
321321
patterns.add<BlockPackMatmul<linalg::GenericOp>,
322322
BlockPackMatmul<linalg::MatmulOp>,
323-
BlockPackMatmul<linalg::BatchMatmulOp>,
324-
BlockPackMatmul<linalg::MatmulTransposeAOp>,
325-
BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
326-
BlockPackMatmul<linalg::MatmulTransposeBOp>,
327-
BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
323+
BlockPackMatmul<linalg::BatchMatmulOp>>(
328324
patterns.getContext(), controlFn);
329325
}

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,12 +1013,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
10131013
static bool constexpr reduceLeft =
10141014
(std::is_same_v<FromOpTy, BatchMatmulOp> &&
10151015
std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1016-
(std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
1017-
std::is_same_v<ToOpTy, BatchVecmatOp>) ||
10181016
(std::is_same_v<FromOpTy, MatmulOp> &&
10191017
std::is_same_v<ToOpTy, VecmatOp>) ||
1020-
(std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
1021-
std::is_same_v<ToOpTy, VecmatOp>) ||
10221018
(std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
10231019

10241020
/// Look for non-batch spatial dims to collapse.
@@ -1074,27 +1070,15 @@ void mlir::linalg::populateContractionOpRankReducingPatterns(
10741070
MLIRContext *context = patterns.getContext();
10751071
// Unbatching patterns for unit batch size
10761072
patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
1077-
patterns
1078-
.add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
1079-
context);
1080-
patterns
1081-
.add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
1082-
context);
10831073
patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
10841074
patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
10851075

10861076
// Non-batch rank 1 reducing patterns
10871077
patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
10881078
patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
1089-
patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
1090-
patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
10911079
// Batch rank 1 reducing patterns
10921080
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
10931081
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
1094-
patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
1095-
context);
1096-
patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
1097-
context);
10981082

10991083
// Non-batch rank 0 reducing patterns
11001084
patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -234,19 +234,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
234234

235235
/// Codegen the different matmul variants.
236236
if (numOfBatchDims) {
237-
if (a == IndexMatchResult::Transposed)
238-
return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
239-
genericOp);
240-
if (b == IndexMatchResult::Transposed)
241-
return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
242-
genericOp);
243237
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
244238
}
245-
246-
if (a == IndexMatchResult::Transposed)
247-
return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
248-
if (b == IndexMatchResult::Transposed)
249-
return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
250239
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
251240
}
252241

0 commit comments

Comments
 (0)