Skip to content

[MLIR][Linalg] Remove matmul_transpose variants #147961

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

rengolin
Copy link
Member

Removes the (batch_)matmul_transpose_{a|b} variants from OpDSL and replace it with matmul affine_maps [...] whenever appropriate. This is in line with the plan, and can be done since #104783 merged.

See: https://discourse.llvm.org/t/deprecate-batch-matmul-transpose-a-b-linalg-operations/87245

Issues investigated:

  • pad transform tests that could use matmul instead, so change to that.
  • ArmSME test using transpose actually needed it, so changed to matmul + affine maps.

Arm tests validated by @banach-space (thanks!!).

Removes the `(batch_)matmul_transpose_{a|b}` variants from OpDSL and
replace it with `matmul affine_maps [...]` whenever appropriate.
@llvmbot
Copy link
Member

llvmbot commented Jul 10, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sme

Author: Renato Golin (rengolin)

Changes

Removes the (batch_)matmul_transpose_{a|b} variants from OpDSL and replace it with matmul affine_maps [...] whenever appropriate. This is in line with the plan, and can be done since #104783 merged.

See: https://discourse.llvm.org/t/deprecate-batch-matmul-transpose-a-b-linalg-operations/87245

Issues investigated:

  • pad transform tests that could use matmul instead, so change to that.
  • ArmSME test using transpose actually needed it, so changed to matmul + affine maps.

Arm tests validated by @banach-space (thanks!!).


Patch is 75.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147961.diff

20 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (-286)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp (+1-5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (-16)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (-11)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp (+48-20)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+1-7)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (-93)
  • (modified) mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir (-50)
  • (modified) mlir/test/Dialect/Linalg/block-pack-matmul.mlir (-144)
  • (modified) mlir/test/Dialect/Linalg/fold-add-into-dest.mlir (-30)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (-111)
  • (modified) mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir (-85)
  • (modified) mlir/test/Dialect/Linalg/tile-to-forall.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/transform-op-pad.mlir (+3-3)
  • (modified) mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir (-89)
  • (modified) mlir/test/Dialect/Linalg/transpose-matmul.mlir (+26-12)
  • (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir (+7-2)
  • (modified) mlir/test/python/dialects/linalg/opdsl/test_core_named_ops.py (+1-1)
  • (modified) mlir/utils/tree-sitter-mlir/dialect/linalg.js (-2)
  • (modified) mlir/utils/tree-sitter-mlir/queries/highlights.scm (-2)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 3637147c5a90d..9aae1b850c3a0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1055,152 +1055,6 @@ structured_op: !LinalgStructuredOpConfig
                     - !ScalarExpression
                       scalar_arg: BZp
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: matmul_transpose_a
-  cpp_class_name: MatmulTransposeAOp
-  doc: |-
-    Performs a matrix multiplication of two 2D inputs with lhs operand
-    transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
-  - !LinalgOperandDefConfig
-    name: cast
-    kind: type_fn_attr
-    default_fn: cast_signed
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d0)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
-  iterator_types:
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: matmul_transpose_b
-  cpp_class_name: MatmulTransposeBOp
-  doc: |-
-    Performs a matrix multiplication of two 2D inputs with rhs operand
-    transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
-  - !LinalgOperandDefConfig
-    name: cast
-    kind: type_fn_attr
-    default_fn: cast_signed
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d1, d2)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
-  iterator_types:
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: mmt4d
   cpp_class_name: Mmt4DOp
