Skip to content

[mlir][linalg] Add support for scalable vectorization of linalg.mmt4d #146531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2440,12 +2440,11 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
}];

let arguments = (ins TransformHandleTypeInterface:$target,
Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
$static_vector_sizes,
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
$scalable_sizes);
Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
OptionalAttr<UnitAttr>:$assume_scalable_sizes_match_dim_size,
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);

let results = (outs);

Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,8 @@ FailureOr<VectorizationResult>
vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
ArrayRef<bool> inputScalableVecDims = {},
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
bool assumeScalableSizesMultipleOfDim = false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have assumeScalableSizesMultipleOfDim and getAssumeScalableSizesMatchDimSize. Should we have only one?
Also, it should be "dim multiple of vector sizes"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, I went with assumeDynamicDimsMatchVecSizes, see this commit. As mentioned in my other comment, for now I am "assuming" equality rather than divisibility. Not sure whether we will need the latter?


/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3921,7 +3921,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
}
FailureOr<VectorizationResult> vectorResults =
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
getVectorizeNdExtract().value_or(false));
getVectorizeNdExtract().value_or(false), false,
getAssumeScalableSizesMatchDimSize().value_or(false));
if (failed(vectorResults)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Attempted to vectorize, but failed";
Expand Down
79 changes: 55 additions & 24 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some changes on comments are not relevant, can you revert them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how that crept it. Fixed in this commit.

Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,11 @@ struct VectorizationState {
/// canonical vector shape for vectorization.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims);
ArrayRef<bool> inputScalableVecDims,
bool assumeScalableVecSizesMatchDimSize = false);

/// Returns the canonical vector shape used to vectorize the iteration space.
/// Returns the canonical vector shape used to vectorize the iteration
/// space.
ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }

/// Returns the vector dimensions that are scalable in the canonical vector
Expand All @@ -233,8 +235,8 @@ struct VectorizationState {

/// Returns a vector type of the provided `elementType` with the canonical
/// vector shape and the corresponding fixed/scalable dimensions bit. If
/// `dimPermutation` is provided, the canonical vector dimensions are permuted
/// accordingly.
/// `dimPermutation` is provided, the canonical vector dimensions are
/// permuted accordingly.
VectorType getCanonicalVecType(
Type elementType,
std::optional<AffineMap> dimPermutation = std::nullopt) const {
Expand All @@ -254,9 +256,9 @@ struct VectorizationState {
}

/// Masks an operation with the canonical vector mask if the operation needs
/// masking. Returns the masked operation or the original operation if masking
/// is not needed. If provided, the canonical mask for this operation is
/// permuted using `maybeIndexingMap`.
/// masking. Returns the masked operation or the original operation if
/// masking is not needed. If provided, the canonical mask for this
/// operation is permuted using `maybeIndexingMap`.
Operation *
maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
std::optional<AffineMap> maybeIndexingMap = std::nullopt);
Expand All @@ -276,15 +278,15 @@ struct VectorizationState {

/// Create or retrieve an existing mask value to mask `opToMask` in the
/// canonical vector iteration space. If `maybeMaskingMap` the mask is
/// permuted using that permutation map. If a new mask is created, it will be
/// cached for future users.
/// permuted using that permutation map. If a new mask is created, it will
/// be cached for future users.
Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
LinalgOp linalgOp,
std::optional<AffineMap> maybeMaskingMap);

/// Check whether this permutation map can be used for masking. At the
/// moment we only make sure that there are no broadcast dimensions, but this
/// might change if indexing maps evolve.
/// moment we only make sure that there are no broadcast dimensions, but
/// this might change if indexing maps evolve.
bool isValidMaskingMap(AffineMap maskingMap) {
return maskingMap.getBroadcastDims().size() == 0;
}
Expand Down Expand Up @@ -324,13 +326,24 @@ struct VectorizationState {
/// shape.
SmallVector<bool> scalableVecDims;

/// Holds the active masks for permutations of the canonical vector iteration
/// space.
/// Holds the active masks for permutations of the canonical vector
/// iteration space.
DenseMap<AffineMap, Value> activeMaskCache;

/// Global vectorization guard for the incoming rewriter. It's initialized
/// when the vectorization state is initialized.
OpBuilder::InsertionGuard rewriterGuard;

/// Do all scalable vector sizes match the corresponding input dim sizes?
/// (tensor or memref)
///
/// At the Tensor + MemRef levels, scalable sizes are modelled using
/// dynamic dimensions (i.e. `?`). In many cases these sizes result from
/// e.g. "scalable packing + tiling" and are known to always match the
/// scalable vector sizes. In such cases, masking can be safely skipped,
/// despite the presence of dynamic shapes. Use this flag with care and
/// only for cases where you are confident the assumption holds.
bool assumeScalableVecSizesMatchDimSize = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this generic, not only for scalable vectors? Assuming a dynamic dimension is multiple of a vector size, scalable or not, is useful in general.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should also add a TODO, as this eventually should be a list containing the divisibility information for each dimension...

Copy link
Contributor

@hanhanW hanhanW Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having the bool option could be beneficial to IREE because IREE may use TensorDynamicDimAnalysis and can use such information to determine if the assumption is true or not. The analysis uses two MLIR analysis and one IREE specific analysis, which may be upstreamable. I seldom work on this area, so I'm not pretty sure about it. I mainly wanna share a data-point about "assume" mechanism that I learned before, which seems related to @dcaballe's comment.

TensorDynamicDimAnalysis::TensorDynamicDimAnalysis(Operation *rootOp)
    : rootOperation(rootOp) {
  solver.load<mlir::dataflow::DeadCodeAnalysis>();
  solver.load<mlir::dataflow::IntegerRangeAnalysis>();
  solver.load<IREE::Util::IntegerDivisibilityAnalysis>();
}

In IREE, we have the hint in IR and we can apply above analysis to retrieve the information. It is IREE specific because the assume.int op is defined by IREE core dialects. E.g.,

   %0:2 = util.assume.int
       %m_in<umin = 16, umax = 4080, udiv = 16>,
       %k2_in<umin = 16, umax = 4080, udiv = 32>
     : index, index

However, IIUC, what's missing is something like ValueRangeAnalysis (not just for integers) because what you wonder is if the size if multiple of vscale or not?

Anyway, I think it is a good start, and I feel that we are missing some analysis to determine if we can make the value true or not. I don't suggest kick in an analysis during vectorization because it could be expensive; my understanding is that it is better to run it once in the beginning of a pass, and we use it vectorize all the ops later.

(I'm happy to move the discussion to the issue, if it is better.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this generic, not only for scalable vectors?

Sorry, I over-indexed on "scalable vectors" .

Assuming a dynamic dimension is multiple of a vector size, scalable or not, is useful in general.

For my needs ATM, equality (rather than divisibility) is sufficient. Do you reckon that we will require divisibility in the future?

However, IIUC, what's missing is something like ValueRangeAnalysis (not just for integers) because what you wonder is if the size if multiple of vscale or not?

We do actually have and use ValueRangeAnalysis for "scalable vectors", see e.g. in Codegen/Transforms/Transforms.cpp in IREE. However, AFAIK, that analysis is costly and in the case of linalg.mmt4d, we just know up front that the assumption holds. Hence taking this approach rather through other means. But now I see that there's even more options in IREE that I should consider in the future, thanks for the pointers @hanhanW 🙏🏻

};

LogicalResult
Expand Down Expand Up @@ -367,10 +380,12 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
/// Initializes the vectorization state, including the computation of the
/// canonical vector shape for vectorization.
// TODO: Move this to the constructor when we can remove the failure cases.
LogicalResult
VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims) {
LogicalResult VectorizationState::initState(RewriterBase &rewriter,
LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
bool assumeScalableSizes) {
assumeScalableVecSizesMatchDimSize = assumeScalableSizes;
// Initialize the insertion point.
rewriter.setInsertionPoint(linalgOp);

Expand Down Expand Up @@ -470,6 +485,21 @@ Value VectorizationState::getOrCreateMaskFor(
return Value();
}

if (assumeScalableVecSizesMatchDimSize) {
// Given that all _scalable vector sizes_ match the corresponding
// memref/tensor dim sizes, masking can be skipped provided that:
// * all vector sizes corresponding to dynamic dims are scalable.
if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getScalableDims()),
[](auto it) {
return std::get<0>(it) == ShapedType::kDynamic
? std::get<1>(it)
: false;
Comment on lines +490 to +492
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I misinterpreting this for some reason or does this imply: "if every size in permutedStaticSizes dynamic and the corresponding dim is scalable". Also, it just prints the debug print and does not guard the following lines that set the mask to nothing and return

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

}))
LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
activeMaskCache[maskingMap] = Value();
return Value();
}

// Permute the iteration space value sizes to compute the mask upper bounds.
SmallVector<Value> upperBounds =
applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
Expand Down Expand Up @@ -2479,7 +2509,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
isa<linalg::MatmulTransposeAOp>(op) ||
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if there is a particular reason why this wouldn't work for batch_mmt4d as well?

hasReductionIterator(linalgOp));
}

LogicalResult mlir::linalg::vectorizeOpPrecondition(
Expand Down Expand Up @@ -2535,11 +2566,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
tensor::InsertSliceOp>(op);
}

FailureOr<VectorizationResult>
mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
FailureOr<VectorizationResult> mlir::linalg::vectorize(
RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
bool flatten1DDepthwiseConv, bool assumeScalableSizesMultipleOfDim) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
Expand All @@ -2559,7 +2589,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
VectorizationState state(rewriter);
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
inputScalableVecDims))) {
inputScalableVecDims,
assumeScalableSizesMultipleOfDim))) {
LDBG("Vectorization state couldn't be initialized\n");
return failure();
}
Expand Down
117 changes: 93 additions & 24 deletions mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,99 @@ module attributes {transform.with_named_sequence} {
}
}

// -----

///----------------------------------------------------------------------------------------
/// Tests for linalg.mmt4d
///----------------------------------------------------------------------------------------

func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
outs(%C_in: memref<16x16x8x8xf32>)
return
}

// CHECK-LABEL: func.func @mmt4d(
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %mmt4d : !transform.any_op
transform.yield
}
}

// -----

func.func @mmt4d_scalable(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
outs(%C_in: memref<16x16x8x?xf32>)
return
}
// CHECK-LABEL: func.func @mmt4d_scalable(
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>,
// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
// CHECK: %[[VAL_0:.*]] = arith.constant 16 : index
// CHECK: %[[VAL_1:.*]] = arith.constant 16 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 16 : index
// CHECK: %[[C8:.*]] = arith.constant 8 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[DIM_2:.*]] = memref.dim %[[B]], %[[C2]] : memref<16x16x?x1xf32>
// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[MASK_1:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x[4]x1xi1>
// CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> } : vector<16x16x[4]x1xi1> -> vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
// CHECK: %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
// CHECK: %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x16x8x[4]x1xi1>
// CHECK: %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16]], %[[VAL_15]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
// CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>


module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] : !transform.any_op
transform.yield
}
}

// -----

func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
outs(%C_in: memref<16x16x8x?xf32>)
return
}
// CHECK-LABEL: func.func @mmt4d_scalable_with_assume(
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>,
// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
// CHECK-NOT: mask
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
// CHECK: %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[VAL_15:.*]] = vector.multi_reduction <add>, %[[VAL_14]], %[[VAL_13]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
// CHECK: vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] {assume_scalable_sizes_match_dim_size} : !transform.any_op
transform.yield
}
}

///----------------------------------------------------------------------------------------
/// Tests for other Ops
///----------------------------------------------------------------------------------------
Expand Down Expand Up @@ -1094,30 +1187,6 @@ module attributes {transform.with_named_sequence} {
}
}

// -----

func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
outs(%C_in: memref<16x16x8x8xf32>)
return
}

// CHECK-LABEL: func.func @mmt4d(
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %mmt4d : !transform.any_op
transform.yield
}
}

// -----

Expand Down
Loading