Skip to content

[mlir][linalg] Add support for scalable vectorization of linalg.mmt4d #146531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

banach-space
Copy link
Contributor

This patch introduces support for scalable vectorization of linalg.mmt4d.
The key design addition is a new state variable in the Linalg vectorizer:

  • assumeScalableVecSizesMatchDimSize

This flag informs the vectorizer that the memref/tensor dimensions
corresponding to scalable vector sizes (typically dynamic) match the
vector sizes
at runtime.

While this assumption is not generally valid, it does hold for
linalg.mmt4d because inputs and outputs are explicitly packed (via
linalg.pack). Packing includes padding, which ensures that dimension
sizes align with the scalable vector lengths (*).

See discussion here:

(*) Provided that the tile sizes used for packing match the vector sizes used
during vectorization. Enforcing this is left to the user.

This patch introduces support for scalable vectorization of `linalg.mmt4d`.
The key design addition is a new state variable in the Linalg vectorizer:

  * `assumeScalableVecSizesMatchDimSize`

This flag informs the vectorizer that the memref/tensor dimensions
corresponding to scalable vector sizes (typically dynamic) _match the
vector sizes_ at runtime.

While this assumption is not generally valid, it does hold for
`linalg.mmt4d` because inputs and outputs are explicitly packed (via
`linalg.pack`). Packing includes padding, which ensures that dimension
sizes align with the scalable vector lengths (*).

See discussion here:
* llvm#143920

(*) Provided that the tile sizes used for packing match the vector sizes used
during vectorization. Enforcing this is left to the user.
@llvmbot
Copy link
Member

llvmbot commented Jul 1, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

Changes

This patch introduces support for scalable vectorization of linalg.mmt4d.
The key design addition is a new state variable in the Linalg vectorizer:

  • assumeScalableVecSizesMatchDimSize

This flag informs the vectorizer that the memref/tensor dimensions
corresponding to scalable vector sizes (typically dynamic) match the
vector sizes
at runtime.

While this assumption is not generally valid, it does hold for
linalg.mmt4d because inputs and outputs are explicitly packed (via
linalg.pack). Packing includes padding, which ensures that dimension
sizes align with the scalable vector lengths (*).

See discussion here:

(*) Provided that the tile sizes used for packing match the vector sizes used
during vectorization. Enforcing this is left to the user.


Full diff: https://github.com/llvm/llvm-project/pull/146531.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+5-6)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+55-24)
  • (modified) mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir (+93-24)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index d64f94a49f781..baa17f75e53b6 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2440,12 +2440,11 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                       Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
-                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
-                          $static_vector_sizes,
-                       OptionalAttr<UnitAttr>:$vectorize_nd_extract,
-                       DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
-                          $scalable_sizes);
+      Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
+      DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
+      OptionalAttr<UnitAttr>:$vectorize_nd_extract,
+      OptionalAttr<UnitAttr>:$assume_scalable_sizes_match_dim_size,
+      DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
 
   let results = (outs);
 
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2b4855f49695c..a6d697b43c0b7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -871,7 +871,8 @@ FailureOr<VectorizationResult>
 vectorize(RewriterBase &rewriter, Operation *op,
           ArrayRef<int64_t> inputVectorSizes = {},
           ArrayRef<bool> inputScalableVecDims = {},
-          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
+          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
+          bool assumeScalableSizesMultipleOfDim = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8571d641e26d1..49b9a41831fc6 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3921,7 +3921,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
     }
     FailureOr<VectorizationResult> vectorResults =
         linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
-                          getVectorizeNdExtract().value_or(false));
+                          getVectorizeNdExtract().value_or(false), false,
+                          getAssumeScalableSizesMatchDimSize().value_or(false));
     if (failed(vectorResults)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Attempted to vectorize, but failed";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b467114c72f7d..3a533322a3c7f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -222,9 +222,11 @@ struct VectorizationState {
   /// canonical vector shape for vectorization.
   LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
                           ArrayRef<int64_t> inputVectorSizes,
-                          ArrayRef<bool> inputScalableVecDims);
+                          ArrayRef<bool> inputScalableVecDims,
+                          bool assumeScalableVecSizesMatchDimSize = false);
 
-  /// Returns the canonical vector shape used to vectorize the iteration space.
+  /// Returns the canonical vector shape used to vectorize the iteration
+  /// space.
   ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
 
   /// Returns the vector dimensions that are scalable in the canonical vector
@@ -233,8 +235,8 @@ struct VectorizationState {
 
   /// Returns a vector type of the provided `elementType` with the canonical
   /// vector shape and the corresponding fixed/scalable dimensions bit. If
-  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
-  /// accordingly.
+  /// `dimPermutation` is provided, the canonical vector dimensions are
+  /// permuted accordingly.
   VectorType getCanonicalVecType(
       Type elementType,
       std::optional<AffineMap> dimPermutation = std::nullopt) const {
@@ -254,9 +256,9 @@ struct VectorizationState {
   }
 
   /// Masks an operation with the canonical vector mask if the operation needs
-  /// masking. Returns the masked operation or the original operation if masking
-  /// is not needed. If provided, the canonical mask for this operation is
-  /// permuted using `maybeIndexingMap`.
+  /// masking. Returns the masked operation or the original operation if
+  /// masking is not needed. If provided, the canonical mask for this
+  /// operation is permuted using `maybeIndexingMap`.
   Operation *
   maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
                 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
@@ -276,15 +278,15 @@ struct VectorizationState {
 
   /// Create or retrieve an existing mask value to mask `opToMask` in the
   /// canonical vector iteration space. If `maybeMaskingMap` the mask is
-  /// permuted using that permutation map. If a new mask is created, it will be
-  /// cached for future users.
+  /// permuted using that permutation map. If a new mask is created, it will
+  /// be cached for future users.
   Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
                            LinalgOp linalgOp,
                            std::optional<AffineMap> maybeMaskingMap);
 
   /// Check whether this permutation map can be used for masking. At the
-  /// moment we only make sure that there are no broadcast dimensions, but this
-  /// might change if indexing maps evolve.
+  /// moment we only make sure that there are no broadcast dimensions, but
+  /// this might change if indexing maps evolve.
   bool isValidMaskingMap(AffineMap maskingMap) {
     return maskingMap.getBroadcastDims().size() == 0;
   }
@@ -324,13 +326,24 @@ struct VectorizationState {
   /// shape.
   SmallVector<bool> scalableVecDims;
 
-  /// Holds the active masks for permutations of the canonical vector iteration
-  /// space.
+  /// Holds the active masks for permutations of the canonical vector
+  /// iteration space.
   DenseMap<AffineMap, Value> activeMaskCache;
 
   /// Global vectorization guard for the incoming rewriter. It's initialized
   /// when the vectorization state is initialized.
   OpBuilder::InsertionGuard rewriterGuard;
+
+  /// Do all scalable vector sizes match the corresponding input dim sizes?
+  /// (tensor or memref)
+  ///
+  /// At the Tensor + MemRef levels, scalable sizes are modelled using
+  /// dynamic dimensions (i.e. `?`). In many cases these sizes result from
+  /// e.g. "scalable packing + tiling" and are known to always match the
+  /// scalable vector sizes. In such cases, masking can be safely skipped,
+  /// despite the presence of dynamic shapes. Use this flag with care and
+  /// only for cases where you are confident the assumption holds.
+  bool assumeScalableVecSizesMatchDimSize = false;
 };
 
 LogicalResult
@@ -367,10 +380,12 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
 /// Initializes the vectorization state, including the computation of the
 /// canonical vector shape for vectorization.
 // TODO: Move this to the constructor when we can remove the failure cases.
-LogicalResult
-VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
-                              ArrayRef<int64_t> inputVectorSizes,
-                              ArrayRef<bool> inputScalableVecDims) {
+LogicalResult VectorizationState::initState(RewriterBase &rewriter,
+                                            LinalgOp linalgOp,
+                                            ArrayRef<int64_t> inputVectorSizes,
+                                            ArrayRef<bool> inputScalableVecDims,
+                                            bool assumeScalableSizes) {
+  assumeScalableVecSizesMatchDimSize = assumeScalableSizes;
   // Initialize the insertion point.
   rewriter.setInsertionPoint(linalgOp);
 
@@ -470,6 +485,21 @@ Value VectorizationState::getOrCreateMaskFor(
     return Value();
   }
 
+  if (assumeScalableVecSizesMatchDimSize) {
+    // Given that all _scalable vector sizes_ match the corresponding
+    // memref/tensor dim sizes, masking can be skipped provided that:
+    // * all vector sizes corresponding to dynamic dims are scalable.
+    if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getScalableDims()),
+                     [](auto it) {
+                       return std::get<0>(it) == ShapedType::kDynamic
+                                  ? std::get<1>(it)
+                                  : false;
+                     }))
+      LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
+    activeMaskCache[maskingMap] = Value();
+    return Value();
+  }
+
   // Permute the iteration space value sizes to compute the mask upper bounds.
   SmallVector<Value> upperBounds =
       applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
@@ -2479,7 +2509,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
   return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
                  isa<linalg::MatmulTransposeAOp>(op) ||
                  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
-                 isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
+                 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
+                 hasReductionIterator(linalgOp));
 }
 
 LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2535,11 +2566,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
              tensor::InsertSliceOp>(op);
 }
 
-FailureOr<VectorizationResult>
-mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
-                        ArrayRef<int64_t> inputVectorSizes,
-                        ArrayRef<bool> inputScalableVecDims,
-                        bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
+FailureOr<VectorizationResult> mlir::linalg::vectorize(
+    RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
+    ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
+    bool flatten1DDepthwiseConv, bool assumeScalableSizesMultipleOfDim) {
   LDBG("Attempting to vectorize:\n" << *op << "\n");
   LDBG("Input vector sizes: ");
   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2559,7 +2589,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
   VectorizationState state(rewriter);
   if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
     if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
-                               inputScalableVecDims))) {
+                               inputScalableVecDims,
+                               assumeScalableSizesMultipleOfDim))) {
       LDBG("Vectorization state couldn't be initialized\n");
       return failure();
     }
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 6722de817f6bf..188f03069938f 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -840,6 +840,99 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.mmt4d
+///----------------------------------------------------------------------------------------
+
+func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
+  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
+               outs(%C_in: memref<16x16x8x8xf32>)
+  return
+}
+
+// CHECK-LABEL:   func.func @mmt4d(
+// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
+// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
+// CHECK:           vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %mmt4d : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @mmt4d_scalable(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
+  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
+               outs(%C_in: memref<16x16x8x?xf32>)
+  return
+}
+// CHECK-LABEL:   func.func @mmt4d_scalable(
+// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<16x16x?x1xf32>,
+// CHECK-SAME:      %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 16 : index
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[DIM_2:.*]] = memref.dim %[[B]], %[[C2]] : memref<16x16x?x1xf32>
+// CHECK:           %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_1:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x[4]x1xi1>
+// CHECK:           %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> } : vector<16x16x[4]x1xi1> -> vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
+// CHECK:           %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
+// CHECK:           %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x16x8x[4]x1xi1>
+// CHECK:           %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16]], %[[VAL_15]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
+// CHECK:           vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
+  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
+               outs(%C_in: memref<16x16x8x?xf32>)
+  return
+}
+// CHECK-LABEL:   func.func @mmt4d_scalable_with_assume(
+// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<16x16x?x1xf32>,
+// CHECK-SAME:      %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
+// CHECK-NOT:       mask
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
+// CHECK:           %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VAL_15:.*]] = vector.multi_reduction <add>, %[[VAL_14]], %[[VAL_13]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
+// CHECK:           vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] {assume_scalable_sizes_match_dim_size} : !transform.any_op
+    transform.yield
+  }
+}
+
 ///----------------------------------------------------------------------------------------
 /// Tests for other Ops
 ///----------------------------------------------------------------------------------------
@@ -1094,30 +1187,6 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// -----
-
-func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
-  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
-               outs(%C_in: memref<16x16x8x8xf32>)
-  return
-}
-
-// CHECK-LABEL:   func.func @mmt4d(
-// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
-// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
-// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
-// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
-// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
-// CHECK:           %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
-// CHECK:           vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %mmt4d : !transform.any_op
-    transform.yield
-  }
-}
 
 // -----
 

@llvmbot
Copy link
Member

llvmbot commented Jul 1, 2025

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

This patch introduces support for scalable vectorization of linalg.mmt4d.
The key design addition is a new state variable in the Linalg vectorizer:

  • assumeScalableVecSizesMatchDimSize

This flag informs the vectorizer that the memref/tensor dimensions
corresponding to scalable vector sizes (typically dynamic) match the
vector sizes
at runtime.

While this assumption is not generally valid, it does hold for
linalg.mmt4d because inputs and outputs are explicitly packed (via
linalg.pack). Packing includes padding, which ensures that dimension
sizes align with the scalable vector lengths (*).

See discussion here:

(*) Provided that the tile sizes used for packing match the vector sizes used
during vectorization. Enforcing this is left to the user.


Full diff: https://github.com/llvm/llvm-project/pull/146531.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+5-6)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+55-24)
  • (modified) mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir (+93-24)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index d64f94a49f781..baa17f75e53b6 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2440,12 +2440,11 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                       Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
-                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
-                          $static_vector_sizes,
-                       OptionalAttr<UnitAttr>:$vectorize_nd_extract,
-                       DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
-                          $scalable_sizes);
+      Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
+      DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
+      OptionalAttr<UnitAttr>:$vectorize_nd_extract,
+      OptionalAttr<UnitAttr>:$assume_scalable_sizes_match_dim_size,
+      DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
 
   let results = (outs);
 
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2b4855f49695c..a6d697b43c0b7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -871,7 +871,8 @@ FailureOr<VectorizationResult>
 vectorize(RewriterBase &rewriter, Operation *op,
           ArrayRef<int64_t> inputVectorSizes = {},
           ArrayRef<bool> inputScalableVecDims = {},
-          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
+          bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
+          bool assumeScalableSizesMultipleOfDim = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8571d641e26d1..49b9a41831fc6 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3921,7 +3921,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
     }
     FailureOr<VectorizationResult> vectorResults =
         linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
-                          getVectorizeNdExtract().value_or(false));
+                          getVectorizeNdExtract().value_or(false), false,
+                          getAssumeScalableSizesMatchDimSize().value_or(false));
     if (failed(vectorResults)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Attempted to vectorize, but failed";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b467114c72f7d..3a533322a3c7f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -222,9 +222,11 @@ struct VectorizationState {
   /// canonical vector shape for vectorization.
   LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
                           ArrayRef<int64_t> inputVectorSizes,
-                          ArrayRef<bool> inputScalableVecDims);
+                          ArrayRef<bool> inputScalableVecDims,
+                          bool assumeScalableVecSizesMatchDimSize = false);
 
-  /// Returns the canonical vector shape used to vectorize the iteration space.
+  /// Returns the canonical vector shape used to vectorize the iteration
+  /// space.
   ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
 
   /// Returns the vector dimensions that are scalable in the canonical vector
@@ -233,8 +235,8 @@ struct VectorizationState {
 
   /// Returns a vector type of the provided `elementType` with the canonical
   /// vector shape and the corresponding fixed/scalable dimensions bit. If
-  /// `dimPermutation` is provided, the canonical vector dimensions are permuted
-  /// accordingly.
+  /// `dimPermutation` is provided, the canonical vector dimensions are
+  /// permuted accordingly.
   VectorType getCanonicalVecType(
       Type elementType,
       std::optional<AffineMap> dimPermutation = std::nullopt) const {
@@ -254,9 +256,9 @@ struct VectorizationState {
   }
 
   /// Masks an operation with the canonical vector mask if the operation needs
-  /// masking. Returns the masked operation or the original operation if masking
-  /// is not needed. If provided, the canonical mask for this operation is
-  /// permuted using `maybeIndexingMap`.
+  /// masking. Returns the masked operation or the original operation if
+  /// masking is not needed. If provided, the canonical mask for this
+  /// operation is permuted using `maybeIndexingMap`.
   Operation *
   maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
                 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
@@ -276,15 +278,15 @@ struct VectorizationState {
 
   /// Create or retrieve an existing mask value to mask `opToMask` in the
   /// canonical vector iteration space. If `maybeMaskingMap` the mask is
-  /// permuted using that permutation map. If a new mask is created, it will be
-  /// cached for future users.
+  /// permuted using that permutation map. If a new mask is created, it will
+  /// be cached for future users.
   Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
                            LinalgOp linalgOp,
                            std::optional<AffineMap> maybeMaskingMap);
 
   /// Check whether this permutation map can be used for masking. At the
-  /// moment we only make sure that there are no broadcast dimensions, but this
-  /// might change if indexing maps evolve.
+  /// moment we only make sure that there are no broadcast dimensions, but
+  /// this might change if indexing maps evolve.
   bool isValidMaskingMap(AffineMap maskingMap) {
     return maskingMap.getBroadcastDims().size() == 0;
   }
@@ -324,13 +326,24 @@ struct VectorizationState {
   /// shape.
   SmallVector<bool> scalableVecDims;
 
-  /// Holds the active masks for permutations of the canonical vector iteration
-  /// space.
+  /// Holds the active masks for permutations of the canonical vector
+  /// iteration space.
   DenseMap<AffineMap, Value> activeMaskCache;
 
   /// Global vectorization guard for the incoming rewriter. It's initialized
   /// when the vectorization state is initialized.
   OpBuilder::InsertionGuard rewriterGuard;
+
+  /// Do all scalable vector sizes match the corresponding input dim sizes?
+  /// (tensor or memref)
+  ///
+  /// At the Tensor + MemRef levels, scalable sizes are modelled using
+  /// dynamic dimensions (i.e. `?`). In many cases these sizes result from
+  /// e.g. "scalable packing + tiling" and are known to always match the
+  /// scalable vector sizes. In such cases, masking can be safely skipped,
+  /// despite the presence of dynamic shapes. Use this flag with care and
+  /// only for cases where you are confident the assumption holds.
+  bool assumeScalableVecSizesMatchDimSize = false;
 };
 
 LogicalResult
@@ -367,10 +380,12 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
 /// Initializes the vectorization state, including the computation of the
 /// canonical vector shape for vectorization.
 // TODO: Move this to the constructor when we can remove the failure cases.
-LogicalResult
-VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
-                              ArrayRef<int64_t> inputVectorSizes,
-                              ArrayRef<bool> inputScalableVecDims) {
+LogicalResult VectorizationState::initState(RewriterBase &rewriter,
+                                            LinalgOp linalgOp,
+                                            ArrayRef<int64_t> inputVectorSizes,
+                                            ArrayRef<bool> inputScalableVecDims,
+                                            bool assumeScalableSizes) {
+  assumeScalableVecSizesMatchDimSize = assumeScalableSizes;
   // Initialize the insertion point.
   rewriter.setInsertionPoint(linalgOp);
 
@@ -470,6 +485,21 @@ Value VectorizationState::getOrCreateMaskFor(
     return Value();
   }
 
+  if (assumeScalableVecSizesMatchDimSize) {
+    // Given that all _scalable vector sizes_ match the corresponding
+    // memref/tensor dim sizes, masking can be skipped provided that:
+    // * all vector sizes corresponding to dynamic dims are scalable.
+    if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getScalableDims()),
+                     [](auto it) {
+                       return std::get<0>(it) == ShapedType::kDynamic
+                                  ? std::get<1>(it)
+                                  : false;
+                     }))
+      LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
+    activeMaskCache[maskingMap] = Value();
+    return Value();
+  }
+
   // Permute the iteration space value sizes to compute the mask upper bounds.
   SmallVector<Value> upperBounds =
       applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
@@ -2479,7 +2509,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
   return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
                  isa<linalg::MatmulTransposeAOp>(op) ||
                  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
-                 isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
+                 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
+                 hasReductionIterator(linalgOp));
 }
 
 LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2535,11 +2566,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
              tensor::InsertSliceOp>(op);
 }
 
-FailureOr<VectorizationResult>
-mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
-                        ArrayRef<int64_t> inputVectorSizes,
-                        ArrayRef<bool> inputScalableVecDims,
-                        bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
+FailureOr<VectorizationResult> mlir::linalg::vectorize(
+    RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
+    ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
+    bool flatten1DDepthwiseConv, bool assumeScalableSizesMultipleOfDim) {
   LDBG("Attempting to vectorize:\n" << *op << "\n");
   LDBG("Input vector sizes: ");
   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2559,7 +2589,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
   VectorizationState state(rewriter);
   if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
     if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
-                               inputScalableVecDims))) {
+                               inputScalableVecDims,
+                               assumeScalableSizesMultipleOfDim))) {
       LDBG("Vectorization state couldn't be initialized\n");
       return failure();
     }
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 6722de817f6bf..188f03069938f 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -840,6 +840,99 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.mmt4d
+///----------------------------------------------------------------------------------------
+
+func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
+  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
+               outs(%C_in: memref<16x16x8x8xf32>)
+  return
+}
+
+// CHECK-LABEL:   func.func @mmt4d(
+// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
+// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
+// CHECK:           vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %mmt4d : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @mmt4d_scalable(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
+  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
+               outs(%C_in: memref<16x16x8x?xf32>)
+  return
+}
+// CHECK-LABEL:   func.func @mmt4d_scalable(
+// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<16x16x?x1xf32>,
+// CHECK-SAME:      %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 16 : index
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[DIM_2:.*]] = memref.dim %[[B]], %[[C2]] : memref<16x16x?x1xf32>
+// CHECK:           %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_1:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x[4]x1xi1>
+// CHECK:           %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> } : vector<16x16x[4]x1xi1> -> vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
+// CHECK:           %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
+// CHECK:           %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x16x8x[4]x1xi1>
+// CHECK:           %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16]], %[[VAL_15]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
+// CHECK:           vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
+  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
+               outs(%C_in: memref<16x16x8x?xf32>)
+  return
+}
+// CHECK-LABEL:   func.func @mmt4d_scalable_with_assume(
+// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<16x16x?x1xf32>,
+// CHECK-SAME:      %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
+// CHECK-NOT:       mask
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
+// CHECK:           %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VAL_15:.*]] = vector.multi_reduction <add>, %[[VAL_14]], %[[VAL_13]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
+// CHECK:           vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] {assume_scalable_sizes_match_dim_size} : !transform.any_op
+    transform.yield
+  }
+}
+
 ///----------------------------------------------------------------------------------------
 /// Tests for other Ops
 ///----------------------------------------------------------------------------------------
@@ -1094,30 +1187,6 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// -----
-
-func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
-  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
-               outs(%C_in: memref<16x16x8x8xf32>)
-  return
-}
-
-// CHECK-LABEL:   func.func @mmt4d(
-// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
-// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
-// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
-// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
-// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
-// CHECK:           %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
-// CHECK:           vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %mmt4d : !transform.any_op
-    transform.yield
-  }
-}
 
 // -----
 

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! A few comments about the direction

@@ -871,7 +871,8 @@ FailureOr<VectorizationResult>
vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
ArrayRef<bool> inputScalableVecDims = {},
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
bool assumeScalableSizesMultipleOfDim = false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have assumeScalableSizesMultipleOfDim and getAssumeScalableSizesMatchDimSize. Should we have only one?
Also, it should be "dim multiple of vector sizes"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, I went with assumeDynamicDimsMatchVecSizes, see this commit. As mentioned in my other comment, for now I am "assuming" equality rather than divisibility. Not sure whether we will need the latter?

/// scalable vector sizes. In such cases, masking can be safely skipped,
/// despite the presence of dynamic shapes. Use this flag with care and
/// only for cases where you are confident the assumption holds.
bool assumeScalableVecSizesMatchDimSize = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this generic, not only for scalable vectors? Assuming a dynamic dimension is multiple of a vector size, scalable or not, is useful in general.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should also add a TODO, as this eventually should be a list containing the divisibility information for each dimension...

Copy link
Contributor

@hanhanW hanhanW Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having the bool option could be beneficial to IREE because IREE may use TensorDynamicDimAnalysis and can use such information to determine if the assumption is true or not. The analysis uses two MLIR analysis and one IREE specific analysis, which may be upstreamable. I seldom work on this area, so I'm not pretty sure about it. I mainly wanna share a data-point about "assume" mechanism that I learned before, which seems related to @dcaballe's comment.

TensorDynamicDimAnalysis::TensorDynamicDimAnalysis(Operation *rootOp)
    : rootOperation(rootOp) {
  solver.load<mlir::dataflow::DeadCodeAnalysis>();
  solver.load<mlir::dataflow::IntegerRangeAnalysis>();
  solver.load<IREE::Util::IntegerDivisibilityAnalysis>();
}

In IREE, we have the hint in IR and we can apply above analysis to retrieve the information. It is IREE specific because the assume.int op is defined by IREE core dialects. E.g.,

   %0:2 = util.assume.int
       %m_in<umin = 16, umax = 4080, udiv = 16>,
       %k2_in<umin = 16, umax = 4080, udiv = 32>
     : index, index

However, IIUC, what's missing is something like ValueRangeAnalysis (not just for integers) because what you wonder is if the size if multiple of vscale or not?

Anyway, I think it is a good start, and I feel that we are missing some analysis to determine if we can make the value true or not. I don't suggest kick in an analysis during vectorization because it could be expensive; my understanding is that it is better to run it once in the beginning of a pass, and we use it vectorize all the ops later.

(I'm happy to move the discussion to the issue, if it is better.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this generic, not only for scalable vectors?

Sorry, I over-indexed on "scalable vectors" .

Assuming a dynamic dimension is multiple of a vector size, scalable or not, is useful in general.

For my needs ATM, equality (rather than divisibility) is sufficient. Do you reckon that we will require divisibility in the future?

However, IIUC, what's missing is something like ValueRangeAnalysis (not just for integers) because what you wonder is if the size if multiple of vscale or not?

We do actually have and use ValueRangeAnalysis for "scalable vectors", see e.g. in Codegen/Transforms/Transforms.cpp in IREE. However, AFAIK, that analysis is costly and in the case of linalg.mmt4d, we just know up front that the assumption holds. Hence taking this approach rather through other means. But now I see that there's even more options in IREE that I should consider in the future, thanks for the pointers @hanhanW 🙏🏻

Copy link
Contributor

@egebeysel egebeysel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Just as an FYI: I actually integrated this to my prototype branch of IREE in the context of iree-org/iree#21304 and iree-org/iree#16162 and confirmed it works. Though I left some minor comments :)

Comment on lines +494 to +496
return std::get<0>(it) == ShapedType::kDynamic
? std::get<1>(it)
: false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I misinterpreting this for some reason or does this imply: "if every size in permutedStaticSizes dynamic and the corresponding dim is scalable". Also, it just prints the debug print and does not guard the following lines that set the mask to nothing and return

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@@ -2479,7 +2509,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
isa<linalg::MatmulTransposeAOp>(op) ||
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if there is a particular reason why this wouldn't work for batch_mmt4d as well?

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm late to the party, and I added my point about analysis in my inline comment. It looks like we have something in our downstream project, but it is not enough for the scalable vectorization of mmt4d ops.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some changes on comments are not relevant, can you revert them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how that crept it. Fixed in this commit.

/// scalable vector sizes. In such cases, masking can be safely skipped,
/// despite the presence of dynamic shapes. Use this flag with care and
/// only for cases where you are confident the assumption holds.
bool assumeScalableVecSizesMatchDimSize = false;
Copy link
Contributor

@hanhanW hanhanW Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having the bool option could be beneficial to IREE because IREE may use TensorDynamicDimAnalysis and can use such information to determine if the assumption is true or not. The analysis uses two MLIR analysis and one IREE specific analysis, which may be upstreamable. I seldom work on this area, so I'm not pretty sure about it. I mainly wanna share a data-point about "assume" mechanism that I learned before, which seems related to @dcaballe's comment.

TensorDynamicDimAnalysis::TensorDynamicDimAnalysis(Operation *rootOp)
    : rootOperation(rootOp) {
  solver.load<mlir::dataflow::DeadCodeAnalysis>();
  solver.load<mlir::dataflow::IntegerRangeAnalysis>();
  solver.load<IREE::Util::IntegerDivisibilityAnalysis>();
}

In IREE, we have the hint in IR and we can apply above analysis to retrieve the information. It is IREE specific because the assume.int op is defined by IREE core dialects. E.g.,

   %0:2 = util.assume.int
       %m_in<umin = 16, umax = 4080, udiv = 16>,
       %k2_in<umin = 16, umax = 4080, udiv = 32>
     : index, index

However, IIUC, what's missing is something like ValueRangeAnalysis (not just for integers) because what you wonder is if the size if multiple of vscale or not?

Anyway, I think it is a good start, and I feel that we are missing some analysis to determine if we can make the value true or not. I don't suggest kick in an analysis during vectorization because it could be expensive; my understanding is that it is better to run it once in the beginning of a pass, and we use it vectorize all the ops later.

(I'm happy to move the discussion to the issue, if it is better.)

Comment on lines +494 to +496
return std::get<0>(it) == ShapedType::kDynamic
? std::get<1>(it)
: false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

…g.mmt4d

Rename the bool to assumeDynamicDimsMatchVecSizes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants