Skip to content

[mlir][linalg] Vectorize directly to a named contraction #147296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

adam-smnk
Copy link
Contributor

Extends linalg vectorizer with a path to lower contraction ops directly into vector.contract.

The direct rewriting preserves high-level op semantics and provides more progressive lowering compared to reconstructing contraction back from multi dimensional reduction.
The added lowering focuses on named linalg ops and leverages their well defined semantics to avoid complex precondition verification.

The new path is optional and disabled by default to avoid changing the default vectorizer behavior.

Extends linalg vectorizer with a path to lower contraction ops
directly into `vector.contract`.

The direct rewriting preserves high-level op semantics and provides
more progressive lowering compared to reconstructing contraction back
from multi dimensional reduction.
The added lowering focuses on named linalg ops and leverages their
well defined semantics to avoid complex precondition verification.

The new path is optional and disabled by default to avoid changing
the default vectorizer behavior.
@llvmbot
Copy link
Member

llvmbot commented Jul 7, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Adam Siemieniuk (adam-smnk)

Changes

Extends linalg vectorizer with a path to lower contraction ops directly into vector.contract.

The direct rewriting preserves high-level op semantics and provides more progressive lowering compared to reconstructing contraction back from multi dimensional reduction.
The added lowering focuses on named linalg ops and leverages their well defined semantics to avoid complex precondition verification.

The new path is optional and disabled by default to avoid changing the default vectorizer behavior.


Patch is 32.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147296.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+2)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+4-1)
  • (modified) mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+3-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+103-7)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+9-4)
  • (added) mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir (+400)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 4360055e78691..3bbde1240286c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2445,6 +2445,8 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
                        DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
                           $static_vector_sizes,
                        OptionalAttr<UnitAttr>:$vectorize_nd_extract,
+                       OptionalAttr<UnitAttr>:$flatten1D_depthwise_conv,
+                       OptionalAttr<UnitAttr>:$create_named_contraction,
                        DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
                           $scalable_sizes);
 
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 74280fdd82f4e..eee856f3eba68 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -876,11 +876,14 @@ struct VectorizationResult {
 /// greater than or equal to their counterpart iteration space sizes, if static.
 /// `inputVectorShapes` also allows the vectorization of operations with dynamic
 /// shapes.
+/// Optionally, `createNamedContraction` can force compatible contractions to be
+/// vectorized directly to vector.contract operation.
 FailureOr<VectorizationResult>
 vectorize(RewriterBase &rewriter, Operation *op,
           ArrayRef<int64_t> inputVectorSizes = {},
           ArrayRef<bool> inputScalableVecDims = {},
-          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
+          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
+          bool createNamedContraction = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index cc8421b23a074..9b765d0b8ede6 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -226,7 +226,8 @@ bool isLinearizableVector(VectorType type);
 /// Note: all read offsets are set to 0.
 Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
                              ArrayRef<int64_t> inputVectorSizes, Value padValue,
-                             bool useInBoundsInsteadOfMasking = false);
+                             bool useInBoundsInsteadOfMasking = false,
+                             ArrayRef<bool> scalableDims = {});
 
 /// Returns success if `inputVectorSizes` is a valid masking configuraion for
 /// given `shape`, i.e., it meets:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7ec3f3445281a..1038615f21b17 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3920,7 +3920,9 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
     }
     FailureOr<VectorizationResult> vectorResults =
         linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
-                          getVectorizeNdExtract().value_or(false));
+                          getVectorizeNdExtract().value_or(false),
+                          getFlatten1DDepthwiseConv().value_or(false),
+                          getCreateNamedContraction().value_or(false));
     if (failed(vectorResults)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Attempted to vectorize, but failed";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2864c5b807d69..bb1454cf43932 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -25,6 +25,7 @@
 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -1681,10 +1682,13 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
     return write;
 
   // Compute the mask and mask the write Op.
-  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
+                                       vecToStoreType.getScalableDims());
 
   SmallVector<OpFoldResult> destSizes =
-      tensor::getMixedSizes(builder, loc, dest);
+      isa<MemRefType>(dest.getType())
+          ? memref::getMixedSizes(builder, loc, dest)
+          : tensor::getMixedSizes(builder, loc, dest);
   SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
                                       destSizes.end());
 
@@ -2093,6 +2097,84 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
   return success();
 }
 
+/// Vectorize a named linalg contraction op into:
+///   vector::TransferReadOp - Reads vectors from the operands
+///   vector::ContractionOp - Performs contraction
+///   vector::TransferWriteOp - Write the result vector back to the
+///   destination
+/// The operands shapes are preserved and loaded directly into vectors.
+/// Any further permutations or numerical casting remain within contraction.
+static LogicalResult
+vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
+                             LinalgOp linalgOp,
+                             SmallVectorImpl<Value> &newResults) {
+  Location loc = linalgOp.getLoc();
+  MLIRContext *ctx = linalgOp.getContext();
+
+  if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
+    return failure();
+
+  OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
+  Operation *reduceOp = matchLinalgReduction(outOperand);
+  auto maybeKind = getCombinerOpKind(reduceOp);
+  if (!maybeKind)
+    return failure();
+
+  // Check that all dimensions are present in the input operands.
+  // Arbitrary broadcasts are not supported by the vector contraction.
+  // Broadcasts are expected to be materialized before vectorization.
+  AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
+  AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
+  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
+    return failure();
+
+  // Load operands.
+  SmallVector<Value> vecOperands;
+  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+    // The operand vector shape is computed by mapping the canonical vector
+    // shape to the operand's domain. Further permutations are left as a part of
+    // the contraction.
+    AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
+    AffineMap readMap = AffineMap::getMultiDimIdentityMap(
+        indexingMap.getNumResults(), rewriter.getContext());
+    Type elemType = getElementTypeOrSelf(opOperand.get());
+    VectorType readType =
+        state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
+
+    Value read = mlir::vector::createReadOrMaskedRead(
+        rewriter, loc, opOperand.get(), readType.getShape(),
+        /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
+        /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
+    vecOperands.push_back(read);
+  }
+
+  // Remap iterators from linalg to vector.
+  SmallVector<Attribute> iterAttrs;
+  auto iterators = linalgOp.getIteratorTypesArray();
+  for (utils::IteratorType iter : iterators) {
+    auto vecIter = iter == utils::IteratorType::parallel
+                       ? vector::IteratorType::parallel
+                       : vector::IteratorType::reduction;
+    iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
+  }
+
+  // Create contraction.
+  Value contractOp = rewriter.create<vector::ContractionOp>(
+      loc, /*lhs=*/vecOperands[0],
+      /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
+      linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
+
+  // Store result.
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, contractOp, outOperand->get());
+
+  // Finalize.
+  if (!write->getResults().empty())
+    newResults.push_back(write->getResult(0));
+
+  return success();
+}
+
 namespace {
 enum class ConvOperationKind { Conv, Pool };
 } // namespace