@@ -1358,146 +1212,6 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: rhs
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: batch_matmul_transpose_a
-  cpp_class_name: BatchMatmulTransposeAOp
-  doc: |-
-    Performs a batched matrix multiplication of two 3D inputs where lhs operand
-    has its non-batch dimensions transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d1)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
-  iterator_types:
-  - parallel
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: batch_matmul_transpose_b
-  cpp_class_name: BatchMatmulTransposeBOp
-  doc: |-
-    Performs a batched matrix multiplication of two 3D inputs where rhs operand
-    has its non-batch dimensions transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d2, d3)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
-  iterator_types:
-  - parallel
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: quantized_batch_matmul
   cpp_class_name: QuantizedBatchMatmulOp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 3908d73f5e0e1..57f898458516e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -320,10 +320,6 @@ void linalg::populateBlockPackMatmulPatterns(
     RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
   patterns.add<BlockPackMatmul<linalg::GenericOp>,
                BlockPackMatmul<linalg::MatmulOp>,
-               BlockPackMatmul<linalg::BatchMatmulOp>,
-               BlockPackMatmul<linalg::MatmulTransposeAOp>,
-               BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
-               BlockPackMatmul<linalg::MatmulTransposeBOp>,
-               BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
+               BlockPackMatmul<linalg::BatchMatmulOp>>(
       patterns.getContext(), controlFn);
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index e0062d15e61ca..0cd2b6810ab9a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -1013,12 +1013,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
   static bool constexpr reduceLeft =
       (std::is_same_v<FromOpTy, BatchMatmulOp> &&
        std::is_same_v<ToOpTy, BatchVecmatOp>) ||
-      (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
-       std::is_same_v<ToOpTy, BatchVecmatOp>) ||
       (std::is_same_v<FromOpTy, MatmulOp> &&
        std::is_same_v<ToOpTy, VecmatOp>) ||
-      (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
-       std::is_same_v<ToOpTy, VecmatOp>) ||
       (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
 
   /// Look for non-batch spatial dims to collapse.
@@ -1074,27 +1070,15 @@ void mlir::linalg::populateContractionOpRankReducingPatterns(
   MLIRContext *context = patterns.getContext();
   // Unbatching patterns for unit batch size
   patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
-  patterns
-      .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
-          context);
-  patterns
-      .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
-          context);
   patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
   patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
 
   // Non-batch rank 1 reducing patterns
   patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
   patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
-  patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
-  patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
   // Batch rank 1 reducing patterns
   patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
   patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
-  patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
-      context);
-  patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
-      context);
 
   // Non-batch rank 0 reducing patterns
   patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 455e1a6d146d1..35ba4f159113f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -234,19 +234,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
 
   /// Codegen the different matmul variants.
   if (numOfBatchDims) {
-    if (a == IndexMatchResult::Transposed)
-      return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
-                                                               genericOp);
-    if (b == IndexMatchResult::Transposed)
-      return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
-                                                               genericOp);
     return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
   }
-
-  if (a == IndexMatchResult::Transposed)
-    return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
-  if (b == IndexMatchResult::Transposed)
-    return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index 934781d1cab75..086f9e5d05e6f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -23,10 +23,11 @@ using namespace mlir::linalg;
 ///
 /// with
 ///
-///   linalg.matmul_transpose_a(linalg.transpose(a), b)
+///   linalg.matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
 ///
 /// By default the LHS is transposed. Set `transposeLHS=false` to
 /// transpose RHS instead.
+/// FIXME: This API is not intuitive, replace LHS=false with something better
 FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
                                                      linalg::MatmulOp matmulOp,
                                                      bool transposeLHS) {
@@ -57,18 +58,31 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
       dynamicDims);
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
       loc, input, empty, ArrayRef<int64_t>{1, 0});
-  Operation *newMatmulOp;
+  Value newLHS, newRHS;
+  AffineMap mapLHS, mapRHS, mapOut;
+  AffineExpr d0, d1, d2;
+  auto context = rewriter.getContext();
+  bindDims(context, d0, d1, d2);
   if (transposeLHS) {
-    newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
-        loc, matmulOp.getResultTypes(),
-        ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
-        matmulOp.getOutputs());
+    newLHS = transposeOp->getResult(0);
+    newRHS = matmulOp.getInputs()[1];
+    mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
+    mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
+    mapOut = AffineMap::get(3, 0, {d0, d1}, context);
   } else {
-    newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
-        loc, matmulOp.getResultTypes(),
-        ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
-        matmulOp.getOutputs());
+    newLHS = matmulOp.getInputs()[0];
+    newRHS = transposeOp->getResult(0);
+    mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
+    mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
+    mapOut = AffineMap::get(3, 0, {d0, d1}, context);
   }
+  Operation *newMatmulOp = rewriter.create<linalg::MatmulOp>(
+      loc, matmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
+      matmulOp.getOutputs());
+  newMatmulOp->setAttr("indexing_maps",
+                       rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
+                                              AffineMapAttr::get(mapRHS),
+                                              AffineMapAttr::get(mapOut)}));
   rewriter.replaceOp(matmulOp, newMatmulOp);
   return newMatmulOp;
 }
@@ -79,10 +93,11 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
 ///
 /// with
 ///
-///   linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
+///   linalg.batch_matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
 ///
 /// Only the non-batch dimensions are transposed. By default the LHS is
 /// transposed. Set `transposeLHS=false` to transpose RHS instead.
+/// FIXME: This API is not intuitive, replace LHS=false with something better
 FailureOr<Operation *>
 mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
                                    linalg::BatchMatmulOp batchMatmulOp,
@@ -114,18 +129,31 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
       type.getElementType(), dynamicDims);
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
       loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
-  Operation *newMatmulOp;
+  Value newLHS, newRHS;
+  AffineMap mapLHS, mapRHS, mapOut;
+  AffineExpr d0, d1, d2, d3;
+  auto context = rewriter.getContext();
+  bindDims(context, d0, d1, d2, d3);
   if (transposeLHS) {
-    newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
-        loc, batchMatmulOp.getResultTypes(),
-        ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
-        batchMatmulOp.getOutputs());
+    newLHS = transposeOp->getResult(0);
+    newRHS = batchMatmulOp.getInputs()[1];
+    mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
+    mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
+    mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
   } else {
-    newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
-        loc, batchMatmulOp.getResultTypes(),
-        ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
-        batchMatmulOp.getOutputs());
+    newLHS = batchMatmulOp.getInputs()[0];
+    newRHS = transposeOp->getResult(0);
+    mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
+    mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
+    mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
   }
+  Operation *newMatmulOp = rewriter.create<linalg::BatchMatmulOp>(
+      loc, batchMatmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
+      batchMatmulOp.getOutputs());
+  newMatmulOp->setAttr("indexing_maps",
+                       rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
+                                              AffineMapAttr::get(mapRHS),
+                                              AffineMapAttr::get(mapOut)}));
   rewriter.replaceOp(batchMatmulOp, newMatmulOp);
   return newMatmulOp;
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5a8c5eab3f444..7d6155218f422 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2423,7 +2423,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
            "vectorization\n");
       return failure();
     }
-    if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
+    if (isa<linalg::MatmulOp>(op)) {
       LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
            "is not supported\n");
       return failure();
@@ -2462,15 +2462,9 @@ vectorizeScalableVectorPrecondition(Operation *op,
       return failure();
   }
 
-  // Check to not let go the matmul with extended semantic, through this
-  // transform.
-  if (linalgOp.hasUserDefinedMaps())
-    return failure();
-
   // Cond 4: Only the following ops are supported in the
   // presence of scalable vectors
   return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
-                 isa<linalg::MatmulTransposeAOp>(op) ||
                  isa<linalg::DepthwiseConv1DNwc...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 10, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Renato Golin (rengolin)

Changes

Removes the (batch_)matmul_transpose_{a|b} variants from OpDSL and replace it with matmul affine_maps [...] whenever appropriate. This is in line with the plan, and can be done since #104783 merged.

See: https://discourse.llvm.org/t/deprecate-batch-matmul-transpose-a-b-linalg-operations/87245

Issues investigated:

  • pad transform tests that could use matmul instead, so change to that.
  • ArmSME test using transpose actually needed it, so changed to matmul + affine maps.

Arm tests validated by @banach-space (thanks!!).


Patch is 75.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147961.diff

20 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (-286)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp (+1-5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (-16)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (-11)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp (+48-20)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+1-7)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (-93)
  • (modified) mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir (-50)
  • (modified) mlir/test/Dialect/Linalg/block-pack-matmul.mlir (-144)
  • (modified) mlir/test/Dialect/Linalg/fold-add-into-dest.mlir (-30)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (-111)
  • (modified) mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir (-85)
  • (modified) mlir/test/Dialect/Linalg/tile-to-forall.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/transform-op-pad.mlir (+3-3)
  • (modified) mlir/test/Dialect/Linalg/transform-op-specialize-matmul.mlir (-89)
  • (modified) mlir/test/Dialect/Linalg/transpose-matmul.mlir (+26-12)
  • (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir (+7-2)
  • (modified) mlir/test/python/dialects/linalg/opdsl/test_core_named_ops.py (+1-1)
  • (modified) mlir/utils/tree-sitter-mlir/dialect/linalg.js (-2)
  • (modified) mlir/utils/tree-sitter-mlir/queries/highlights.scm (-2)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 3637147c5a90d..9aae1b850c3a0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1055,152 +1055,6 @@ structured_op: !LinalgStructuredOpConfig
                     - !ScalarExpression
                       scalar_arg: BZp
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: matmul_transpose_a
-  cpp_class_name: MatmulTransposeAOp
-  doc: |-
-    Performs a matrix multiplication of two 2D inputs with lhs operand
-    transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
-  - !LinalgOperandDefConfig
-    name: cast
-    kind: type_fn_attr
-    default_fn: cast_signed
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d0)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
-  iterator_types:
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: matmul_transpose_b
-  cpp_class_name: MatmulTransposeBOp
-  doc: |-
-    Performs a matrix multiplication of two 2D inputs with rhs operand
-    transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
-  - !LinalgOperandDefConfig
-    name: cast
-    kind: type_fn_attr
-    default_fn: cast_signed
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d1, d2)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
-  iterator_types:
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: mmt4d
   cpp_class_name: Mmt4DOp
@@ -1358,146 +1212,6 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: rhs
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: batch_matmul_transpose_a
-  cpp_class_name: BatchMatmulTransposeAOp
-  doc: |-
-    Performs a batched matrix multiplication of two 3D inputs where lhs operand
-    has its non-batch dimensions transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d1)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
-  iterator_types:
-  - parallel
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: batch_matmul_transpose_b
-  cpp_class_name: BatchMatmulTransposeBOp
-  doc: |-
-    Performs a batched matrix multiplication of two 3D inputs where rhs operand
-    has its non-batch dimensions transposed.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d2, d3)>
-    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
-  iterator_types:
-  - parallel
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                fn_name: cast_signed
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: quantized_batch_matmul
   cpp_class_name: QuantizedBatchMatmulOp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 3908d73f5e0e1..57f898458516e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -320,10 +320,6 @@ void linalg::populateBlockPackMatmulPatterns(
     RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
   patterns.add<BlockPackMatmul<linalg::GenericOp>,
                BlockPackMatmul<linalg::MatmulOp>,
-               BlockPackMatmul<linalg::BatchMatmulOp>,
-               BlockPackMatmul<linalg::MatmulTransposeAOp>,
-               BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
-               BlockPackMatmul<linalg::MatmulTransposeBOp>,
-               BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
+               BlockPackMatmul<linalg::BatchMatmulOp>>(
       patterns.getContext(), controlFn);
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index e0062d15e61ca..0cd2b6810ab9a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -1013,12 +1013,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
   static bool constexpr reduceLeft =
       (std::is_same_v<FromOpTy, BatchMatmulOp> &&
        std::is_same_v<ToOpTy, BatchVecmatOp>) ||
-      (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
-       std::is_same_v<ToOpTy, BatchVecmatOp>) ||
       (std::is_same_v<FromOpTy, MatmulOp> &&
        std::is_same_v<ToOpTy, VecmatOp>) ||
-      (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
-       std::is_same_v<ToOpTy, VecmatOp>) ||
       (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
 
   /// Look for non-batch spatial dims to collapse.
@@ -1074,27 +1070,15 @@ void mlir::linalg::populateContractionOpRankReducingPatterns(
   MLIRContext *context = patterns.getContext();
   // Unbatching patterns for unit batch size
   patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
-  patterns
-      .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
-          context);
-  patterns
-      .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
-          context);
   patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
   patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
 
   // Non-batch rank 1 reducing patterns
   patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
   patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
-  patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
-  patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
   // Batch rank 1 reducing patterns
   patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
   patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
-  patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
-      context);
-  patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
-      context);
 
   // Non-batch rank 0 reducing patterns
   patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 455e1a6d146d1..35ba4f159113f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -234,19 +234,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
 
   /// Codegen the different matmul variants.
   if (numOfBatchDims) {
-    if (a == IndexMatchResult::Transposed)
-      return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
-                                                               genericOp);
-    if (b == IndexMatchResult::Transposed)
-      return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
-                                                               genericOp);
     return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
   }
-
-  if (a == IndexMatchResult::Transposed)
-    return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
-  if (b == IndexMatchResult::Transposed)
-    return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index 934781d1cab75..086f9e5d05e6f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -23,10 +23,11 @@ using namespace mlir::linalg;
 ///
 /// with
 ///
-///   linalg.matmul_transpose_a(linalg.transpose(a), b)
+///   linalg.matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
 ///
 /// By default the LHS is transposed. Set `transposeLHS=false` to
 /// transpose RHS instead.
+/// FIXME: This API is not intuitive, replace LHS=false with something better
 FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
                                                      linalg::MatmulOp matmulOp,
                                                      bool transposeLHS) {
@@ -57,18 +58,31 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
       dynamicDims);
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
       loc, input, empty, ArrayRef<int64_t>{1, 0});
-  Operation *newMatmulOp;
+  Value newLHS, newRHS;
+  AffineMap mapLHS, mapRHS, mapOut;
+  AffineExpr d0, d1, d2;
+  auto context = rewriter.getContext();
+  bindDims(context, d0, d1, d2);
   if (transposeLHS) {
-    newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
-        loc, matmulOp.getResultTypes(),
-        ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
-        matmulOp.getOutputs());
+    newLHS = transposeOp->getResult(0);
+    newRHS = matmulOp.getInputs()[1];
+    mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
+    mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
+    mapOut = AffineMap::get(3, 0, {d0, d1}, context);
   } else {
-    newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
-        loc, matmulOp.getResultTypes(),
-        ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
-        matmulOp.getOutputs());
+    newLHS = matmulOp.getInputs()[0];
+    newRHS = transposeOp->getResult(0);
+    mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
+    mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
+    mapOut = AffineMap::get(3, 0, {d0, d1}, context);
   }
+  Operation *newMatmulOp = rewriter.create<linalg::MatmulOp>(
+      loc, matmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
+      matmulOp.getOutputs());
+  newMatmulOp->setAttr("indexing_maps",
+                       rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
+                                              AffineMapAttr::get(mapRHS),
+                                              AffineMapAttr::get(mapOut)}));
   rewriter.replaceOp(matmulOp, newMatmulOp);
   return newMatmulOp;
 }
@@ -79,10 +93,11 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
 ///
 /// with
 ///
-///   linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
+///   linalg.batch_matmul affine_maps { #A^T, #B, #C } (linalg.transpose(a), b)
 ///
 /// Only the non-batch dimensions are transposed. By default the LHS is
 /// transposed. Set `transposeLHS=false` to transpose RHS instead.
+/// FIXME: This API is not intuitive, replace LHS=false with something better
 FailureOr<Operation *>
 mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
                                    linalg::BatchMatmulOp batchMatmulOp,
@@ -114,18 +129,31 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
       type.getElementType(), dynamicDims);
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
       loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
-  Operation *newMatmulOp;
+  Value newLHS, newRHS;
+  AffineMap mapLHS, mapRHS, mapOut;
+  AffineExpr d0, d1, d2, d3;
+  auto context = rewriter.getContext();
+  bindDims(context, d0, d1, d2, d3);
   if (transposeLHS) {
-    newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
-        loc, batchMatmulOp.getResultTypes(),
-        ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
-        batchMatmulOp.getOutputs());
+    newLHS = transposeOp->getResult(0);
+    newRHS = batchMatmulOp.getInputs()[1];
+    mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
+    mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
+    mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
   } else {
-    newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
-        loc, batchMatmulOp.getResultTypes(),
-        ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
-        batchMatmulOp.getOutputs());
+    newLHS = batchMatmulOp.getInputs()[0];
+    newRHS = transposeOp->getResult(0);
+    mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
+    mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
+    mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
   }
+  Operation *newMatmulOp = rewriter.create<linalg::BatchMatmulOp>(
+      loc, batchMatmulOp.getResultTypes(), ValueRange{newLHS, newRHS},
+      batchMatmulOp.getOutputs());
+  newMatmulOp->setAttr("indexing_maps",
+                       rewriter.getArrayAttr({AffineMapAttr::get(mapLHS),
+                                              AffineMapAttr::get(mapRHS),
+                                              AffineMapAttr::get(mapOut)}));
   rewriter.replaceOp(batchMatmulOp, newMatmulOp);
   return newMatmulOp;
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5a8c5eab3f444..7d6155218f422 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2423,7 +2423,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
            "vectorization\n");
       return failure();
     }
-    if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
+    if (isa<linalg::MatmulOp>(op)) {
       LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
            "is not supported\n");
       return failure();
@@ -2462,15 +2462,9 @@ vectorizeScalableVectorPrecondition(Operation *op,
       return failure();
   }
 
-  // Check to not let go the matmul with extended semantic, through this
-  // transform.
-  if (linalgOp.hasUserDefinedMaps())
-    return failure();
-
   // Cond 4: Only the following ops are supported in the
   // presence of scalable vectors
   return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
-                 isa<linalg::MatmulTransposeAOp>(op) ||
                  isa<linalg::DepthwiseConv1DNwc...
[truncated]

Copy link

github-actions bot commented Jul 10, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@MaheshRavishankar
Copy link
Contributor

Could we get some time to manage the churn with these changes?

@rengolin
Copy link
Member Author

Could we get some time to manage the churn with these changes?

Of course!! No rush from my side. Let me know if you need some help.

@MaheshRavishankar
Copy link
Contributor

Could we get some time to manage the churn with these changes?

Of course!! No rush from my side. Let me know if you need some help.

It would be ideal if we had a deprecation path instead of just a switch. The changes needed here are

  • Torch-MLIR (and all other "front end dialects" like StableHLO, etc.) will need to be updated to not generate the batch-matmul operations
  • and downstream compilers (like IREE) need to make sure that they can navigate away from this.

In LLVM this would be something that would be deprecated on one release and removed on the next release (it is pretty much a break the world change). I dont think all the front end dialect folks are paying attention to this.

newRHS = transposeOp->getResult(0);
mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
mapOut = AffineMap::get(3, 0, {d0, d1}, context);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than requiring all builders to be replaced everywhere, we can add C++ specializations as intermediate replacements similar to ConstantIndexOp:

/// Specialization of `arith.constant` op that returns an integer of index type.
class ConstantIndexOp : public arith::ConstantOp {
public:
using arith::ConstantOp::ConstantOp;
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
/// Build a constant int op that produces an index.
static void build(OpBuilder &builder, OperationState &result, int64_t value);
inline int64_t value() {
return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
}
static bool classof(Operation *op);
};

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pattern matching can't be replaced as easily, but we can add bespoke C++ for it like matchMatmulTransposeB

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Or we could actually just add a new builder with an affine map attribute?

On the pattern matching, I'm looking into something like this to help:
https://github.com/libxsmm/tpp-mlir/blob/main/lib/TPP/IR/StructuredOpMatcher.cpp

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point of adding a specialization would be to skip the need for callers to construct affine maps. IOW no C++ changes at the point of operation construction, only the underlying IR generated.

Copy link
Member Author

@rengolin rengolin Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. "Technically", users can pass whatever affine maps (that make sense), but you're right that we don't want people to have to remember all variations and have them on the code like that.

Last year we discuss an attribute like transpose_a and transpose_b. That was not a good idea, because the affine map is mandatory and cover all those cases. But we can still have it in a builder, right? Like an enum that can work across all matmul variants (but not contract)?

  builder.create<linalg::MatmulOp>(..., transpose_a | broadcast_b);

I'm trying to avoid creating a specialization to all combinations, if you notice. 😄

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

builder.createlinalg::MatmulOp(..., transpose_a | broadcast_b);

This sounds good too. I mainly want to minimize the number of places downstream users need to setup affine maps. A specialization would be completely no change, but enums would be pretty lightweight (I suspect every frontend will end up redesigning the same "map generator" util if we don't provide it).

@rengolin
Copy link
Member Author

It would be ideal if we had a deprecation path instead of just a switch.

Ack. I tried to gather that in the forum, but people were not paying attention to that, either. 😃

Also, this has been discussed for years and the original patch was almost 1 year ago.

  • Torch-MLIR (and all other "front end dialects" like StableHLO, etc.) will need to be updated to not generate the batch-matmul operations

This should be straightforward. This PR has the recipe: replace *_matmul_transpose_* with matmul + appropriate affine map. Everything else just works. I could help with that, too, if needed.

  • and downstream compilers (like IREE) need to make sure that they can navigate away from this.

This is less straightforward. From the discussion back then, I had assumed we would all work towards not using it internally, so that we could easily switch later. This is 1 year later. ☹️

In LLVM this would be something that would be deprecated on one release and removed on the next release (it is pretty much a break the world change). I dont think all the front end dialect folks are paying attention to this.

Not exactly. We have backwards compatibility for LLVM IR, and the support policy describes how to deprecate entire components, but those points apply. We neither have backwards compatibility guarantees for MLIR, nor we're talking about removing an entire dialect. Changes in dialect operations occur more often, and to be honest, this one is already turning 1 year old. I have just removed some old element-wise operations and it didn't break the world.

I'm assuming IREE isn't really using releases, here, so talking about releases doesn't make a lot of sense. But that doesn't mean we can be careless when removing stuff, it just means we can do best effort on a case by case basis. We also need to balance upstream versus downstream needs, where the latter cannot impose hard constraints on the former, but can suggest time-frames and ways around.

My proposal is to fix torch-mlir to stop emitting matmul_transpose_* and understand what are the implications inside IREE (and any other linalg user that still relies on those operations). If IREE can be fixed quickly, we wait. If not, IREE can duplicate those ops inside linalgx and deprecate on your own speed.

@rengolin
Copy link
Member Author

  • Torch-MLIR (and all other "front end dialects" like StableHLO, etc.) will need to be updated to not generate the batch-matmul operations

I have just noticed you're talking here about the match_matmul variants, which are not the subject of this PR. Here I'm only removing the transposed variants, which do not seem to be anywhere in torch-mlir. Does that mean this PR is "fine" from your point of view, or do you still have concerns specifically with the transposed variants?

@rengolin
Copy link
Member Author

rengolin commented Jul 10, 2025

I see a lot of usage of matmul transpose in IREE: https://github.com/search?q=repo%3Airee-org%2Firee%20matmul_transpose&type=code

A few things I can infer from just that search:

  • Some usages are "representational", ie. could easily be replaced with the alternative.
  • if (!isa<linalg::MatmulTransposeBOp, linalg::BatchMatmulTransposeBOp>(op)) { could be replaced with the matcher that @qed discusses above. Here, just check something like isMapTranspose that works on both matmul and batch_matmul.
  • I assume {lowering_config = ...} attribute could also be applied to matmul
  • Upstream tests show that transforms/passes still work on the alternative, so IREE should benefit from those.
  • IREE specific passes may need to adapt, but if made on par with upstream, may also work out of the box.

Of course, my assumptions may be wrong, this is just me being proactive. 😃

@@ -504,7 +504,7 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %arg1 : (!transform.any_op) -> !transform.any_op
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was the previous IR wrong? It seemed looking for matmul_transpose_b but line 500 was never that. I may be missing something.

@qedawkins
Copy link
Contributor

qedawkins commented Jul 10, 2025

I see a lot of usage of matmul transpose in IREE: https://github.com/search?q=repo%3Airee-org%2Firee%20matmul_transpose&type=code

A few things I can infer from just that search:

* Some usages are "representational", ie. could easily be replaced with the alternative.

* `if (!isa<linalg::MatmulTransposeBOp, linalg::BatchMatmulTransposeBOp>(op)) {` could be replaced with the matcher that @qed discusses above. Here, just check something like `isMapTranspose` that works on both `matmul` and `batch_matmul`.

* I assume `{lowering_config = ...}` attribute could also be applied to `matmul`

* Upstream tests show that transforms/passes still work on the alternative, so IREE should benefit from those.

* IREE specific passes may need to adapt, but if made on par with upstream, may also work out of the box.

Of course, my assumptions may be wrong, this is just me being proactive. 😃

In theory everything there is coverage related or test related. Because Linalg is IREE's input, we have to support the existing named variants. We also try to retain the name for as long as possible for convenience and readability purposes, but also that gives ostensible reliance on those ops per your search.

What I believe Mahesh is primarily asking for is a grace period to align that illusion of independence with reality. I see largely two paths:

  1. Introduce the replacement (linalg.matmul w/ indexing maps exists and is that) and provide a deprecation period. That is unusual for MLIR though.
  2. Design the C++ API changes to be as non-intrusive as possible. Things like bespoke matcher replacements and specialized builders make the majority of the changes purely mechanical and mitigates the surface area for bug finding along the way. This is mainly what I was asking for in the above comment.

Last comment I'll add about IREE, there is a fair amount of historical code that is currently unmaintained. Current development has been careful about parity between named and generic variants, but the old stuff wasn't necessarily. With that said, when old unused stuff breaks it gets deleted so not a huge deal :)

Honestly the part about this change I'm looking forward to the least is all the random IR snippets I have stashed away as a personal testing bank that I'll have to go update :P. Otherwise the direction here is great!

@qedawkins
Copy link
Contributor

Oh also concurrent multi-project llvm bumps are really painful so

My proposal is to fix torch-mlir to stop emitting matmul_transpose_* and understand what are the implications inside IREE (and any other linalg user that still relies on those operations). If IREE can be fixed quickly, we wait. If not, IREE can duplicate those ops inside linalgx and deprecate on your own speed.

Any way we can stage it so torch-mlir gets to go first would be very much appreciated for us. I think it's just linalg.matmul in torch-mlir proper, but there are adjacent shark projects that embed hand written IR that use the transpose variants.

@rengolin
Copy link
Member Author

  1. Introduce the replacement (linalg.matmul w/ indexing maps exists and is that) and provide a deprecation period. That is unusual for MLIR though.

Right, this is kinda what we did last year, minus the part where we agreed on a date. I'm trying to se the date now, but I don't think we should wait another year. 😃

  1. Design the C++ API changes to be as non-intrusive as possible. Things like bespoke matcher replacements and specialized builders make the majority of the changes purely mechanical and mitigates the surface area for bug finding along the way. This is mainly what I was asking for in the above comment.

I like this very much. But I also want to try to avoid too much bespoke matchers because that explodes as well. I'm ok with specializations for the common cases, but that causes confusion for the non-common cases. I'm ok with C++ functions that match something, but then everyone adds theirs and we get duplicates. So, some plan here would be good to have.

Current development has been careful about parity between named and generic variants, but the old stuff wasn't necessarily. With that said, when old unused stuff breaks it gets deleted so not a huge deal

I empathise. That's why I proposed a linalgx.matmul_transpose_* to be a stop-gap if you find code that is too hairy and can't be deleted straight away.

Honestly the part about this change I'm looking forward to the least is all the random IR snippets I have stashed away as a personal testing bank that I'll have to go update

That's why we created the mlir-gen tool, but now I have no idea what half of the code is trying to do! 🤣

Any way we can stage it so torch-mlir gets to go first would be very much appreciated for us. I think it's just linalg.matmul in torch-mlir proper, but there are adjacent shark projects that embed hand written IR that use the transpose variants.

I see. How do we sync those changes with upstream?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants