diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 4360055e78691..3bbde1240286c 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2445,6 +2445,8 @@ def VectorizeOp : Op: $static_vector_sizes, OptionalAttr:$vectorize_nd_extract, + OptionalAttr:$flatten1D_depthwise_conv, + OptionalAttr:$create_named_contraction, DefaultValuedOptionalAttr: $scalable_sizes); diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 74280fdd82f4e..eee856f3eba68 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -876,11 +876,14 @@ struct VectorizationResult { /// greater than or equal to their counterpart iteration space sizes, if static. /// `inputVectorShapes` also allows the vectorization of operations with dynamic /// shapes. +/// Optionally, `createNamedContraction` can force compatible contractions to be +/// vectorized directly to vector.contract operation. FailureOr vectorize(RewriterBase &rewriter, Operation *op, ArrayRef inputVectorSizes = {}, ArrayRef inputScalableVecDims = {}, - bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false); + bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false, + bool createNamedContraction = false); /// Emit a suitable vector form for a Copy op with fully static shape. LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index cc8421b23a074..9b765d0b8ede6 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -226,7 +226,8 @@ bool isLinearizableVector(VectorType type); /// Note: all read offsets are set to 0. Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef inputVectorSizes, Value padValue, - bool useInBoundsInsteadOfMasking = false); + bool useInBoundsInsteadOfMasking = false, + ArrayRef scalableDims = {}); /// Returns success if `inputVectorSizes` is a valid masking configuraion for /// given `shape`, i.e., it meets: diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 7ec3f3445281a..1038615f21b17 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3920,7 +3920,9 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply( } FailureOr vectorResults = linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(), - getVectorizeNdExtract().value_or(false)); + getVectorizeNdExtract().value_or(false), + getFlatten1DDepthwiseConv().value_or(false), + getCreateNamedContraction().value_or(false)); if (failed(vectorResults)) { return mlir::emitSilenceableFailure(target->getLoc()) << "Attempted to vectorize, but failed"; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 2864c5b807d69..bb1454cf43932 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -25,6 +25,7 @@ #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" @@ -1681,10 +1682,13 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, return write; // Compute the mask and mask the write Op. - auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type()); + auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(), + vecToStoreType.getScalableDims()); SmallVector destSizes = - tensor::getMixedSizes(builder, loc, dest); + isa(dest.getType()) + ? memref::getMixedSizes(builder, loc, dest) + : tensor::getMixedSizes(builder, loc, dest); SmallVector maskSizes(destSizes.end() - vecToStoreRank, destSizes.end()); @@ -2093,6 +2097,84 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, return success(); } +/// Vectorize a named linalg contraction op into: +/// vector::TransferReadOp - Reads vectors from the operands +/// vector::ContractionOp - Performs contraction +/// vector::TransferWriteOp - Write the result vector back to the +/// destination +/// The operands shapes are preserved and loaded directly into vectors. +/// Any further permutations or numerical casting remain within contraction. +static LogicalResult +vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, + LinalgOp linalgOp, + SmallVectorImpl &newResults) { + Location loc = linalgOp.getLoc(); + MLIRContext *ctx = linalgOp.getContext(); + + if (!isa(linalgOp.getOperation())) + return failure(); + + OpOperand *outOperand = linalgOp.getDpsInitOperand(0); + Operation *reduceOp = matchLinalgReduction(outOperand); + auto maybeKind = getCombinerOpKind(reduceOp); + if (!maybeKind) + return failure(); + + // Check that all dimensions are present in the input operands. + // Arbitrary broadcasts are not supported by the vector contraction. + // Broadcasts are expected to be materialized before vectorization. + AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0]; + AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1]; + if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) + return failure(); + + // Load operands. + SmallVector vecOperands; + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + // The operand vector shape is computed by mapping the canonical vector + // shape to the operand's domain. Further permutations are left as a part of + // the contraction. + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); + AffineMap readMap = AffineMap::getMultiDimIdentityMap( + indexingMap.getNumResults(), rewriter.getContext()); + Type elemType = getElementTypeOrSelf(opOperand.get()); + VectorType readType = + state.getCanonicalVecType(elemType, readMap.compose(indexingMap)); + + Value read = mlir::vector::createReadOrMaskedRead( + rewriter, loc, opOperand.get(), readType.getShape(), + /*padding=*/arith::getZeroConstant(rewriter, loc, elemType), + /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims()); + vecOperands.push_back(read); + } + + // Remap iterators from linalg to vector. + SmallVector iterAttrs; + auto iterators = linalgOp.getIteratorTypesArray(); + for (utils::IteratorType iter : iterators) { + auto vecIter = iter == utils::IteratorType::parallel + ? vector::IteratorType::parallel + : vector::IteratorType::reduction; + iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter)); + } + + // Create contraction. + Value contractOp = rewriter.create( + loc, /*lhs=*/vecOperands[0], + /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2], + linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind); + + // Store result. + Operation *write = + createWriteOrMaskedWrite(rewriter, loc, contractOp, outOperand->get()); + + // Finalize. + if (!write->getResults().empty()) + newResults.push_back(write->getResult(0)); + + return success(); +} + namespace { enum class ConvOperationKind { Conv, Pool }; } // namespace @@ -2528,11 +2610,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) { tensor::InsertSliceOp>(op); } -FailureOr -mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, - ArrayRef inputVectorSizes, - ArrayRef inputScalableVecDims, - bool vectorizeNDExtract, bool flatten1DDepthwiseConv) { +FailureOr mlir::linalg::vectorize( + RewriterBase &rewriter, Operation *op, ArrayRef inputVectorSizes, + ArrayRef inputScalableVecDims, bool vectorizeNDExtract, + bool flatten1DDepthwiseConv, bool createNamedContraction) { LDBG("Attempting to vectorize:\n" << *op << "\n"); LDBG("Input vector sizes: "); LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); @@ -2578,6 +2659,21 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, return failure(); } + // For simplicity, contraction vectorization is limited to linalg + // named ops. Generic op is ignored as not every arbitrary + // contraction body can be expressed by a vector.contract. + if (createNamedContraction && + isa(linalgOp.getOperation())) { + // Attempt vectorizing directly into a named contraction. + // In case of failure, fall back to the generic path. + LogicalResult res = vectorizeAsLinalgContraction( + rewriter, state, linalgOp, results); + if (succeeded(res)) + return success(); + + LDBG("Failed to vectorize as a named contraction.\n"); + } + LDBG("Vectorize generic by broadcasting to the canonical vector " "shape\n"); diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 674a93d7c520e..594bfce57e598 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -320,14 +320,16 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef inputVectorSizes, Value padValue, - bool useInBoundsInsteadOfMasking) { + bool useInBoundsInsteadOfMasking, + ArrayRef scalableDims) { assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) && "invalid input vector sizes"); auto sourceShapedType = cast(source.getType()); auto sourceShape = sourceShapedType.getShape(); assert(sourceShape.size() == inputVectorSizes.size() && "expected same ranks."); - auto vectorType = VectorType::get(inputVectorSizes, padValue.getType()); + auto vectorType = + VectorType::get(inputVectorSizes, padValue.getType(), scalableDims); assert(padValue.getType() == sourceShapedType.getElementType() && "expected same pad element type to match source element type"); int64_t readRank = inputVectorSizes.size(); @@ -352,9 +354,12 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking) return transferReadOp; SmallVector mixedSourceDims = - tensor::getMixedSizes(builder, loc, source); + isa(source.getType()) + ? memref::getMixedSizes(builder, loc, source) + : tensor::getMixedSizes(builder, loc, source); - auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type()); + auto maskType = + VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims); Value mask = builder.create(loc, maskType, mixedSourceDims); return mlir::vector::maskOperation(builder, transferReadOp, mask) diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir new file mode 100644 index 0000000000000..1831acf092afb --- /dev/null +++ b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir @@ -0,0 +1,400 @@ +// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s + +func.func @matmul(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>, + %C: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = linalg.matmul + ins(%A, %B : tensor<8x4xf32>, tensor<4x16xf32>) + outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func.func @matmul( +// CHECK-SAME: %[[A:.*]]: tensor<8x4xf32>, %[[B:.*]]: tensor<4x16xf32>, +// CHECK-SAME: %[[C:.*]]: tensor<8x16xf32>) +// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<8x4xf32>, vector<8x4xf32> +// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<4x16xf32>, vector<4x16xf32> +// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32> +// CHECK: %[[CONTRACT:.*]] = vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @matmul_dynamic(%A: tensor, %B: tensor, + %C: tensor) -> tensor { + %0 = linalg.matmul + ins(%A, %B : tensor, tensor) + outs(%C: tensor) -> tensor + return %0 : tensor +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func.func @matmul_dynamic( +// CHECK-SAME: %[[A:.*]]: tensor, %[[B:.*]]: tensor, +// CHECK-SAME: %[[C:.*]]: tensor) +// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor, vector<8x4xf32> +// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor, vector<4x16xf32> +// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor, vector<8x16xf32> +// CHECK: %[[CONTRACT:.*]] = vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [8, 16, 4] + {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @matmul_dynamic_memref(%A: memref, %B: memref, + %C: memref) { + linalg.matmul + ins(%A, %B : memref, memref) + outs(%C: memref) + return +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func.func @matmul_dynamic_memref( +// CHECK-SAME: %[[A:.*]]: memref, %[[B:.*]]: memref, +// CHECK-SAME: %[[C:.*]]: memref) +// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: memref, vector<8x4xf32> +// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: memref, vector<4x16xf32> +// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: memref, vector<8x16xf32> +// CHECK: %[[CONTRACT:.*]] = vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, memref + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [8, 16, 4] + {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @matmul_scalable(%A: tensor, %B: tensor, + %C: tensor) -> tensor { + %0 = linalg.matmul + ins(%A, %B : tensor, tensor) + outs(%C: tensor) -> tensor + return %0 : tensor +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func.func @matmul_scalable( +// CHECK-SAME: %[[A:.*]]: tensor, %[[B:.*]]: tensor, +// CHECK-SAME: %[[C:.*]]: tensor) +// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor, vector<8x4xf32> +// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor, vector<4x[16]xf32> +// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor, vector<8x[16]xf32> +// CHECK: %[[CONTRACT:.*]] = vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x[16]xf32>, tensor + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [8, [16], 4] + {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @matmul_transpose(%A: tensor<4x8xf32>, %B: tensor<16x4xf32>, + %C: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = linalg.matmul + indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose A + affine_map<(m, n, k) -> (n, k)>, // transpose B + affine_map<(m, n, k) -> (m, n)>] + ins(%A, %B : tensor<4x8xf32>, tensor<16x4xf32>) + outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func.func @matmul_transpose( +// CHECK-SAME: %[[A:.*]]: tensor<4x8xf32>, %[[B:.*]]: tensor<16x4xf32>, +// CHECK-SAME: %[[C:.*]]: tensor<8x16xf32>) +// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<4x8xf32>, vector<4x8xf32> +// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<16x4xf32>, vector<16x4xf32> +// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32> +// CHECK: %[[CONTRACT:.*]] = vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @matmul_dynamic_transpose(%A: tensor, %B: tensor, + %C: tensor) -> tensor { + %0 = linalg.matmul + indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose A + affine_map<(m, n, k) -> (n, k)>, // transpose B + affine_map<(m, n, k) -> (m, n)>] + ins(%A, %B : tensor, tensor) + outs(%C: tensor) -> tensor + return %0 : tensor +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func.func @matmul_dynamic_transpose( +// CHECK-SAME: %[[A:.*]]: tensor, %[[B:.*]]: tensor, +// CHECK-SAME: %[[C:.*]]: tensor) +// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor, vector<4x8xf32> +// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor, vector<16x4xf32> +// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor, vector<8x16xf32> +// CHECK: %[[CONTRACT:.*]] = vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [8, 16, 4] + {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @negative_matmul_broadcast(%A: tensor<4xf32>, %B: tensor<4x16xf32>, + %C: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = linalg.matmul + indexing_maps = [affine_map<(m, n, k) -> (k)>, // broadcast + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)>] + ins(%A, %B : tensor<4xf32>, tensor<4x16xf32>) + outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// CHECK-LABEL: func.func @negative_matmul_broadcast( +// CHECK-NOT: vector.contract +// CHECK: vector.multi_reduction + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @matmul_mixed_precision(%A: tensor<8x4xf16>, %B: tensor<4x16xf16>, + %C: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = linalg.matmul + ins(%A, %B : tensor<8x4xf16>, tensor<4x16xf16>) + outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func.func @matmul_mixed_precision( +// CHECK-SAME: %[[A:.*]]: tensor<8x4xf16>, %[[B:.*]]: tensor<4x16xf16>, +// CHECK-SAME: %[[C:.*]]: tensor<8x16xf32>) +// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<8x4xf16>, vector<8x4xf16> +// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<4x16xf16>, vector<4x16xf16> +// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32> +// CHECK: %[[CONTRACT:.*]] = vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @batch_matmul(%A: tensor<3x8x4xf16>, %B: tensor<3x4x16xf16>, + %C: tensor<3x8x16xf32>) -> tensor<3x8x16xf32> { + %0 = linalg.batch_matmul + ins(%A, %B : tensor<3x8x4xf16>, tensor<3x4x16xf16>) + outs(%C: tensor<3x8x16xf32>) -> tensor<3x8x16xf32> + return %0 : tensor<3x8x16xf32> +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK-LABEL: func.func @batch_matmul( +// CHECK-SAME: %[[A:.*]]: tensor<3x8x4xf16>, %[[B:.*]]: tensor<3x4x16xf16>, +// CHECK-SAME: %[[C:.*]]: tensor<3x8x16xf32>) +// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<3x8x4xf16>, vector<3x8x4xf16> +// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<3x4x16xf16>, vector<3x4x16xf16> +// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<3x8x16xf32>, vector<3x8x16xf32> +// CHECK: %[[CONTRACT:.*]] = vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<3x8x16xf32>, tensor<3x8x16xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @batch_reduce_matmul(%A: tensor<3x8x4xf16>, %B: tensor<3x4x16xf16>, + %C: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = linalg.batch_reduce_matmul + ins(%A, %B : tensor<3x8x4xf16>, tensor<3x4x16xf16>) + outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +// CHECK-LABEL: func.func @batch_reduce_matmul( +// CHECK-SAME: %[[A:.*]]: tensor<3x8x4xf16>, %[[B:.*]]: tensor<3x4x16xf16>, +// CHECK-SAME: %[[C:.*]]: tensor<8x16xf32>) +// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<3x8x4xf16>, vector<3x8x4xf16> +// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<3x4x16xf16>, vector<3x4x16xf16> +// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32> +// CHECK: %[[CONTRACT:.*]] = vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.batch_reduce_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @contract(%A: tensor<4x8x2xf16>, %B: tensor<8x16x2xf16>, + %C: tensor<4x16xf32>) -> tensor<4x16xf32> { + %0 = linalg.contract + indexing_maps = [affine_map<(m, n, k, vnni) -> (m, k, vnni)>, + affine_map<(m, n, k, vnni) -> (k, n, vnni)>, + affine_map<(m, n, k, vnni) -> (m, n)>] + ins(%A, %B : tensor<4x8x2xf16>, tensor<8x16x2xf16>) + outs(%C : tensor<4x16xf32>) -> tensor<4x16xf32> + return %0 : tensor<4x16xf32> +} + +// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> +// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)> +// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)> +// CHECK-LABEL: func.func @contract( +// CHECK-SAME: %[[A:.*]]: tensor<4x8x2xf16>, %[[B:.*]]: tensor<8x16x2xf16>, +// CHECK-SAME: %[[C:.*]]: tensor<4x16xf32>) +// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<4x8x2xf16>, vector<4x8x2xf16> +// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<8x16x2xf16>, vector<8x16x2xf16> +// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<4x16xf32>, vector<4x16xf32> +// CHECK: %[[CONTRACT:.*]] = vector.contract +// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]] +// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<4x16xf32>, tensor<4x16xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op + transform.yield + } +} + +// ----- + +// Generic is currently ignored in direct lowering to a named contraction. + +func.func @negative_generic(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>, + %C: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = linalg.generic { + indexing_maps = [affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%A, %B : tensor<8x4xf32>, tensor<4x16xf32>) + outs(%C : tensor<8x16xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %1 = arith.mulf %in, %in_0 : f32 + %2 = arith.addf %out, %1 : f32 + linalg.yield %2 : f32 + } -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// CHECK-LABEL: func.func @negative_generic( +// CHECK-NOT: vector.contract +// CHECK: vector.multi_reduction + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op + transform.yield + } +}