@@ -2528,11 +2610,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
              tensor::InsertSliceOp>(op);
 }
 
-FailureOr<VectorizationResult>
-mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
-                        ArrayRef<int64_t> inputVectorSizes,
-                        ArrayRef<bool> inputScalableVecDims,
-                        bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
+FailureOr<VectorizationResult> mlir::linalg::vectorize(
+    RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
+    ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
+    bool flatten1DDepthwiseConv, bool createNamedContraction) {
   LDBG("Attempting to vectorize:\n" << *op << "\n");
   LDBG("Input vector sizes: ");
   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2578,6 +2659,21 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
               return failure();
             }
 
+            // For simplicity, contraction vectorization is limited to linalg
+            // named ops. Generic op is ignored as not every arbitrary
+            // contraction body can be expressed by a vector.contract.
+            if (createNamedContraction &&
+                isa<ContractionOpInterface>(linalgOp.getOperation())) {
+              // Attempt vectorizing directly into a named contraction.
+              // In case of failure, fall back to the generic path.
+              LogicalResult res = vectorizeAsLinalgContraction(
+                  rewriter, state, linalgOp, results);
+              if (succeeded(res))
+                return success();
+
+              LDBG("Failed to vectorize as a named contraction.\n");
+            }
+
             LDBG("Vectorize generic by broadcasting to the canonical vector "
                  "shape\n");
 
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 674a93d7c520e..594bfce57e598 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -320,14 +320,16 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
                                      Value source,
                                      ArrayRef<int64_t> inputVectorSizes,
                                      Value padValue,
-                                     bool useInBoundsInsteadOfMasking) {
+                                     bool useInBoundsInsteadOfMasking,
+                                     ArrayRef<bool> scalableDims) {
   assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
          "invalid input vector sizes");
   auto sourceShapedType = cast<ShapedType>(source.getType());
   auto sourceShape = sourceShapedType.getShape();
   assert(sourceShape.size() == inputVectorSizes.size() &&
          "expected same ranks.");
-  auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+  auto vectorType =
+      VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
   assert(padValue.getType() == sourceShapedType.getElementType() &&
          "expected same pad element type to match source element type");
   int64_t readRank = inputVectorSizes.size();
@@ -352,9 +354,12 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
   if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
     return transferReadOp;
   SmallVector<OpFoldResult> mixedSourceDims =
-      tensor::getMixedSizes(builder, loc, source);
+      isa<MemRefType>(source.getType())
+          ? memref::getMixedSizes(builder, loc, source)
+          : tensor::getMixedSizes(builder, loc, source);
 
-  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
+  auto maskType =
+      VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
   Value mask =
       builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
   return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
new file mode 100644
index 0000000000000..1831acf092afb
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
@@ -0,0 +1,400 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
+
+func.func @matmul(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>,
+    %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = linalg.matmul
+    ins(%A, %B : tensor<8x4xf32>, tensor<4x16xf32>)
+    outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+  return %0 : tensor<8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul(
+// CHECK-SAME:    %[[A:.*]]: tensor<8x4xf32>, %[[B:.*]]: tensor<4x16xf32>,
+// CHECK-SAME:    %[[C:.*]]: tensor<8x16xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<8x4xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<4x16xf32>, vector<4x16xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+    %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul
+    ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic(
+// CHECK-SAME:    %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:    %[[C:.*]]: tensor<?x?xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x16xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+      {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_dynamic_memref(%A: memref<?x?xf32>, %B: memref<?x?xf32>,
+    %C: memref<?x?xf32>) {
+  linalg.matmul
+    ins(%A, %B : memref<?x?xf32>, memref<?x?xf32>)
+    outs(%C: memref<?x?xf32>)
+  return
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic_memref(
+// CHECK-SAME:    %[[A:.*]]: memref<?x?xf32>, %[[B:.*]]: memref<?x?xf32>,
+// CHECK-SAME:    %[[C:.*]]: memref<?x?xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: memref<?x?xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: memref<?x?xf32>, vector<4x16xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: memref<?x?xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, memref<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+      {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+    %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul
+    ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_scalable(
+// CHECK-SAME:    %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:    %[[C:.*]]: tensor<?x?xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x[16]xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x[16]xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x[16]xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, [16], 4]
+      {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_transpose(%A: tensor<4x8xf32>, %B: tensor<16x4xf32>,
+    %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = linalg.matmul
+    indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose A
+                     affine_map<(m, n, k) -> (n, k)>, // transpose B
+                     affine_...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 7, 2025

@llvm/pr-subscribers-mlir-vector

Author: Adam Siemieniuk (adam-smnk)

Changes

Extends linalg vectorizer with a path to lower contraction ops directly into vector.contract.

The direct rewriting preserves high-level op semantics and provides more progressive lowering compared to reconstructing contraction back from multi dimensional reduction.
The added lowering focuses on named linalg ops and leverages their well defined semantics to avoid complex precondition verification.

The new path is optional and disabled by default to avoid changing the default vectorizer behavior.


Patch is 32.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147296.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+2)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+4-1)
  • (modified) mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+3-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+103-7)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+9-4)
  • (added) mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir (+400)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 4360055e78691..3bbde1240286c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2445,6 +2445,8 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
                        DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
                           $static_vector_sizes,
                        OptionalAttr<UnitAttr>:$vectorize_nd_extract,
+                       OptionalAttr<UnitAttr>:$flatten1D_depthwise_conv,
+                       OptionalAttr<UnitAttr>:$create_named_contraction,
                        DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
                           $scalable_sizes);
 
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 74280fdd82f4e..eee856f3eba68 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -876,11 +876,14 @@ struct VectorizationResult {
 /// greater than or equal to their counterpart iteration space sizes, if static.
 /// `inputVectorShapes` also allows the vectorization of operations with dynamic
 /// shapes.
+/// Optionally, `createNamedContraction` can force compatible contractions to be
+/// vectorized directly to vector.contract operation.
 FailureOr<VectorizationResult>
 vectorize(RewriterBase &rewriter, Operation *op,
           ArrayRef<int64_t> inputVectorSizes = {},
           ArrayRef<bool> inputScalableVecDims = {},
-          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
+          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
+          bool createNamedContraction = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index cc8421b23a074..9b765d0b8ede6 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -226,7 +226,8 @@ bool isLinearizableVector(VectorType type);
 /// Note: all read offsets are set to 0.
 Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
                              ArrayRef<int64_t> inputVectorSizes, Value padValue,
-                             bool useInBoundsInsteadOfMasking = false);
+                             bool useInBoundsInsteadOfMasking = false,
+                             ArrayRef<bool> scalableDims = {});
 
 /// Returns success if `inputVectorSizes` is a valid masking configuraion for
 /// given `shape`, i.e., it meets:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7ec3f3445281a..1038615f21b17 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3920,7 +3920,9 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
     }
     FailureOr<VectorizationResult> vectorResults =
         linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
-                          getVectorizeNdExtract().value_or(false));
+                          getVectorizeNdExtract().value_or(false),
+                          getFlatten1DDepthwiseConv().value_or(false),
+                          getCreateNamedContraction().value_or(false));
     if (failed(vectorResults)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Attempted to vectorize, but failed";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2864c5b807d69..bb1454cf43932 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -25,6 +25,7 @@
 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -1681,10 +1682,13 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
     return write;
 
   // Compute the mask and mask the write Op.
-  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
+                                       vecToStoreType.getScalableDims());
 
   SmallVector<OpFoldResult> destSizes =
-      tensor::getMixedSizes(builder, loc, dest);
+      isa<MemRefType>(dest.getType())
+          ? memref::getMixedSizes(builder, loc, dest)
+          : tensor::getMixedSizes(builder, loc, dest);
   SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
                                       destSizes.end());
 
@@ -2093,6 +2097,84 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
   return success();
 }
 
+/// Vectorize a named linalg contraction op into:
+///   vector::TransferReadOp - Reads vectors from the operands
+///   vector::ContractionOp - Performs contraction
+///   vector::TransferWriteOp - Write the result vector back to the
+///   destination
+/// The operands shapes are preserved and loaded directly into vectors.
+/// Any further permutations or numerical casting remain within contraction.
+static LogicalResult
+vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
+                             LinalgOp linalgOp,
+                             SmallVectorImpl<Value> &newResults) {
+  Location loc = linalgOp.getLoc();
+  MLIRContext *ctx = linalgOp.getContext();
+
+  if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
+    return failure();
+
+  OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
+  Operation *reduceOp = matchLinalgReduction(outOperand);
+  auto maybeKind = getCombinerOpKind(reduceOp);
+  if (!maybeKind)
+    return failure();
+
+  // Check that all dimensions are present in the input operands.
+  // Arbitrary broadcasts are not supported by the vector contraction.
+  // Broadcasts are expected to be materialized before vectorization.
+  AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
+  AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
+  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
+    return failure();
+
+  // Load operands.
+  SmallVector<Value> vecOperands;
+  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+    // The operand vector shape is computed by mapping the canonical vector
+    // shape to the operand's domain. Further permutations are left as a part of
+    // the contraction.
+    AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
+    AffineMap readMap = AffineMap::getMultiDimIdentityMap(
+        indexingMap.getNumResults(), rewriter.getContext());
+    Type elemType = getElementTypeOrSelf(opOperand.get());
+    VectorType readType =
+        state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
+
+    Value read = mlir::vector::createReadOrMaskedRead(
+        rewriter, loc, opOperand.get(), readType.getShape(),
+        /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
+        /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
+    vecOperands.push_back(read);
+  }
+
+  // Remap iterators from linalg to vector.
+  SmallVector<Attribute> iterAttrs;
+  auto iterators = linalgOp.getIteratorTypesArray();
+  for (utils::IteratorType iter : iterators) {
+    auto vecIter = iter == utils::IteratorType::parallel
+                       ? vector::IteratorType::parallel
+                       : vector::IteratorType::reduction;
+    iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
+  }
+
+  // Create contraction.
+  Value contractOp = rewriter.create<vector::ContractionOp>(
+      loc, /*lhs=*/vecOperands[0],
+      /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
+      linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
+
+  // Store result.
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, contractOp, outOperand->get());
+
+  // Finalize.
+  if (!write->getResults().empty())
+    newResults.push_back(write->getResult(0));
+
+  return success();
+}
+
 namespace {
 enum class ConvOperationKind { Conv, Pool };
 } // namespace
@@ -2528,11 +2610,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
              tensor::InsertSliceOp>(op);
 }
 
-FailureOr<VectorizationResult>
-mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
-                        ArrayRef<int64_t> inputVectorSizes,
-                        ArrayRef<bool> inputScalableVecDims,
-                        bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
+FailureOr<VectorizationResult> mlir::linalg::vectorize(
+    RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
+    ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
+    bool flatten1DDepthwiseConv, bool createNamedContraction) {
   LDBG("Attempting to vectorize:\n" << *op << "\n");
   LDBG("Input vector sizes: ");
   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2578,6 +2659,21 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
               return failure();
             }
 
+            // For simplicity, contraction vectorization is limited to linalg
+            // named ops. Generic op is ignored as not every arbitrary
+            // contraction body can be expressed by a vector.contract.
+            if (createNamedContraction &&
+                isa<ContractionOpInterface>(linalgOp.getOperation())) {
+              // Attempt vectorizing directly into a named contraction.
+              // In case of failure, fall back to the generic path.
+              LogicalResult res = vectorizeAsLinalgContraction(
+                  rewriter, state, linalgOp, results);
+              if (succeeded(res))
+                return success();
+
+              LDBG("Failed to vectorize as a named contraction.\n");
+            }
+
             LDBG("Vectorize generic by broadcasting to the canonical vector "
                  "shape\n");
 
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 674a93d7c520e..594bfce57e598 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -320,14 +320,16 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
                                      Value source,
                                      ArrayRef<int64_t> inputVectorSizes,
                                      Value padValue,
-                                     bool useInBoundsInsteadOfMasking) {
+                                     bool useInBoundsInsteadOfMasking,
+                                     ArrayRef<bool> scalableDims) {
   assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
          "invalid input vector sizes");
   auto sourceShapedType = cast<ShapedType>(source.getType());
   auto sourceShape = sourceShapedType.getShape();
   assert(sourceShape.size() == inputVectorSizes.size() &&
          "expected same ranks.");
-  auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+  auto vectorType =
+      VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
   assert(padValue.getType() == sourceShapedType.getElementType() &&
          "expected same pad element type to match source element type");
   int64_t readRank = inputVectorSizes.size();
@@ -352,9 +354,12 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
   if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
     return transferReadOp;
   SmallVector<OpFoldResult> mixedSourceDims =
-      tensor::getMixedSizes(builder, loc, source);
+      isa<MemRefType>(source.getType())
+          ? memref::getMixedSizes(builder, loc, source)
+          : tensor::getMixedSizes(builder, loc, source);
 
-  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
+  auto maskType =
+      VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
   Value mask =
       builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
   return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
new file mode 100644
index 0000000000000..1831acf092afb
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
@@ -0,0 +1,400 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
+
+func.func @matmul(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>,
+    %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = linalg.matmul
+    ins(%A, %B : tensor<8x4xf32>, tensor<4x16xf32>)
+    outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+  return %0 : tensor<8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul(
+// CHECK-SAME:    %[[A:.*]]: tensor<8x4xf32>, %[[B:.*]]: tensor<4x16xf32>,
+// CHECK-SAME:    %[[C:.*]]: tensor<8x16xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<8x4xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<4x16xf32>, vector<4x16xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+    %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul
+    ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic(
+// CHECK-SAME:    %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:    %[[C:.*]]: tensor<?x?xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x16xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+      {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_dynamic_memref(%A: memref<?x?xf32>, %B: memref<?x?xf32>,
+    %C: memref<?x?xf32>) {
+  linalg.matmul
+    ins(%A, %B : memref<?x?xf32>, memref<?x?xf32>)
+    outs(%C: memref<?x?xf32>)
+  return
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic_memref(
+// CHECK-SAME:    %[[A:.*]]: memref<?x?xf32>, %[[B:.*]]: memref<?x?xf32>,
+// CHECK-SAME:    %[[C:.*]]: memref<?x?xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: memref<?x?xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: memref<?x?xf32>, vector<4x16xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: memref<?x?xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, memref<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+      {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+    %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul
+    ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_scalable(
+// CHECK-SAME:    %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:    %[[C:.*]]: tensor<?x?xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x[16]xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x[16]xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x[16]xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, [16], 4]
+      {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_transpose(%A: tensor<4x8xf32>, %B: tensor<16x4xf32>,
+    %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = linalg.matmul
+    indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose A
+                     affine_map<(m, n, k) -> (n, k)>, // transpose B
+                     affine_...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 7, 2025

@llvm/pr-subscribers-mlir

Author: Adam Siemieniuk (adam-smnk)

Changes

Extends linalg vectorizer with a path to lower contraction ops directly into vector.contract.

The direct rewriting preserves high-level op semantics and provides more progressive lowering compared to reconstructing contraction back from multi dimensional reduction.
The added lowering focuses on named linalg ops and leverages their well defined semantics to avoid complex precondition verification.

The new path is optional and disabled by default to avoid changing the default vectorizer behavior.


Patch is 32.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147296.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+2)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+4-1)
  • (modified) mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+3-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+103-7)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+9-4)
  • (added) mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir (+400)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 4360055e78691..3bbde1240286c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2445,6 +2445,8 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
                        DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
                           $static_vector_sizes,
                        OptionalAttr<UnitAttr>:$vectorize_nd_extract,
+                       OptionalAttr<UnitAttr>:$flatten1D_depthwise_conv,
+                       OptionalAttr<UnitAttr>:$create_named_contraction,
                        DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
                           $scalable_sizes);
 
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 74280fdd82f4e..eee856f3eba68 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -876,11 +876,14 @@ struct VectorizationResult {
 /// greater than or equal to their counterpart iteration space sizes, if static.
 /// `inputVectorShapes` also allows the vectorization of operations with dynamic
 /// shapes.
+/// Optionally, `createNamedContraction` can force compatible contractions to be
+/// vectorized directly to vector.contract operation.
 FailureOr<VectorizationResult>
 vectorize(RewriterBase &rewriter, Operation *op,
           ArrayRef<int64_t> inputVectorSizes = {},
           ArrayRef<bool> inputScalableVecDims = {},
-          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
+          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
+          bool createNamedContraction = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index cc8421b23a074..9b765d0b8ede6 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -226,7 +226,8 @@ bool isLinearizableVector(VectorType type);
 /// Note: all read offsets are set to 0.
 Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
                              ArrayRef<int64_t> inputVectorSizes, Value padValue,
-                             bool useInBoundsInsteadOfMasking = false);
+                             bool useInBoundsInsteadOfMasking = false,
+                             ArrayRef<bool> scalableDims = {});
 
 /// Returns success if `inputVectorSizes` is a valid masking configuraion for
 /// given `shape`, i.e., it meets:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7ec3f3445281a..1038615f21b17 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3920,7 +3920,9 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
     }
     FailureOr<VectorizationResult> vectorResults =
         linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
-                          getVectorizeNdExtract().value_or(false));
+                          getVectorizeNdExtract().value_or(false),
+                          getFlatten1DDepthwiseConv().value_or(false),
+                          getCreateNamedContraction().value_or(false));
     if (failed(vectorResults)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Attempted to vectorize, but failed";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2864c5b807d69..bb1454cf43932 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -25,6 +25,7 @@
 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -1681,10 +1682,13 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
     return write;
 
   // Compute the mask and mask the write Op.
-  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
+                                       vecToStoreType.getScalableDims());
 
   SmallVector<OpFoldResult> destSizes =
-      tensor::getMixedSizes(builder, loc, dest);
+      isa<MemRefType>(dest.getType())
+          ? memref::getMixedSizes(builder, loc, dest)
+          : tensor::getMixedSizes(builder, loc, dest);
   SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
                                       destSizes.end());
 
@@ -2093,6 +2097,84 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
   return success();
 }
 
+/// Vectorize a named linalg contraction op into:
+///   vector::TransferReadOp - Reads vectors from the operands
+///   vector::ContractionOp - Performs contraction
+///   vector::TransferWriteOp - Write the result vector back to the
+///   destination
+/// The operands shapes are preserved and loaded directly into vectors.
+/// Any further permutations or numerical casting remain within contraction.
+static LogicalResult
+vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
+                             LinalgOp linalgOp,
+                             SmallVectorImpl<Value> &newResults) {
+  Location loc = linalgOp.getLoc();
+  MLIRContext *ctx = linalgOp.getContext();
+
+  if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
+    return failure();
+
+  OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
+  Operation *reduceOp = matchLinalgReduction(outOperand);
+  auto maybeKind = getCombinerOpKind(reduceOp);
+  if (!maybeKind)
+    return failure();
+
+  // Check that all dimensions are present in the input operands.
+  // Arbitrary broadcasts are not supported by the vector contraction.
+  // Broadcasts are expected to be materialized before vectorization.
+  AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
+  AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
+  if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
+    return failure();
+
+  // Load operands.
+  SmallVector<Value> vecOperands;
+  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+    // The operand vector shape is computed by mapping the canonical vector
+    // shape to the operand's domain. Further permutations are left as a part of
+    // the contraction.
+    AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
+    AffineMap readMap = AffineMap::getMultiDimIdentityMap(
+        indexingMap.getNumResults(), rewriter.getContext());
+    Type elemType = getElementTypeOrSelf(opOperand.get());
+    VectorType readType =
+        state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
+
+    Value read = mlir::vector::createReadOrMaskedRead(
+        rewriter, loc, opOperand.get(), readType.getShape(),
+        /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
+        /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
+    vecOperands.push_back(read);
+  }
+
+  // Remap iterators from linalg to vector.
+  SmallVector<Attribute> iterAttrs;
+  auto iterators = linalgOp.getIteratorTypesArray();
+  for (utils::IteratorType iter : iterators) {
+    auto vecIter = iter == utils::IteratorType::parallel
+                       ? vector::IteratorType::parallel
+                       : vector::IteratorType::reduction;
+    iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
+  }
+
+  // Create contraction.
+  Value contractOp = rewriter.create<vector::ContractionOp>(
+      loc, /*lhs=*/vecOperands[0],
+      /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
+      linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
+
+  // Store result.
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, contractOp, outOperand->get());
+
+  // Finalize.
+  if (!write->getResults().empty())
+    newResults.push_back(write->getResult(0));
+
+  return success();
+}
+
 namespace {
 enum class ConvOperationKind { Conv, Pool };
 } // namespace
@@ -2528,11 +2610,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
              tensor::InsertSliceOp>(op);
 }
 
-FailureOr<VectorizationResult>
-mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
-                        ArrayRef<int64_t> inputVectorSizes,
-                        ArrayRef<bool> inputScalableVecDims,
-                        bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
+FailureOr<VectorizationResult> mlir::linalg::vectorize(
+    RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
+    ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
+    bool flatten1DDepthwiseConv, bool createNamedContraction) {
   LDBG("Attempting to vectorize:\n" << *op << "\n");
   LDBG("Input vector sizes: ");
   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2578,6 +2659,21 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
               return failure();
             }
 
+            // For simplicity, contraction vectorization is limited to linalg
+            // named ops. Generic op is ignored as not every arbitrary
+            // contraction body can be expressed by a vector.contract.
+            if (createNamedContraction &&
+                isa<ContractionOpInterface>(linalgOp.getOperation())) {
+              // Attempt vectorizing directly into a named contraction.
+              // In case of failure, fall back to the generic path.
+              LogicalResult res = vectorizeAsLinalgContraction(
+                  rewriter, state, linalgOp, results);
+              if (succeeded(res))
+                return success();
+
+              LDBG("Failed to vectorize as a named contraction.\n");
+            }
+
             LDBG("Vectorize generic by broadcasting to the canonical vector "
                  "shape\n");
 
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 674a93d7c520e..594bfce57e598 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -320,14 +320,16 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
                                      Value source,
                                      ArrayRef<int64_t> inputVectorSizes,
                                      Value padValue,
-                                     bool useInBoundsInsteadOfMasking) {
+                                     bool useInBoundsInsteadOfMasking,
+                                     ArrayRef<bool> scalableDims) {
   assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
          "invalid input vector sizes");
   auto sourceShapedType = cast<ShapedType>(source.getType());
   auto sourceShape = sourceShapedType.getShape();
   assert(sourceShape.size() == inputVectorSizes.size() &&
          "expected same ranks.");
-  auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+  auto vectorType =
+      VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
   assert(padValue.getType() == sourceShapedType.getElementType() &&
          "expected same pad element type to match source element type");
   int64_t readRank = inputVectorSizes.size();
@@ -352,9 +354,12 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
   if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
     return transferReadOp;
   SmallVector<OpFoldResult> mixedSourceDims =
-      tensor::getMixedSizes(builder, loc, source);
+      isa<MemRefType>(source.getType())
+          ? memref::getMixedSizes(builder, loc, source)
+          : tensor::getMixedSizes(builder, loc, source);
 
-  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
+  auto maskType =
+      VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
   Value mask =
       builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
   return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
new file mode 100644
index 0000000000000..1831acf092afb
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
@@ -0,0 +1,400 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
+
+func.func @matmul(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>,
+    %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = linalg.matmul
+    ins(%A, %B : tensor<8x4xf32>, tensor<4x16xf32>)
+    outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+  return %0 : tensor<8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul(
+// CHECK-SAME:    %[[A:.*]]: tensor<8x4xf32>, %[[B:.*]]: tensor<4x16xf32>,
+// CHECK-SAME:    %[[C:.*]]: tensor<8x16xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<8x4xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<4x16xf32>, vector<4x16xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+    %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul
+    ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic(
+// CHECK-SAME:    %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:    %[[C:.*]]: tensor<?x?xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x16xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+      {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_dynamic_memref(%A: memref<?x?xf32>, %B: memref<?x?xf32>,
+    %C: memref<?x?xf32>) {
+  linalg.matmul
+    ins(%A, %B : memref<?x?xf32>, memref<?x?xf32>)
+    outs(%C: memref<?x?xf32>)
+  return
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic_memref(
+// CHECK-SAME:    %[[A:.*]]: memref<?x?xf32>, %[[B:.*]]: memref<?x?xf32>,
+// CHECK-SAME:    %[[C:.*]]: memref<?x?xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: memref<?x?xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: memref<?x?xf32>, vector<4x16xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: memref<?x?xf32>, vector<8x16xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, memref<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+      {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+    %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul
+    ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_scalable(
+// CHECK-SAME:    %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:    %[[C:.*]]: tensor<?x?xf32>)
+//      CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
+//      CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x[16]xf32>
+//      CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x[16]xf32>
+//      CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME:   indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME:   kind = #vector.kind<add>
+// CHECK-SAME:   %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x[16]xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, [16], 4]
+      {create_named_contraction} : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_transpose(%A: tensor<4x8xf32>, %B: tensor<16x4xf32>,
+    %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  %0 = linalg.matmul
+    indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose A
+                     affine_map<(m, n, k) -> (n, k)>, // transpose B
+                     affine_...
[truncated]

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments, but this looks great, very lean and to the point.

I'll let others review and approve, but for the record, it looks good to me already, thanks!

@@ -2445,6 +2445,8 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
$static_vector_sizes,
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
OptionalAttr<UnitAttr>:$flatten1D_depthwise_conv,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any special reason to add this unrelated option?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for completness.
As I already tweak the transform op with a new option, this one present in linalg::vectorize API was missing here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me add a bit more context.

I added the option to flatten the depthwise convs as an optimisation for convs with low channel dim count. While great for NEON (i.e. fixed width vectors), it's something that's tricky to generalise to scalable vectors. So I deliberately avoided extending the support ( I am waiting to see whether others find it useful).

I am fine with extending this Op, but if we do, we should also add more tests. Your call :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the insight 🙂
In this case, I'll remove the option to limit the scope of this PR.


// Check that all dimensions are present in the input operands.
// Arbitrary broadcasts are not supported by the vector contraction.
// Broadcasts are expected to be materialized before vectorization.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the broadcast vectorization work in tandem with this one? Or can you call pinpoint vectorization on the contract and not on the surrounding code (say, a transform)?

If the latter, then we may (at some point later) validate the producers and consumers to make sure the vectorization won't break anything around.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes to both. You can also vectorize selectively (either in a pass or a transform).

In general, the current vectorizer rewrites one op at the time. It creates read and write ops at the boundaries exactly to ensure seamless transition between a vectorized op and its consumers and producers.
At tensor level, these read-write pairs can easily cancel out thanks to value semantics. In memrefs, cleanup is obviously tricker due to possible aliasing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At tensor level, these read-write pairs can easily cancel out thanks to value semantics. In memrefs, cleanup is obviously tricker due to possible aliasing.

Yes, that's my worry, but I guess it's up to the transform user to know memrefs are harder and adjust strategy. This won't be the only problem they'll have with memrefs anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't want to guess how to materialize such broadcasts. We could just pick one default broadcasting scheme (e.g., canonical vector shape like in the linalg generic vectorizer) but it's likely to be suboptimal too.

Perhaps this broadcasting information could be encoded in the operand's layout? Sth I could experiment with later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I am missing "materialize" context in the vectorization concept. Is it a blocker to make the option default? What does materializing broadcasts mean? Is it breaking a matmul into something like broadcast(LHS)->matmul?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a blocker to make the option default?

Not necessarily but it'd be good to align if broadcasts should be handled at all. If yes, then how.

What does materializing broadcasts mean? Is it breaking a matmul into something like broadcast(LHS)->matmul?

Correct. Today, broadcast semantics can't be preserved and kept within vector.contract so the extra dimension has to be created somewhere.
I'll elaborate more in a separate answer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My terminology would be decompose, because it is what we use for pack/unpack/pad/etc ops in upstream. If we break an op into a sequence of simpler ops, I'd call it decomposition. E.g., DecomposeGenericByUnfoldingPermutation,

/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
/// InsertSliceOp. For now, only constant padding values are supported.
struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
DecomposePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<tensor::PadOp>(context, benefit) {}
LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override;
protected:
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp,
Value dest,
const SmallVector<Value> &dynSizes) const;
};
/// Rewrites a linalg::PackOp into a sequence of:
/// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp +
/// tensor::InsertSliceOp ops.
///
/// Requires that all the outer dims of the input linalg::PackOp are 1.
///
/// Before:
/// ```
/// %packed = linalg.pack %input
/// padding_value(%pad : f32)
/// inner_dims_pos = [1, 0]
/// inner_tiles = [2, %high]
/// into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
/// ```
///
/// After:
/// ```
/// // PadOp
/// %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
/// ^bb0(...):
/// tensor.yield %arg2 : f32
/// } : tensor<5x1xf32> to tensor<?x2xf32>
/// // EmptyOp + TransposeOp
/// %empty = tensor.empty(%arg3) : tensor<2x?xf32>
/// %transposed = linalg.transpose
/// ins(%extracted_slice : tensor<?x2xf32>)
/// outs(%empty : tensor<2x?xf32>)
/// permutation = [1, 0]
/// // InsertSliceOp
/// %inserted_slice = tensor.insert_slice %transposed
/// into %arg1[0, 0, 0, 0] [1, 1, 2, %tile_dim_1] [1, 1, 1, 1]
/// : tensor<2x?xf32> into tensor<1x1x2x?xf32>
/// ```
struct DecomposeOuterUnitDimsPackOpPattern
: public OpRewritePattern<linalg::PackOp> {
using OpRewritePattern<linalg::PackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::PackOp packOp,
PatternRewriter &rewriter) const override;
};
/// Rewrites a linalg::UnPackOp into a sequence of rank-reduced
/// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
///
/// Requires that all the outer dims of the input linalg::PackOp are 1.
///
/// Before:
/// ```
/// %packed = linalg.unpack %input
/// inner_dims_pos = [1, 0]
/// inner_tiles = [2, 8]
/// into %output : tensor<1x1x2x8xf32> -> tensor<5x1xf32>
/// ```
///
/// After:
/// ```
/// // Rank-reduced extract to obtain the tile
/// %slice = tensor.extract_slice %arg0[0, 0, 0, 0] [1, 1, 2, 8] [1, 1, 1, 1]
/// : tensor<1x1x2x8xf32> to tensor<2x8xf32>
/// // EmptyOp + TransposeOp
/// %init = tensor.empty() : tensor<8x2xf32>
/// %transposed = linalg.transpose
/// ins(%extracted_slice : tensor<2x8xf32>)
/// outs(%0 : tensor<8x2xf32>) permutation = [1, 0]
/// // Extract a slice matching the specified output size
/// %result = tensor.extract_slice %transposed[0, 0] [5, 1] [1, 1]
/// : tensor<8x2xf32> to tensor<5x1xf32>
/// ```
struct DecomposeOuterUnitDimsUnPackOpPattern
: public OpRewritePattern<linalg::UnPackOp> {
using OpRewritePattern<linalg::UnPackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp,
PatternRewriter &rewriter) const override;
};

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, thanks!

I've made some small suggestions inline, but nothing major. The approach that you are taking makes a lot of sense to me.

The new path is optional and disabled by default to avoid changing the default vectorizer behavior.

This is good time to discuss the future direction for this.

@dcaballe has already hinted in Vector Dialect: Refactoring + Re-design ideas that we should look into removing vector.multi_reduction. I am not sure about that myself just yet, but avoiding vector.multi_reduction in the vectorization path makes a lot of sense to me (i.e. specifically for "Linalg contractions"). So:

  • Why shouldn't we make it the default?
  • What are the next step for you?

I am happy to work with you to make this the default if we/you discover that's something is missing. Otherwise, I would prepare people to flip the switch soonish.

@@ -2445,6 +2445,8 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
$static_vector_sizes,
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
OptionalAttr<UnitAttr>:$flatten1D_depthwise_conv,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me add a bit more context.

I added the option to flatten the depthwise convs as an optimisation for convs with low channel dim count. While great for NEON (i.e. fixed width vectors), it's something that's tricky to generalise to scalable vectors. So I deliberately avoided extending the support ( I am waiting to see whether others find it useful).

I am fine with extending this Op, but if we do, we should also add more tests. Your call :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, I have recently refactored the vectorisation tests and grouped them as discussed here:

Specifically, this is what we have today:

mlir/test/Dialect/Linalg/vectorization/
├── conv.mlir
├── conv-flatten.mlir
├── conv-with-patterns.mlir
├── extract.mlir
├── extract-with-patterns.mlir
├── insert-slice.mlir
├── insert-slice-with-patterns.mlir
├── linalg-ops.mlir
├── linalg-ops-with-patterns.mlir
├── pad.mlir
├── pad-with-patterns.mlir
├── unsupported.mlir

So, contraction-named.mlir doesn't quite fit. However, one of these options would:

  • contractione.mlir (1 file), or
  • contraction-interface.mlir (1 file), or
  • contract.mlir + matmul.mlir + batch_matmul.mlir ... (multiple files)

Given that the updated logic looks at ContractionOpInterface, I vote for contraction-interface.mlir. Other names are fine too, I would just add a top-level comment to explain what this file covers (this is the first test file to focus on a specific interface, so it's worth clarifying).


// -----

func.func @matmul_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Given that this is a variation of @matmul_dynamic ...

Suggested change
func.func @matmul_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
func.func @matmul_dynamic_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,


// -----

func.func @negative_matmul_broadcast(%A: tensor<4xf32>, %B: tensor<4x16xf32>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps add a note that "negative" in this case means: "Vectorization works, but you are not getting vector.contract as explicitly requested"?

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! The direction looks good to me. This is a long overdue. I think we should discuss what is missing to make this the default path and work towards that. Do you know what is missing?

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The approach looks okay to me because I'm -1 on recovering information to achieve the same thing when we have enough information. It'd be good to see what is missing to make it default, as I think it is usually bad to have boolean options, especially when they are linalg specific. It makes generalization harder if we want vectorization interface in the future. I'd like to learn why it can't be default and what is missing.

I also have a question about "materializing broadcast" in my inline comment.


// Check that all dimensions are present in the input operands.
// Arbitrary broadcasts are not supported by the vector contraction.
// Broadcasts are expected to be materialized before vectorization.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I am missing "materialize" context in the vectorization concept. Is it a blocker to make the option default? What does materializing broadcasts mean? Is it breaking a matmul into something like broadcast(LHS)->matmul?

@adam-smnk
Copy link
Contributor Author

adam-smnk commented Jul 9, 2025

Thanks a lot for the feedback!
Since I have all your attention, I'm happy to iterate on the design and work toward making it the default path.

This new lowering tries to preserve as much information as possible to avoid any need for reconstruction (like current mutli_reduction plus raising to contract).
Thanks to narrow and well-defined semantics of existing linalg contraction ops (matmul, contract etc.) the lowering is pretty straight forward for their all default representations and in presence of any transposes.

Generics

One point I want to address first is why we should ignore linalg.generic here.
Simpler and cheaper matching aside, both linalg named contraction ops and vector.contract have narrower semantics. They cannot represent all possible contractions as defined by linalg::detail::isContractionBody - for example, a unary operation between a binary elemwise and a binary reduce operations represent a valid contraction body which currently cannot be represented by other ops.

I think generics should be treated as such (well, generic) and be vectorized as they are today. One can always specialize a generic to capture more narrow behavior. Equally, linalg ops can be generalized before vectorization to retain current vectorizer behavior. Thus, no obstacle here preventing enabling specialized vectorization by default.

Mixed precision

Both linalg and vector support the same mixed precision semantics where inputs are converted into the output precision before computation.

Today, the casts are externalized by default and folding them back into vector.contract is optional.
The new path keep the mixed semantics within the op. Minor difference but might be better to give a heads up to the users before making it the default.

Broadcasts

Now, let's look into broadcast semantics present in linalg named ops.
Many general named contraction ops (like matmul, batch_matmul, contract) allow to express arbitrary input broadcasts through their indexing_maps. As an example, consider the following matmul:

%0 = linalg.matmul
    indexing_maps = [affine_map<(m, n, k) -> (k)>,  // broadcast LHS
                      affine_map<(m, n, k) -> (k)>, // broadcast RHS
                      affine_map<(m, n, k) -> (m, n)>]
    ins(%arg0, %arg1 : tensor<12xf32>, tensor<12xf32>)
    outs(%arg2: tensor<24x25xf32>) -> tensor<24x25xf32>

The matmul broadcasts both of its inputs. Considering the named op semantic, one would expect LHS to expand into (m, k) or (k, m) shape and RHS into (k, n) or (n, k). While all these possible combinations are valid, this choice might have implications further down the line. Currently, there is no way to encode desired broadcast variant within the op.

vector.contract doesn't support broadcast semantics in its indexing maps. Therefore, the missing dimensions have to be materialized before or during vectorization. Current, vectorization through multi reduction and reconstruction (transform.structured.vectorize_children_and_apply_patterns) results in the following contraction:

%1 = vector.transfer_read %LHS : tensor<12xf32>, vector<12xf32>
%2 = vector.broadcast %1 : vector<12xf32> to vector<24x25x12xf32>
%3 = vector.transfer_read %RHS : tensor<12xf32>, vector<12xf32>
%4 = vector.transfer_read %ACC : tensor<24x25xf32>, vector<24x25xf32>
%5 = vector.contract ... kind = #vector.kind<add>} %2, %3, %4
  : vector<24x25x12xf32>, vector<12xf32> into vector<24x25xf32>

which is a valid realization. However, the 2D matmul semantics are lost and the computation turned it into an arbitrary contraction.

At the moment, I went with the option to bail (fall back to generic behavior) in presence of broadcasts. This keeps the overall vectorization consistent between specialized lowering and generic+reconstruction. So, that makes for a fine default path as well.
Now a user has the option to create separate broadcasts before vectorization for cleaner lowering or to keep it as is.

As a compromise, the same broadcast into canonical vector shape could be performed or a somewhat better broadcasting for contraction could be added to this specialized path.
I'd say a less surprising scheme would be to broadcasts LHS into (batch, m, k) and RHS into (batch, n, k) when needed. The dimension order might not be optimal but it resembles more typical contraction.
It's just a first thought though. Handling "weird" linalg.contract cases might be more difficult.

This makes specialized vs generic vectorization diverge a bit - it is fine but it might be better to give users time to adjust before making it the default.

@adam-smnk
Copy link
Contributor Author

For the next steps, I'd love to hear your thoughts on how broadcasts should be handled.
About making it the default, I feel it might be better to give a PSA and a bit of time before flipping the switch while this path can be adopted and further tested by downstream users.

we should look into removing vector.multi_reduction. I am not sure about that myself just yet

Me neither🙂
I agree we shouldn't have to go through reduction and raising in presence of clear contractions. But I imagine the mutli reduction is still useful for representing, well... Just reductions.

@adam-smnk adam-smnk requested a review from rolfmorel July 9, 2025 11:50
@banach-space
Copy link
Contributor

The matmul broadcasts both of its inputs. Considering the named op semantic, one would expect LHS to expand into (m, k) or (k, m) shape and RHS into (k, n) or (n, k). While all these possible combinations are valid, this choice might have implications further down the line. Currently, there is no way to encode desired broadcast variant within the op.

When we broadcast like this, the maps should be the default indexing maps for linalg.matmul, no? Otherwise there would be an ambiguity.

@adam-smnk
Copy link
Contributor Author

The matmul broadcasts both of its inputs. Considering the named op semantic, one would expect LHS to expand into (m, k) or (k, m) shape and RHS into (k, n) or (n, k). While all these possible combinations are valid, this choice might have implications further down the line. Currently, there is no way to encode desired broadcast variant within the op.

When we broadcast like this, the maps should be the default indexing maps for linalg.matmul, no? Otherwise there would be an ambiguity.

I think there is ambiguity to begin with. Default maps are a reasonable ("canonical"?) choice in this case.
However, this still leaves linalg.contract that can express the same computation but has no defaults.
In any case, somebody has to make this decision.

@banach-space
Copy link
Contributor

However, this still leaves linalg.contract that can express the same computation but has no defaults.

Does it not? To me, if the maps are not uniquely defined, that would be a "bug" (as in, if there's ambiguity, it should be removed). However, taking this example:

func.func @example%A : memref<?x?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?x?xf32>) {
  linalg.contract
      indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>,
                       affine_map<(batch, m, n, k) -> (k, n)>,
                       affine_map<(batch, m, n, k) -> (batch, m, n)>]
      ins(%A, %B: memref<?x?x?xf32>, memref<?x?xf32>)
      outs(%C: memref<?x?x?xf32>)
  return
}

"generalisation" does yield unique maps (mlir-opt %s -split-input-file -linalg-generalize-named-ops):

#map = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
module {
  func.func @generalize_matmul_buffer(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
    linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2 : memref<?x?x?xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %0 = arith.mulf %in, %in_0 : f32
      %1 = arith.addf %out, %0 : f32
      linalg.yield %1 : f32
    }
    return
  }
}

What am I missing here? 🤔

@adam-smnk
Copy link
Contributor Author

I think we're talking about two slightly different things. Let me try to clarify.

linalg.contract has unique maps as in the operation requires to have its indexing_maps to be explicitly defined. As in your above example, contract already has maps and generalization only has to generate correct body. It's basically a 1-to-1 conversion.

For the earlier case I took your statement:

When we broadcast like this, the maps should be the default indexing maps for linalg.matmul

as in resolving broadcasting for this matmul:

%0 = linalg.matmul
    indexing_maps = [affine_map<(m, n, k) -> (k)>,  // broadcast LHS
                      affine_map<(m, n, k) -> (k)>, // broadcast RHS
                      affine_map<(m, n, k) -> (m, n)>]

should yield

    indexing_maps = [affine_map<(m, n, k) -> (m, k)>,  // broadcast LHS
                      affine_map<(m, n, k) -> (k, n)>, // broadcast RHS
                      affine_map<(m, n, k) -> (m, n)>]

as these can be derived from the default indexing maps for `linalg.matmul.

I agree that's a valid strategy. However, the same linalg.contract with implicit (broadcasted) parallel dimensions on LHS and RHS can be realized in any possible permutation of these maps as there's no default definition here. For example:

indexing_maps = [affine_map<(m, n, k) -> (m, n, k)>,
                     affine_map<(m, n, k) -> (k)>,
                     affine_map<(m, n, k) -> (m, n)>]

is a valid interpretation for broadcast decomposition.

@adam-smnk
Copy link
Contributor Author

adam-smnk commented Jul 10, 2025

The impedance mismatch between linalg and vector contractions comes from the fact that vector.contract doesn't support (EDIT) all possible (just mislead myself, above example wants to imply batch broadcast but there's no real broadcasting done there) broadcasts. So, before or during vectorization, the broadcasts have to be decomposed and the extra dimensions created before they go into vector.contract.

Today, vectorizer obviously can vectorize such linalg ops with broadcasts.
But I claim the default realization is a bit poor and I wonder if we could do better.

@adam-smnk
Copy link
Contributor Author

adam-smnk commented Jul 10, 2025

Personally, I lean toward forcing users to first decompose unsupported affine maps to not have to guess or compromise on an arbitrary default decomposition for these edge cases.

All the default named op contractions and maps like these (note: it's not really batch just standard parallel dimension but just sticking to previous example's naming):

      indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>,
                       affine_map<(batch, m, n, k) -> (k, n)>,
                       affine_map<(batch, m, n, k) -> (batch, m, n)>]

are directly representable by vector.contract and should be directly vectorized.

@banach-space
Copy link
Contributor

Thanks for elaborating, Adam! Now I see what was the source of confusion for me. Basically, there is no ambiguity at the Linalg level, but rather in how we materlialise broadcasts at the Vector level.

I lean toward forcing users to first decompose unsupported affine maps to not have to guess or compromise on an arbitrary default decomposition for these edge cases.

+1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants