From 487db47dbae6de69f174868b625fb7730612c09a Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Tue, 1 Jul 2025 13:39:46 +0000 Subject: [PATCH 1/3] [mlir][linalg] Add support for scalable vectorization of linalg.mmt4d This patch introduces support for scalable vectorization of `linalg.mmt4d`. The key design addition is a new state variable in the Linalg vectorizer: * `assumeScalableVecSizesMatchDimSize` This flag informs the vectorizer that the memref/tensor dimensions corresponding to scalable vector sizes (typically dynamic) _match the vector sizes_ at runtime. While this assumption is not generally valid, it does hold for `linalg.mmt4d` because inputs and outputs are explicitly packed (via `linalg.pack`). Packing includes padding, which ensures that dimension sizes align with the scalable vector lengths (*). See discussion here: * https://github.com/llvm/llvm-project/issues/143920 (*) Provided that the tile sizes used for packing match the vector sizes used during vectorization. Enforcing this is left to the user. --- .../Linalg/TransformOps/LinalgTransformOps.td | 11 +- .../Dialect/Linalg/Transforms/Transforms.h | 3 +- .../TransformOps/LinalgTransformOps.cpp | 3 +- .../Linalg/Transforms/Vectorization.cpp | 79 ++++++++---- .../Linalg/vectorization/linalg-ops.mlir | 117 ++++++++++++++---- 5 files changed, 157 insertions(+), 56 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index d64f94a49f781..baa17f75e53b6 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2440,12 +2440,11 @@ def VectorizeOp : Op:$vector_sizes, - DefaultValuedOptionalAttr: - $static_vector_sizes, - OptionalAttr:$vectorize_nd_extract, - DefaultValuedOptionalAttr: - $scalable_sizes); + Variadic:$vector_sizes, + DefaultValuedOptionalAttr:$static_vector_sizes, + OptionalAttr:$vectorize_nd_extract, + OptionalAttr:$assume_scalable_sizes_match_dim_size, + DefaultValuedOptionalAttr:$scalable_sizes); let results = (outs); diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 2b4855f49695c..a6d697b43c0b7 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -871,7 +871,8 @@ FailureOr vectorize(RewriterBase &rewriter, Operation *op, ArrayRef inputVectorSizes = {}, ArrayRef inputScalableVecDims = {}, - bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false); + bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false, + bool assumeScalableSizesMultipleOfDim = false); /// Emit a suitable vector form for a Copy op with fully static shape. LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 8571d641e26d1..49b9a41831fc6 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3921,7 +3921,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply( } FailureOr vectorResults = linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(), - getVectorizeNdExtract().value_or(false)); + getVectorizeNdExtract().value_or(false), false, + getAssumeScalableSizesMatchDimSize().value_or(false)); if (failed(vectorResults)) { return mlir::emitSilenceableFailure(target->getLoc()) << "Attempted to vectorize, but failed"; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index b467114c72f7d..3a533322a3c7f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -222,9 +222,11 @@ struct VectorizationState { /// canonical vector shape for vectorization. LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef inputVectorSizes, - ArrayRef inputScalableVecDims); + ArrayRef inputScalableVecDims, + bool assumeScalableVecSizesMatchDimSize = false); - /// Returns the canonical vector shape used to vectorize the iteration space. + /// Returns the canonical vector shape used to vectorize the iteration + /// space. ArrayRef getCanonicalVecShape() const { return canonicalVecShape; } /// Returns the vector dimensions that are scalable in the canonical vector @@ -233,8 +235,8 @@ struct VectorizationState { /// Returns a vector type of the provided `elementType` with the canonical /// vector shape and the corresponding fixed/scalable dimensions bit. If - /// `dimPermutation` is provided, the canonical vector dimensions are permuted - /// accordingly. + /// `dimPermutation` is provided, the canonical vector dimensions are + /// permuted accordingly. VectorType getCanonicalVecType( Type elementType, std::optional dimPermutation = std::nullopt) const { @@ -254,9 +256,9 @@ struct VectorizationState { } /// Masks an operation with the canonical vector mask if the operation needs - /// masking. Returns the masked operation or the original operation if masking - /// is not needed. If provided, the canonical mask for this operation is - /// permuted using `maybeIndexingMap`. + /// masking. Returns the masked operation or the original operation if + /// masking is not needed. If provided, the canonical mask for this + /// operation is permuted using `maybeIndexingMap`. Operation * maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional maybeIndexingMap = std::nullopt); @@ -276,15 +278,15 @@ struct VectorizationState { /// Create or retrieve an existing mask value to mask `opToMask` in the /// canonical vector iteration space. If `maybeMaskingMap` the mask is - /// permuted using that permutation map. If a new mask is created, it will be - /// cached for future users. + /// permuted using that permutation map. If a new mask is created, it will + /// be cached for future users. Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional maybeMaskingMap); /// Check whether this permutation map can be used for masking. At the - /// moment we only make sure that there are no broadcast dimensions, but this - /// might change if indexing maps evolve. + /// moment we only make sure that there are no broadcast dimensions, but + /// this might change if indexing maps evolve. bool isValidMaskingMap(AffineMap maskingMap) { return maskingMap.getBroadcastDims().size() == 0; } @@ -324,13 +326,24 @@ struct VectorizationState { /// shape. SmallVector scalableVecDims; - /// Holds the active masks for permutations of the canonical vector iteration - /// space. + /// Holds the active masks for permutations of the canonical vector + /// iteration space. DenseMap activeMaskCache; /// Global vectorization guard for the incoming rewriter. It's initialized /// when the vectorization state is initialized. OpBuilder::InsertionGuard rewriterGuard; + + /// Do all scalable vector sizes match the corresponding input dim sizes? + /// (tensor or memref) + /// + /// At the Tensor + MemRef levels, scalable sizes are modelled using + /// dynamic dimensions (i.e. `?`). In many cases these sizes result from + /// e.g. "scalable packing + tiling" and are known to always match the + /// scalable vector sizes. In such cases, masking can be safely skipped, + /// despite the presence of dynamic shapes. Use this flag with care and + /// only for cases where you are confident the assumption holds. + bool assumeScalableVecSizesMatchDimSize = false; }; LogicalResult @@ -367,10 +380,12 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter, /// Initializes the vectorization state, including the computation of the /// canonical vector shape for vectorization. // TODO: Move this to the constructor when we can remove the failure cases. -LogicalResult -VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp, - ArrayRef inputVectorSizes, - ArrayRef inputScalableVecDims) { +LogicalResult VectorizationState::initState(RewriterBase &rewriter, + LinalgOp linalgOp, + ArrayRef inputVectorSizes, + ArrayRef inputScalableVecDims, + bool assumeScalableSizes) { + assumeScalableVecSizesMatchDimSize = assumeScalableSizes; // Initialize the insertion point. rewriter.setInsertionPoint(linalgOp); @@ -470,6 +485,21 @@ Value VectorizationState::getOrCreateMaskFor( return Value(); } + if (assumeScalableVecSizesMatchDimSize) { + // Given that all _scalable vector sizes_ match the corresponding + // memref/tensor dim sizes, masking can be skipped provided that: + // * all vector sizes corresponding to dynamic dims are scalable. + if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getScalableDims()), + [](auto it) { + return std::get<0>(it) == ShapedType::kDynamic + ? std::get<1>(it) + : false; + })) + LDBG("Masking is not needed for masking map: " << maskingMap << "\n"); + activeMaskCache[maskingMap] = Value(); + return Value(); + } + // Permute the iteration space value sizes to compute the mask upper bounds. SmallVector upperBounds = applyPermutationMap(maskingMap, ArrayRef(iterSpaceValueSizes)); @@ -2479,7 +2509,8 @@ vectorizeScalableVectorPrecondition(Operation *op, return success(isElementwise(linalgOp) || isa(op) || isa(op) || isa(op) || - isa(op) || hasReductionIterator(linalgOp)); + isa(op) || isa(op) || + hasReductionIterator(linalgOp)); } LogicalResult mlir::linalg::vectorizeOpPrecondition( @@ -2535,11 +2566,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) { tensor::InsertSliceOp>(op); } -FailureOr -mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, - ArrayRef inputVectorSizes, - ArrayRef inputScalableVecDims, - bool vectorizeNDExtract, bool flatten1DDepthwiseConv) { +FailureOr mlir::linalg::vectorize( + RewriterBase &rewriter, Operation *op, ArrayRef inputVectorSizes, + ArrayRef inputScalableVecDims, bool vectorizeNDExtract, + bool flatten1DDepthwiseConv, bool assumeScalableSizesMultipleOfDim) { LDBG("Attempting to vectorize:\n" << *op << "\n"); LDBG("Input vector sizes: "); LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); @@ -2559,7 +2589,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, VectorizationState state(rewriter); if (auto linalgOp = dyn_cast(op)) { if (failed(state.initState(rewriter, linalgOp, inputVectorSizes, - inputScalableVecDims))) { + inputScalableVecDims, + assumeScalableSizesMultipleOfDim))) { LDBG("Vectorization state couldn't be initialized\n"); return failure(); } diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir index 6722de817f6bf..188f03069938f 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir @@ -840,6 +840,99 @@ module attributes {transform.with_named_sequence} { } } +// ----- + +///---------------------------------------------------------------------------------------- +/// Tests for linalg.mmt4d +///---------------------------------------------------------------------------------------- + +func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) { + linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>) + outs(%C_in: memref<16x16x8x8xf32>) + return +} + +// CHECK-LABEL: func.func @mmt4d( +// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) { +// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32> +// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32> +// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32> +// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32> +// CHECK: %[[RED:.*]] = vector.multi_reduction , %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32> +// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %mmt4d : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @mmt4d_scalable(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) { + linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>) + outs(%C_in: memref<16x16x8x?xf32>) + return +} +// CHECK-LABEL: func.func @mmt4d_scalable( +// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, +// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>, +// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) { +// CHECK: %[[VAL_0:.*]] = arith.constant 16 : index +// CHECK: %[[VAL_1:.*]] = arith.constant 16 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 16 : index +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[DIM_2:.*]] = memref.dim %[[B]], %[[C2]] : memref<16x16x?x1xf32> +// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index +// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32> +// CHECK: %[[MASK_1:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x[4]x1xi1> +// CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> } : vector<16x16x[4]x1xi1> -> vector<16x16x16x8x[4]x1xf32> +// CHECK: %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1> +// CHECK: %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32> +// CHECK: %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32> +// CHECK: %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x16x8x[4]x1xi1> +// CHECK: %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction , %[[VAL_16]], %[[VAL_15]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32> +// CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1> + + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) { + linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>) + outs(%C_in: memref<16x16x8x?xf32>) + return +} +// CHECK-LABEL: func.func @mmt4d_scalable_with_assume( +// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, +// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>, +// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) { +// CHECK-NOT: mask +// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32> +// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> +// CHECK: %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> +// CHECK: %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32> +// CHECK: %[[VAL_15:.*]] = vector.multi_reduction , %[[VAL_14]], %[[VAL_13]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> +// CHECK: vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] {assume_scalable_sizes_match_dim_size} : !transform.any_op + transform.yield + } +} + ///---------------------------------------------------------------------------------------- /// Tests for other Ops ///---------------------------------------------------------------------------------------- @@ -1094,30 +1187,6 @@ module attributes {transform.with_named_sequence} { } } -// ----- - -func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) { - linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>) - outs(%C_in: memref<16x16x8x8xf32>) - return -} - -// CHECK-LABEL: func.func @mmt4d( -// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) { -// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32> -// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32> -// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32> -// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32> -// CHECK: %[[RED:.*]] = vector.multi_reduction , %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32> -// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32> - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.structured.vectorize %mmt4d : !transform.any_op - transform.yield - } -} // ----- From 8ef66618eeb45c2ad8a435f47715d0cf0b15fc86 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Thu, 10 Jul 2025 16:44:41 +0000 Subject: [PATCH 2/3] fixup! [mlir][linalg] Add support for scalable vectorization of linalg.mmt4d Revert changes in comments --- .../Linalg/Transforms/Vectorization.cpp | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 3a533322a3c7f..38bf37d844be4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -225,8 +225,7 @@ struct VectorizationState { ArrayRef inputScalableVecDims, bool assumeScalableVecSizesMatchDimSize = false); - /// Returns the canonical vector shape used to vectorize the iteration - /// space. + /// Returns the canonical vector shape used to vectorize the iteration space. ArrayRef getCanonicalVecShape() const { return canonicalVecShape; } /// Returns the vector dimensions that are scalable in the canonical vector @@ -235,8 +234,8 @@ struct VectorizationState { /// Returns a vector type of the provided `elementType` with the canonical /// vector shape and the corresponding fixed/scalable dimensions bit. If - /// `dimPermutation` is provided, the canonical vector dimensions are - /// permuted accordingly. + /// `dimPermutation` is provided, the canonical vector dimensions are permuted + /// accordingly. VectorType getCanonicalVecType( Type elementType, std::optional dimPermutation = std::nullopt) const { @@ -256,9 +255,9 @@ struct VectorizationState { } /// Masks an operation with the canonical vector mask if the operation needs - /// masking. Returns the masked operation or the original operation if - /// masking is not needed. If provided, the canonical mask for this - /// operation is permuted using `maybeIndexingMap`. + /// masking. Returns the masked operation or the original operation if masking + /// is not needed. If provided, the canonical mask for this operation is + /// permuted using `maybeIndexingMap`. Operation * maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional maybeIndexingMap = std::nullopt); @@ -278,15 +277,15 @@ struct VectorizationState { /// Create or retrieve an existing mask value to mask `opToMask` in the /// canonical vector iteration space. If `maybeMaskingMap` the mask is - /// permuted using that permutation map. If a new mask is created, it will - /// be cached for future users. + /// permuted using that permutation map. If a new mask is created, it will be + /// cached for future users. Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional maybeMaskingMap); /// Check whether this permutation map can be used for masking. At the - /// moment we only make sure that there are no broadcast dimensions, but - /// this might change if indexing maps evolve. + /// moment we only make sure that there are no broadcast dimensions, but this + /// might change if indexing maps evolve. bool isValidMaskingMap(AffineMap maskingMap) { return maskingMap.getBroadcastDims().size() == 0; } @@ -326,8 +325,8 @@ struct VectorizationState { /// shape. SmallVector scalableVecDims; - /// Holds the active masks for permutations of the canonical vector - /// iteration space. + /// Holds the active masks for permutations of the canonical vector iteration + /// space. DenseMap activeMaskCache; /// Global vectorization guard for the incoming rewriter. It's initialized From 2b6019caf04f37715ca5857b31b3f4fe9c48a417 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Thu, 10 Jul 2025 17:23:24 +0000 Subject: [PATCH 3/3] fixup! [mlir][linalg] Add support for scalable vectorization of linalg.mmt4d Rename the bool to assumeDynamicDimsMatchVecSizes --- .../Linalg/TransformOps/LinalgTransformOps.td | 2 +- .../Dialect/Linalg/Transforms/Transforms.h | 2 +- .../TransformOps/LinalgTransformOps.cpp | 2 +- .../Linalg/Transforms/Vectorization.cpp | 27 +++++++++---------- 4 files changed, 15 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index baa17f75e53b6..472df21cb464e 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2443,7 +2443,7 @@ def VectorizeOp : Op:$vector_sizes, DefaultValuedOptionalAttr:$static_vector_sizes, OptionalAttr:$vectorize_nd_extract, - OptionalAttr:$assume_scalable_sizes_match_dim_size, + OptionalAttr:$assume_dynamic_dims_match_vec_sizes, DefaultValuedOptionalAttr:$scalable_sizes); let results = (outs); diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index a6d697b43c0b7..8ba4f8f218721 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -872,7 +872,7 @@ vectorize(RewriterBase &rewriter, Operation *op, ArrayRef inputVectorSizes = {}, ArrayRef inputScalableVecDims = {}, bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false, - bool assumeScalableSizesMultipleOfDim = false); + bool assumeDynamicDimsMatchVecSizes = false); /// Emit a suitable vector form for a Copy op with fully static shape. LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 49b9a41831fc6..7e1911a56693f 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3922,7 +3922,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply( FailureOr vectorResults = linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(), getVectorizeNdExtract().value_or(false), false, - getAssumeScalableSizesMatchDimSize().value_or(false)); + getAssumeDynamicDimsMatchVecSizes().value_or(false)); if (failed(vectorResults)) { return mlir::emitSilenceableFailure(target->getLoc()) << "Attempted to vectorize, but failed"; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 38bf37d844be4..11ba3eebc128a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -223,7 +223,7 @@ struct VectorizationState { LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef inputVectorSizes, ArrayRef inputScalableVecDims, - bool assumeScalableVecSizesMatchDimSize = false); + bool assumeDynamicDimsMatchVecSizes = false); /// Returns the canonical vector shape used to vectorize the iteration space. ArrayRef getCanonicalVecShape() const { return canonicalVecShape; } @@ -333,16 +333,13 @@ struct VectorizationState { /// when the vectorization state is initialized. OpBuilder::InsertionGuard rewriterGuard; - /// Do all scalable vector sizes match the corresponding input dim sizes? - /// (tensor or memref) + /// Do all dynamic dims match the corresponding vector sizes? /// - /// At the Tensor + MemRef levels, scalable sizes are modelled using - /// dynamic dimensions (i.e. `?`). In many cases these sizes result from - /// e.g. "scalable packing + tiling" and are known to always match the - /// scalable vector sizes. In such cases, masking can be safely skipped, - /// despite the presence of dynamic shapes. Use this flag with care and - /// only for cases where you are confident the assumption holds. - bool assumeScalableVecSizesMatchDimSize = false; + /// When a dynamic tensor/memref dimension matches the corresponding vector + /// dimension, masking can be safely skipped, despite the presence of dynamic + /// shapes. Use this flag with care and only for cases where you are + /// confident the assumption holds. + bool assumeDynamicDimsMatchVecSizes = false; }; LogicalResult @@ -383,8 +380,8 @@ LogicalResult VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef inputVectorSizes, ArrayRef inputScalableVecDims, - bool assumeScalableSizes) { - assumeScalableVecSizesMatchDimSize = assumeScalableSizes; + bool assumeDimsMatchVec) { + assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec; // Initialize the insertion point. rewriter.setInsertionPoint(linalgOp); @@ -484,7 +481,7 @@ Value VectorizationState::getOrCreateMaskFor( return Value(); } - if (assumeScalableVecSizesMatchDimSize) { + if (assumeDynamicDimsMatchVecSizes) { // Given that all _scalable vector sizes_ match the corresponding // memref/tensor dim sizes, masking can be skipped provided that: // * all vector sizes corresponding to dynamic dims are scalable. @@ -2568,7 +2565,7 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) { FailureOr mlir::linalg::vectorize( RewriterBase &rewriter, Operation *op, ArrayRef inputVectorSizes, ArrayRef inputScalableVecDims, bool vectorizeNDExtract, - bool flatten1DDepthwiseConv, bool assumeScalableSizesMultipleOfDim) { + bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes) { LDBG("Attempting to vectorize:\n" << *op << "\n"); LDBG("Input vector sizes: "); LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); @@ -2589,7 +2586,7 @@ FailureOr mlir::linalg::vectorize( if (auto linalgOp = dyn_cast(op)) { if (failed(state.initState(rewriter, linalgOp, inputVectorSizes, inputScalableVecDims, - assumeScalableSizesMultipleOfDim))) { + assumeDynamicDimsMatchVecSizes))) { LDBG("Vectorization state couldn't be initialized\n"); return failure(); }