-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[MLIR][Linalg] Remove matmul_transpose variants #147961
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Removes the `(batch_)matmul_transpose_{a|b}` variants from OpDSL and replace it with `matmul affine_maps [...]` whenever appropriate.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sme Author: Renato Golin (rengolin) ChangesRemoves the See: https://discourse.llvm.org/t/deprecate-batch-matmul-transpose-a-b-linalg-operations/87245 Issues investigated:
Arm tests validated by @banach-space (thanks!!). Patch is 75.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147961.diff 20 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 3637147c5a90d..9aae1b850c3a0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1055,152 +1055,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: BZp
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: matmul_transpose_a
- cpp_class_name: MatmulTransposeAOp
- doc: |-
- Performs a matrix multiplication of two 2D inputs with lhs operand
- transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
- - !LinalgOperandDefConfig
- name: cast
- kind: type_fn_attr
- default_fn: cast_signed
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d0)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
- iterator_types:
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: matmul_transpose_b
- cpp_class_name: MatmulTransposeBOp
- doc: |-
- Performs a matrix multiplication of two 2D inputs with rhs operand
- transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
- - !LinalgOperandDefConfig
- name: cast
- kind: type_fn_attr
- default_fn: cast_signed
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d1, d2)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
- iterator_types:
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: mmt4d
cpp_class_name: Mmt4DOp
@@ -1358,146 +1212,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: batch_matmul_transpose_a
- cpp_class_name: BatchMatmulTransposeAOp
- doc: |-
- Performs a batched matrix multiplication of two 3D inputs where lhs operand
- has its non-batch dimensions transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d1)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
- iterator_types:
- - parallel
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: batch_matmul_transpose_b
- cpp_class_name: BatchMatmulTransposeBOp
- doc: |-
- Performs a batched matrix multiplication of two 3D inputs where rhs operand
- has its non-batch dimensions transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d2, d3)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
- iterator_types:
- - parallel
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: quantized_batch_matmul
cpp_class_name: QuantizedBatchMatmulOp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 3908d73f5e0e1..57f898458516e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -320,10 +320,6 @@ void linalg::populateBlockPackMatmulPatterns(
RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
patterns.add<BlockPackMatmul<linalg::GenericOp>,
BlockPackMatmul<linalg::MatmulOp>,
- BlockPackMatmul<linalg::BatchMatmulOp>,
- BlockPackMatmul<linalg::MatmulTransposeAOp>,
- BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
- BlockPackMatmul<linalg::MatmulTransposeBOp>,
- BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
+ BlockPackMatmul<linalg::BatchMatmulOp>>(
patterns.getContext(), controlFn);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index e0062d15e61ca..0cd2b6810ab9a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -1013,12 +1013,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
static bool constexpr reduceLeft =
(std::is_same_v<FromOpTy, BatchMatmulOp> &&
std::is_same_v<ToOpTy, BatchVecmatOp>) ||
- (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
- std::is_same_v<ToOpTy, BatchVecmatOp>) ||
(std::is_same_v<FromOpTy, MatmulOp> &&
std::is_same_v<ToOpTy, VecmatOp>) ||
- (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
- std::is_same_v<ToOpTy, VecmatOp>) ||
(std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
/// Look for non-batch spatial dims to collapse.
@@ -1074,27 +1070,15 @@ void mlir::linalg::populateContractionOpRankReducingPatterns(
MLIRContext *context = patterns.getContext();
// Unbatching patterns for unit batch size
patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
- patterns
- .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
- context);
- patterns
- .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
- context);
patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
// Non-batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
- patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
- patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
// Batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
- patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
- context);
- patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
- context);
// Non-batch rank 0 reducing patterns
patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 455e1a6d146d1..35ba4f159113f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -234,19 +234,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
/// Codegen the different matmul variants.
if (numOfBatchDims) {
- if (a == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
- genericOp);
- if (b == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
- genericOp);
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
}
-
- if (a == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
- if (b == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index 934781d1cab75..086f9e5d05e6f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -23,10 +23,11 @@ using namespace mlir::linalg;
///
/// with
///
-/// linalg.matmul_transpose_a(linalg.transpose(a), b)
+/// linalg.matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
///
/// By default the LHS is transposed. Set `transposeLHS=false` to
/// transpose RHS instead.
+/// FIXME: This API is not intuitive, replace LHS=false with something better
FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
linalg::MatmulOp matmulOp,
bool transposeLHS) {
@@ -57,18 +58,31 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{1, 0});
- Operation *newMatmulOp;
+ Value newLHS, newRHS;
+ AffineMap mapLHS, mapRHS, mapOut;
+ AffineExpr d0, d1, d2;
+ auto context = rewriter.getContext();
+ bindDims(context, d0, d1, d2);
if (transposeLHS) {
- newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
- loc, matmulOp.getResultTypes(),
- ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
- matmulOp.getOutputs());
+ newLHS = transposeOp->getResult(0);
+ newRHS = matmulOp.getInputs()[1];
+ mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
+ mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
+ mapOut = AffineMap::get(3, 0, {d0, d1}, context);
} else {
- newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
- loc, matmulOp.getResultTypes(),
- ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
- matmulOp.getOutputs());
+ newLHS = matmulOp.getInputs()[0];
+ newRHS = transposeOp->getResult(0);
+ mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
+ mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
+ mapOut = AffineMap::get(3, 0, {d0, d1}, context);
}
+ Operation *newMatmulOp = rewriter.create<linalg::MatmulOp>(
+ loc, matmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
+ matmulOp.getOutputs());
+ newMatmulOp->setAttr("indexing_maps",
+ rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
+ AffineMapAttr::get(mapRHS),
+ AffineMapAttr::get(mapOut)}));
rewriter.replaceOp(matmulOp, newMatmulOp);
return newMatmulOp;
}
@@ -79,10 +93,11 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
///
/// with
///
-/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
+/// linalg.batch_matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
///
/// Only the non-batch dimensions are transposed. By default the LHS is
/// transposed. Set `transposeLHS=false` to transpose RHS instead.
+/// FIXME: This API is not intuitive, replace LHS=false with something better
FailureOr<Operation *>
mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
linalg::BatchMatmulOp batchMatmulOp,
@@ -114,18 +129,31 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
type.getElementType(), dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
- Operation *newMatmulOp;
+ Value newLHS, newRHS;
+ AffineMap mapLHS, mapRHS, mapOut;
+ AffineExpr d0, d1, d2, d3;
+ auto context = rewriter.getContext();
+ bindDims(context, d0, d1, d2, d3);
if (transposeLHS) {
- newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
- loc, batchMatmulOp.getResultTypes(),
- ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
- batchMatmulOp.getOutputs());
+ newLHS = transposeOp->getResult(0);
+ newRHS = batchMatmulOp.getInputs()[1];
+ mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
+ mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
+ mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
} else {
- newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
- loc, batchMatmulOp.getResultTypes(),
- ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
- batchMatmulOp.getOutputs());
+ newLHS = batchMatmulOp.getInputs()[0];
+ newRHS = transposeOp->getResult(0);
+ mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
+ mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
+ mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
}
+ Operation *newMatmulOp = rewriter.create<linalg::BatchMatmulOp>(
+ loc, batchMatmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
+ batchMatmulOp.getOutputs());
+ newMatmulOp->setAttr("indexing_maps",
+ rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
+ AffineMapAttr::get(mapRHS),
+ AffineMapAttr::get(mapOut)}));
rewriter.replaceOp(batchMatmulOp, newMatmulOp);
return newMatmulOp;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5a8c5eab3f444..7d6155218f422 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2423,7 +2423,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
"vectorization\n");
return failure();
}
- if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
+ if (isa<linalg::MatmulOp>(op)) {
LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
"is not supported\n");
return failure();
@@ -2462,15 +2462,9 @@ vectorizeScalableVectorPrecondition(Operation *op,
return failure();
}
- // Check to not let go the matmul with extended semantic, through this
- // transform.
- if (linalgOp.hasUserDefinedMaps())
- return failure();
-
// Cond 4: Only the following ops are supported in the
// presence of scalable vectors
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
- isa<linalg::MatmulTransposeAOp>(op) ||
isa<linalg::DepthwiseConv1DNwc...
[truncated]
|
@llvm/pr-subscribers-mlir-linalg Author: Renato Golin (rengolin) ChangesRemoves the See: https://discourse.llvm.org/t/deprecate-batch-matmul-transpose-a-b-linalg-operations/87245 Issues investigated:
Arm tests validated by @banach-space (thanks!!). Patch is 75.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147961.diff 20 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 3637147c5a90d..9aae1b850c3a0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1055,152 +1055,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: BZp
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: matmul_transpose_a
- cpp_class_name: MatmulTransposeAOp
- doc: |-
- Performs a matrix multiplication of two 2D inputs with lhs operand
- transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
- - !LinalgOperandDefConfig
- name: cast
- kind: type_fn_attr
- default_fn: cast_signed
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d0)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
- iterator_types:
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: matmul_transpose_b
- cpp_class_name: MatmulTransposeBOp
- doc: |-
- Performs a matrix multiplication of two 2D inputs with rhs operand
- transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
- - !LinalgOperandDefConfig
- name: cast
- kind: type_fn_attr
- default_fn: cast_signed
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d1, d2)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
- iterator_types:
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: mmt4d
cpp_class_name: Mmt4DOp
@@ -1358,146 +1212,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: batch_matmul_transpose_a
- cpp_class_name: BatchMatmulTransposeAOp
- doc: |-
- Performs a batched matrix multiplication of two 3D inputs where lhs operand
- has its non-batch dimensions transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d1)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
- iterator_types:
- - parallel
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: batch_matmul_transpose_b
- cpp_class_name: BatchMatmulTransposeBOp
- doc: |-
- Performs a batched matrix multiplication of two 3D inputs where rhs operand
- has its non-batch dimensions transposed.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d2, d3)>
- - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
- iterator_types:
- - parallel
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- fn_name: cast_signed
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: quantized_batch_matmul
cpp_class_name: QuantizedBatchMatmulOp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 3908d73f5e0e1..57f898458516e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -320,10 +320,6 @@ void linalg::populateBlockPackMatmulPatterns(
RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
patterns.add<BlockPackMatmul<linalg::GenericOp>,
BlockPackMatmul<linalg::MatmulOp>,
- BlockPackMatmul<linalg::BatchMatmulOp>,
- BlockPackMatmul<linalg::MatmulTransposeAOp>,
- BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
- BlockPackMatmul<linalg::MatmulTransposeBOp>,
- BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
+ BlockPackMatmul<linalg::BatchMatmulOp>>(
patterns.getContext(), controlFn);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index e0062d15e61ca..0cd2b6810ab9a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -1013,12 +1013,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
static bool constexpr reduceLeft =
(std::is_same_v<FromOpTy, BatchMatmulOp> &&
std::is_same_v<ToOpTy, BatchVecmatOp>) ||
- (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
- std::is_same_v<ToOpTy, BatchVecmatOp>) ||
(std::is_same_v<FromOpTy, MatmulOp> &&
std::is_same_v<ToOpTy, VecmatOp>) ||
- (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
- std::is_same_v<ToOpTy, VecmatOp>) ||
(std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
/// Look for non-batch spatial dims to collapse.
@@ -1074,27 +1070,15 @@ void mlir::linalg::populateContractionOpRankReducingPatterns(
MLIRContext *context = patterns.getContext();
// Unbatching patterns for unit batch size
patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
- patterns
- .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
- context);
- patterns
- .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
- context);
patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
// Non-batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
- patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
- patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
// Batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
- patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
- context);
- patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
- context);
// Non-batch rank 0 reducing patterns
patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 455e1a6d146d1..35ba4f159113f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -234,19 +234,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
/// Codegen the different matmul variants.
if (numOfBatchDims) {
- if (a == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
- genericOp);
- if (b == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
- genericOp);
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
}
-
- if (a == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
- if (b == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index 934781d1cab75..086f9e5d05e6f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -23,10 +23,11 @@ using namespace mlir::linalg;
///
/// with
///
-/// linalg.matmul_transpose_a(linalg.transpose(a), b)
+/// linalg.matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
///
/// By default the LHS is transposed. Set `transposeLHS=false` to
/// transpose RHS instead.
+/// FIXME: This API is not intuitive, replace LHS=false with something better
FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
linalg::MatmulOp matmulOp,
bool transposeLHS) {
@@ -57,18 +58,31 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{1, 0});
- Operation *newMatmulOp;
+ Value newLHS, newRHS;
+ AffineMap mapLHS, mapRHS, mapOut;
+ AffineExpr d0, d1, d2;
+ auto context = rewriter.getContext();
+ bindDims(context, d0, d1, d2);
if (transposeLHS) {
- newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
- loc, matmulOp.getResultTypes(),
- ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
- matmulOp.getOutputs());
+ newLHS = transposeOp->getResult(0);
+ newRHS = matmulOp.getInputs()[1];
+ mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
+ mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
+ mapOut = AffineMap::get(3, 0, {d0, d1}, context);
} else {
- newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
- loc, matmulOp.getResultTypes(),
- ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
- matmulOp.getOutputs());
+ newLHS = matmulOp.getInputs()[0];
+ newRHS = transposeOp->getResult(0);
+ mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
+ mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
+ mapOut = AffineMap::get(3, 0, {d0, d1}, context);
}
+ Operation *newMatmulOp = rewriter.create<linalg::MatmulOp>(
+ loc, matmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
+ matmulOp.getOutputs());
+ newMatmulOp->setAttr("indexing_maps",
+ rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
+ AffineMapAttr::get(mapRHS),
+ AffineMapAttr::get(mapOut)}));
rewriter.replaceOp(matmulOp, newMatmulOp);
return newMatmulOp;
}
@@ -79,10 +93,11 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
///
/// with
///
-/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
+/// linalg.batch_matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
///
/// Only the non-batch dimensions are transposed. By default the LHS is
/// transposed. Set `transposeLHS=false` to transpose RHS instead.
+/// FIXME: This API is not intuitive, replace LHS=false with something better
FailureOr<Operation *>
mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
linalg::BatchMatmulOp batchMatmulOp,
@@ -114,18 +129,31 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
type.getElementType(), dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
- Operation *newMatmulOp;
+ Value newLHS, newRHS;
+ AffineMap mapLHS, mapRHS, mapOut;
+ AffineExpr d0, d1, d2, d3;
+ auto context = rewriter.getContext();
+ bindDims(context, d0, d1, d2, d3);
if (transposeLHS) {
- newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
- loc, batchMatmulOp.getResultTypes(),
- ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
- batchMatmulOp.getOutputs());
+ newLHS = transposeOp->getResult(0);
+ newRHS = batchMatmulOp.getInputs()[1];
+ mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
+ mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
+ mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
} else {
- newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
- loc, batchMatmulOp.getResultTypes(),
- ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
- batchMatmulOp.getOutputs());
+ newLHS = batchMatmulOp.getInputs()[0];
+ newRHS = transposeOp->getResult(0);
+ mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
+ mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
+ mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
}
+ Operation *newMatmulOp = rewriter.create<linalg::BatchMatmulOp>(
+ loc, batchMatmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
+ batchMatmulOp.getOutputs());
+ newMatmulOp->setAttr("indexing_maps",
+ rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
+ AffineMapAttr::get(mapRHS),
+ AffineMapAttr::get(mapOut)}));
rewriter.replaceOp(batchMatmulOp, newMatmulOp);
return newMatmulOp;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5a8c5eab3f444..7d6155218f422 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2423,7 +2423,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
"vectorization\n");
return failure();
}
- if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
+ if (isa<linalg::MatmulOp>(op)) {
LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
"is not supported\n");
return failure();
@@ -2462,15 +2462,9 @@ vectorizeScalableVectorPrecondition(Operation *op,
return failure();
}
- // Check to not let go the matmul with extended semantic, through this
- // transform.
- if (linalgOp.hasUserDefinedMaps())
- return failure();
-
// Cond 4: Only the following ops are supported in the
// presence of scalable vectors
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
- isa<linalg::MatmulTransposeAOp>(op) ||
isa<linalg::DepthwiseConv1DNwc...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Could we get some time to manage the churn with these changes? |
Of course!! No rush from my side. Let me know if you need some help. |
It would be ideal if we had a deprecation path instead of just a switch. The changes needed here are
In LLVM this would be something that would be deprecated on one release and removed on the next release (it is pretty much a break the world change). I dont think all the front end dialect folks are paying attention to this. |
newRHS = transposeOp->getResult(0); | ||
mapLHS = AffineMap::get(3, 0, {d0, d2}, context); | ||
mapRHS = AffineMap::get(3, 0, {d1, d2}, context); | ||
mapOut = AffineMap::get(3, 0, {d0, d1}, context); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than requiring all builders to be replaced everywhere, we can add C++ specializations as intermediate replacements similar to ConstantIndexOp:
llvm-project/mlir/include/mlir/Dialect/Arith/IR/Arith.h
Lines 96 to 109 in 9a0e03f
/// Specialization of `arith.constant` op that returns an integer of index type. | |
class ConstantIndexOp : public arith::ConstantOp { | |
public: | |
using arith::ConstantOp::ConstantOp; | |
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); } | |
/// Build a constant int op that produces an index. | |
static void build(OpBuilder &builder, OperationState &result, int64_t value); | |
inline int64_t value() { | |
return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt(); | |
} | |
static bool classof(Operation *op); | |
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pattern matching can't be replaced as easily, but we can add bespoke C++ for it like matchMatmulTransposeB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Or we could actually just add a new builder with an affine map attribute?
On the pattern matching, I'm looking into something like this to help:
https://github.com/libxsmm/tpp-mlir/blob/main/lib/TPP/IR/StructuredOpMatcher.cpp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The point of adding a specialization would be to skip the need for callers to construct affine maps. IOW no C++ changes at the point of operation construction, only the underlying IR generated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. "Technically", users can pass whatever affine maps (that make sense), but you're right that we don't want people to have to remember all variations and have them on the code like that.
Last year we discuss an attribute like transpose_a
and transpose_b
. That was not a good idea, because the affine map is mandatory and cover all those cases. But we can still have it in a builder, right? Like an enum that can work across all matmul
variants (but not contract
)?
builder.create<linalg::MatmulOp>(..., transpose_a | broadcast_b);
I'm trying to avoid creating a specialization to all combinations, if you notice. 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
builder.createlinalg::MatmulOp(..., transpose_a | broadcast_b);
This sounds good too. I mainly want to minimize the number of places downstream users need to setup affine maps. A specialization would be completely no change, but enums would be pretty lightweight (I suspect every frontend will end up redesigning the same "map generator" util if we don't provide it).
Ack. I tried to gather that in the forum, but people were not paying attention to that, either. 😃 Also, this has been discussed for years and the original patch was almost 1 year ago.
This should be straightforward. This PR has the recipe: replace
This is less straightforward. From the discussion back then, I had assumed we would all work towards not using it internally, so that we could easily switch later. This is 1 year later.
Not exactly. We have backwards compatibility for LLVM IR, and the support policy describes how to deprecate entire components, but those points apply. We neither have backwards compatibility guarantees for MLIR, nor we're talking about removing an entire dialect. Changes in dialect operations occur more often, and to be honest, this one is already turning 1 year old. I have just removed some old element-wise operations and it didn't break the world. I'm assuming IREE isn't really using releases, here, so talking about releases doesn't make a lot of sense. But that doesn't mean we can be careless when removing stuff, it just means we can do best effort on a case by case basis. We also need to balance upstream versus downstream needs, where the latter cannot impose hard constraints on the former, but can suggest time-frames and ways around. My proposal is to fix |
I have just noticed you're talking here about the |
I see a lot of usage of matmul transpose in IREE: https://github.com/search?q=repo%3Airee-org%2Firee%20matmul_transpose&type=code A few things I can infer from just that search:
Of course, my assumptions may be wrong, this is just me being proactive. 😃 |
@@ -504,7 +504,7 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C | |||
|
|||
module attributes {transform.with_named_sequence} { | |||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { | |||
%0 = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was the previous IR wrong? It seemed looking for matmul_transpose_b but line 500 was never that. I may be missing something.
In theory everything there is coverage related or test related. Because Linalg is IREE's input, we have to support the existing named variants. We also try to retain the name for as long as possible for convenience and readability purposes, but also that gives ostensible reliance on those ops per your search. What I believe Mahesh is primarily asking for is a grace period to align that illusion of independence with reality. I see largely two paths:
Last comment I'll add about IREE, there is a fair amount of historical code that is currently unmaintained. Current development has been careful about parity between named and generic variants, but the old stuff wasn't necessarily. With that said, when old unused stuff breaks it gets deleted so not a huge deal :) Honestly the part about this change I'm looking forward to the least is all the random IR snippets I have stashed away as a personal testing bank that I'll have to go update :P. Otherwise the direction here is great! |
Oh also concurrent multi-project llvm bumps are really painful so
Any way we can stage it so torch-mlir gets to go first would be very much appreciated for us. I think it's just |
Right, this is kinda what we did last year, minus the part where we agreed on a date. I'm trying to se the date now, but I don't think we should wait another year. 😃
I like this very much. But I also want to try to avoid too much bespoke matchers because that explodes as well. I'm ok with specializations for the common cases, but that causes confusion for the non-common cases. I'm ok with C++ functions that match something, but then everyone adds theirs and we get duplicates. So, some plan here would be good to have.
I empathise. That's why I proposed a
That's why we created the
I see. How do we sync those changes with upstream? |
Removes the
(batch_)matmul_transpose_{a|b}
variants from OpDSL and replace it withmatmul affine_maps [...]
whenever appropriate. This is in line with the plan, and can be done since #104783 merged.See: https://discourse.llvm.org/t/deprecate-batch-matmul-transpose-a-b-linalg-operations/87245
Issues investigated:
matmul
instead, so change to that.matmul
+ affine maps.Arm tests validated by @banach-space (thanks!!).