diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h index f6b09deb4e44c..cfb6cc313bc63 100644 --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -13,12 +13,6 @@ namespace mlir { class LLVMTypeConverter; -/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix -/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics -/// will be needed when invoking LLVM. -void populateVectorToLLVMMatrixConversionPatterns( - const LLVMTypeConverter &converter, RewritePatternSet &patterns); - /// Collect a set of patterns to convert from the Vector dialect to LLVM. void populateVectorToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns, diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index ec2c87ca1cf44..d16f82edc9c65 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2788,124 +2788,6 @@ def Vector_PrintOp : }]; } -//===----------------------------------------------------------------------===// -// Ops used for supporting progressive lowering and conversion type changes. -// The Ops are typically not used directly by higher level dialects, but are -// used by intra-dialect rewriting rules to bring vector operations closer -// to the hardware ISA. -//===----------------------------------------------------------------------===// - -/// Vector dialect matrix multiplication op that operates on flattened 1-D -/// MLIR vectors. This is the counterpart of llvm.matrix.multiply in MLIR. -/// This may seem redundant with vector.contract but it serves the purposes of -/// more progressive lowering and localized type conversion on the path: -/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`. -def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure, - PredOpTrait<"lhs operand and result have same element type", - TCresVTEtIsSameAsOpBase<0, 0>>, - PredOpTrait<"rhs operand and result have same element type", - TCresVTEtIsSameAsOpBase<0, 1>>]>, - Arguments<( - // TODO: tighten vector element types that make sense. - ins FixedVectorOfRankAndType<[1], - [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$lhs, - FixedVectorOfRankAndType<[1], - [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$rhs, - I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>, - Results<( - outs FixedVectorOfRankAndType<[1], - [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)> -{ - let summary = "Vector matrix multiplication op that operates on flattened 1-D" - " MLIR vectors"; - let description = [{ - This is the counterpart of llvm.matrix.multiply in MLIR. It serves the - purposes of more progressive lowering and localized type conversion. - Higher levels typically lower matrix multiplications into 'vector.contract' - operations. Subsequent rewriting rule progressively lower these operations - into 'vector.matrix_multiply' operations to bring the operations closer - to the hardware ISA. - - The ‘vector.matrix_multiply’ op treats `lhs` as matrix with rows - and columns, `rhs` as matrix with rows and - and multiplies them. The result matrix is returned embedded in - the result vector. - - Note, the corresponding LLVM intrinsic, `@llvm.matrix.multiply.*`, does not - support scalable vectors. Hence, this Op is only available for fixed-width - vectors. Also see: - - http://llvm.org/docs/LangRef.html#llvm-matrix-multiply-intrinsic - - Example: - - ```mlir - %C = vector.matrix_multiply %A, %B - { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } : - (vector<64xf64>, vector<48xf64>) -> vector<12xf64> - ``` - }]; - let builders = [ - OpBuilder<(ins "Value":$lhs, "Value":$rhs, "unsigned":$lhsRows, - "unsigned":$lhsColumns, "unsigned":$rhsColumns), - [{ - $_state.addOperands({lhs, rhs}); - $_state.addAttribute("lhs_rows",$_builder.getI32IntegerAttr(lhsRows)); - $_state.addAttribute("lhs_columns",$_builder.getI32IntegerAttr(lhsColumns)); - $_state.addAttribute("rhs_columns",$_builder.getI32IntegerAttr(rhsColumns)); - $_state.addTypes(VectorType::get(lhsRows * rhsColumns, - ::llvm::cast(lhs.getType()).getElementType())); - }]>, - ]; - let assemblyFormat = "$lhs `,` $rhs attr-dict " - "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)"; -} - -/// Vector dialect matrix transposition op that operates on flattened 1-D -/// MLIR vectors. This is the counterpart of llvm.matrix.transpose in MLIR. -/// This may seem redundant with vector.transpose but it serves the purposes of -/// more progressive lowering and localized type conversion on the path: -/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`. -def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [Pure, - PredOpTrait<"source operand and result have same element type", - TCresVTEtIsSameAsOpBase<0, 0>>]>, - Arguments<( - // TODO: tighten vector element types that make sense. - ins FixedVectorOfRankAndType<[1], - [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$matrix, - I32Attr:$rows, I32Attr:$columns)>, - Results<( - outs FixedVectorOfRankAndType<[1], - [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)> { - let summary = "Vector matrix transposition on flattened 1-D MLIR vectors"; - let description = [{ - This is the counterpart of llvm.matrix.transpose in MLIR. It serves - the purposes of more progressive lowering and localized type conversion. - Higher levels typically lower matrix transpositions into 'vector.transpose' - operations. Subsequent rewriting rule progressively lower these operations - into 'vector.flat_transpose' operations to bring the operations closer - to the hardware ISA. - - The `vector.flat_transpose` op treats the 1-D input `matrix` as - a 2-D matrix with rows and columns, and returns the - transposed matrix in flattened form in 'res'. - - Note, the corresponding LLVM intrinsic, `@llvm.matrix.transpose.*`, does not - support scalable vectors. Hence, this Op is only available for fixed-width - vectors. Also see: - - http://llvm.org/docs/LangRef.html#llvm-matrix-transpose-intrinsic - - Example: - - ```mlir - %1 = vector.flat_transpose %0 {columns = 4 : i32, rows = 4 : i32} - : vector<16xf32> -> vector<16xf32> - ``` - }]; - let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)"; -} - //===----------------------------------------------------------------------===// // SplatOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 14cff4ff893b5..a82be92251b32 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -297,6 +297,12 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, /// n > 1. void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns); +/// TODO +void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns); + +/// TODO +void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns); + } // namespace vector } // namespace mlir #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 501d98862672d..6e6400ba6e60e 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -184,41 +184,6 @@ class VectorBitCastOpConversion } }; -/// Conversion pattern for a vector.matrix_multiply. -/// This is lowered directly to the proper llvm.intr.matrix.multiply. -class VectorMatmulOpConversion - : public ConvertOpToLLVMPattern { -public: - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), - adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), - matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); - return success(); - } -}; - -/// Conversion pattern for a vector.flat_transpose. -/// This is lowered directly to the proper llvm.intr.matrix.transpose. -class VectorFlatTransposeOpConversion - : public ConvertOpToLLVMPattern { -public: - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - transOp, typeConverter->convertType(transOp.getRes().getType()), - adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); - return success(); - } -}; - /// Overloaded utility that replaces a vector.load, vector.store, /// vector.maskedload and vector.maskedstore with their respective LLVM /// couterparts. @@ -2035,6 +2000,190 @@ struct VectorScalableStepOpLowering } }; +/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to: +/// ``` +/// %flattened_a = vector.shape_cast %a +/// %flattened_b = vector.shape_cast %b +/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b +/// %d = vector.shape_cast %%flattened_d +/// %e = add %c, %d +/// ``` +/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`. +// +/// This only kicks in when vectorContractLowering is set to Matmul and +/// the vector.contract op is a row-major matrix multiply. +class ContractionOpToMatmulOpLowering + : public vector::MaskableOpRewritePattern { +public: + using MaskableOpRewritePattern::MaskableOpRewritePattern; + + ContractionOpToMatmulOpLowering( + vector::VectorContractLowering vectorContractLowering, + MLIRContext *context, PatternBenefit benefit = 100) + : MaskableOpRewritePattern(context, benefit) {} + + FailureOr + matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp, + PatternRewriter &rewriter) const override; +}; + +/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to: +/// ``` +/// %mta = maybe_transpose +/// %mtb = maybe_transpose +/// %flattened_a = vector.shape_cast %mta +/// %flattened_b = vector.shape_cast %mtb +/// %flattened_d = llvm.intr.matrix.multiply %flattened_a, %flattened_b +/// %mtd = vector.shape_cast %flattened_d +/// %d = maybe_untranspose %mtd +/// %e = add %c, %d +/// ``` +// +/// This only kicks in when vectorContractLowering is set to `Matmul`. +/// vector.transpose operations are inserted if the vector.contract op is not a +/// row-major matrix multiply. +/// +/// Scalable vectors are not supported. +FailureOr ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp( + vector::ContractionOp op, MaskingOpInterface maskOp, + PatternRewriter &rew) const { + // TODO: Support vector.mask. + if (maskOp) + return failure(); + + auto iteratorTypes = op.getIteratorTypes().getValue(); + if (!isParallelIterator(iteratorTypes[0]) || + !isParallelIterator(iteratorTypes[1]) || + !isReductionIterator(iteratorTypes[2])) + return failure(); + + Type opResType = op.getType(); + VectorType vecType = dyn_cast(opResType); + if (vecType && vecType.isScalable()) { + // Note - this is sufficient to reject all cases with scalable vectors. + return failure(); + } + + Type elementType = op.getLhsType().getElementType(); + if (!elementType.isIntOrFloat()) + return failure(); + + Type dstElementType = vecType ? vecType.getElementType() : opResType; + if (elementType != dstElementType) + return failure(); + + // Perform lhs + rhs transpositions to conform to matmul row-major semantics. + // Bail out if the contraction cannot be put in this form. + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + AffineExpr m, n, k; + bindDims(rew.getContext(), m, n, k); + // LHS must be A(m, k) or A(k, m). + Value lhs = op.getLhs(); + auto lhsMap = op.getIndexingMapsArray()[0]; + if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) + lhs = rew.create(loc, lhs, ArrayRef{1, 0}); + else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) + return failure(); + + // RHS must be B(k, n) or B(n, k). + Value rhs = op.getRhs(); + auto rhsMap = op.getIndexingMapsArray()[1]; + if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) + rhs = rew.create(loc, rhs, ArrayRef{1, 0}); + else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) + return failure(); + + // At this point lhs and rhs are in row-major. + VectorType lhsType = cast(lhs.getType()); + VectorType rhsType = cast(rhs.getType()); + int64_t lhsRows = lhsType.getDimSize(0); + int64_t lhsColumns = lhsType.getDimSize(1); + int64_t rhsColumns = rhsType.getDimSize(1); + + Type flattenedLHSType = + VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); + lhs = rew.create(loc, flattenedLHSType, lhs); + + Type flattenedRHSType = + VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); + rhs = rew.create(loc, flattenedRHSType, rhs); + + Value mul = rew.create( + loc, + VectorType::get(lhsRows * rhsColumns, + cast(lhs.getType()).getElementType()), + lhs, rhs, lhsRows, lhsColumns, rhsColumns); + + mul = rew.create( + loc, + VectorType::get({lhsRows, rhsColumns}, + getElementTypeOrSelf(op.getAcc().getType())), + mul); + + // ACC must be C(m, n) or C(n, m). + auto accMap = op.getIndexingMapsArray()[2]; + if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) + mul = rew.create(loc, mul, ArrayRef{1, 0}); + else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) + llvm_unreachable("invalid contraction semantics"); + + Value res = + isa(elementType) + ? static_cast(rew.create(loc, op.getAcc(), mul)) + : static_cast( + rew.create(loc, op.getAcc(), mul)); + + return res; +} + +/// Progressive lowering of TransposeOp. +/// One: +/// %x = vector.transpose %y, [1, 0] +/// is replaced by: +/// %z = arith.constant dense<0.000000e+00> +/// %0 = vector.extract %y[0, 0] +/// %1 = vector.insert %0, %z [0, 0] +/// .. +/// %x = vector.insert .., .. [.., ..] +class TransposeOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransposeOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value input = op.getVector(); + VectorType inputType = op.getSourceVectorType(); + VectorType resType = op.getResultVectorType(); + + if (inputType.isScalable()) + return rewriter.notifyMatchFailure( + op, "This lowering does not support scalable vectors"); + + // Set up convenience transposition table. + ArrayRef transp = op.getPermutation(); + + if (resType.getRank() != 2 || transp[0] != 1 || transp[1] != 0) { + return failure(); + } + + Type flattenedType = + VectorType::get(resType.getNumElements(), resType.getElementType()); + auto matrix = + rewriter.create(loc, flattenedType, input); + auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); + auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); + Value trans = rewriter.create( + loc, flattenedType, matrix, rows, columns); + rewriter.replaceOpWithNewOp(op, resType, trans); + return success(); + } +}; + } // namespace void mlir::vector::populateVectorRankReducingFMAPattern( @@ -2042,6 +2191,26 @@ void mlir::vector::populateVectorRankReducingFMAPattern( patterns.add(patterns.getContext()); } +/// Pattern to lower `vector.contract` to `llvm.intr.matrix.multiply`. +/// +/// Given the high benefit, this will be prioriotised over other +/// contract-lowering patterns. As such, the convert-vector-to-llvm pass will +/// only run this registration conditionally. +void mlir::vector::populateVectorContractToMatrixMultiply( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext(), 100); +} + +/// Pattern to lower `vector.transpose` to `llvm.intr.matrix.flat_transpose`. +/// +/// Given the high benefit, this will be prioriotised over other +/// transpose-lowering patterns. As such, the convert-vector-to-llvm pass will +/// only run this registration conditionally. +void mlir::vector::populateVectorTransposeToFlatTranspose( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext(), 100); +} + /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns, @@ -2071,12 +2240,6 @@ void mlir::populateVectorToLLVMConversionPatterns( converter); } -void mlir::populateVectorToLLVMMatrixConversionPatterns( - const LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add(converter); - patterns.add(converter); -} - namespace { struct VectorToLLVMDialectInterface : public ConvertToLLVMPatternInterface { using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 549d0210af7ad..b8eb0f58e98ca 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -70,10 +70,17 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorBitCastLoweringPatterns(patterns); populateVectorBroadcastLoweringPatterns(patterns); populateVectorContractLoweringPatterns(patterns, vectorContractLowering); + if (vectorContractLowering == vector::VectorContractLowering::Matmul) { + populateVectorContractToMatrixMultiply(patterns); + } populateVectorMaskOpLoweringPatterns(patterns); populateVectorShapeCastLoweringPatterns(patterns); populateVectorInterleaveLoweringPatterns(patterns); populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering); + if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat) { + populateVectorTransposeToFlatTranspose(patterns); + } + populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering); // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); populateVectorMaskMaterializationPatterns(patterns, @@ -96,11 +103,9 @@ void ConvertVectorToLLVMPass::runOnOperation() { LLVMTypeConverter converter(&getContext(), options); RewritePatternSet patterns(&getContext()); populateVectorTransferLoweringPatterns(patterns); - populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns( converter, patterns, reassociateFPReductions, force32BitVectorIndices, useVectorAlignment); - populateVectorToLLVMMatrixConversionPatterns(converter, patterns); // Architecture specific augmentations. LLVMConversionTarget target(getContext()); diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index 62022bfb7df1e..f14264e2f55f3 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -118,9 +118,9 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality( return converter.isLegal(op); }); // Manually mark arithmetic-performing vector instructions. - target.addDynamicallyLegalOp< - vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp, - vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>( + target.addDynamicallyLegalOp( [&](Operation *op) { return converter.isLegal(op); }); target.addLegalOp(); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index c6627b5ec0d77..f3d438f78da66 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -210,49 +211,6 @@ static Value createMul(Location loc, Value x, Value y, bool isInt, namespace { -/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul -/// semantics to: -/// ``` -/// %flattened_a = vector.shape_cast %a -/// %flattened_b = vector.shape_cast %b -/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b -/// %d = vector.shape_cast %%flattened_d -/// %e = add %c, %d -/// ``` -/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`. -// -/// This only kicks in when vectorContractLowering is set to Matmul and -/// the vector.contract op is a row-major matrix multiply. -class ContractionOpToMatmulOpLowering - : public vector::MaskableOpRewritePattern { -public: - using MaskableOpRewritePattern::MaskableOpRewritePattern; - - using FilterConstraintType = - std::function; - - static LogicalResult defaultFilter(vector::ContractionOp op) { - return success(); - } - - ContractionOpToMatmulOpLowering( - vector::VectorContractLowering vectorContractLowering, - MLIRContext *context, PatternBenefit benefit = 1, - FilterConstraintType constraint = defaultFilter) - : MaskableOpRewritePattern(context, benefit), - vectorContractLowering(vectorContractLowering), - filter(std::move(constraint)) {} - - FailureOr - matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp, - PatternRewriter &rewriter) const override; - -private: - /// Options to control the vector patterns. - vector::VectorContractLowering vectorContractLowering; - FilterConstraintType filter; -}; - /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul /// semantics to a reduction_size-unrolled sequence: /// ``` @@ -948,24 +906,18 @@ FailureOr ContractionOpLowering::matchAndRewriteMaskableOp( // TODO: implement benefits, cost models. MLIRContext *ctx = op.getContext(); - ContractionOpToMatmulOpLowering pat1(vectorContractLoweringOption, ctx); + ContractionOpToOuterProductOpLowering pat1(vectorContractLoweringOption, ctx); FailureOr newVal1 = pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter); if (!failed(newVal1)) return newVal1; - ContractionOpToOuterProductOpLowering pat2(vectorContractLoweringOption, ctx); + ContractionOpToDotLowering pat2(vectorContractLoweringOption, ctx); FailureOr newVal2 = pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter); if (!failed(newVal2)) return newVal2; - ContractionOpToDotLowering pat3(vectorContractLoweringOption, ctx); - FailureOr newVal3 = - pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter); - if (!failed(newVal3)) - return newVal3; - ContractOpToElementwise pat4(vectorContractLoweringOption, ctx); FailureOr newVal4 = pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter); @@ -1273,118 +1225,6 @@ class OuterProductOpLowering : public OpRewritePattern { } }; -/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul -/// semantics to: -/// ``` -/// %mta = maybe_transpose -/// %mtb = maybe_transpose -/// %flattened_a = vector.shape_cast %mta -/// %flattened_b = vector.shape_cast %mtb -/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b -/// %mtd = vector.shape_cast %flattened_d -/// %d = maybe_untranspose %mtd -/// %e = add %c, %d -/// ``` -/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`. -// -/// This only kicks in when vectorContractLowering is set to `Matmul`. -/// vector.transpose operations are inserted if the vector.contract op is not a -/// row-major matrix multiply. -/// -/// Scalable vectors are not supported. -FailureOr ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp( - vector::ContractionOp op, MaskingOpInterface maskOp, - PatternRewriter &rew) const { - // TODO: Support vector.mask. - if (maskOp) - return failure(); - - if (vectorContractLowering != vector::VectorContractLowering::Matmul) - return failure(); - if (failed(filter(op))) - return failure(); - - auto iteratorTypes = op.getIteratorTypes().getValue(); - if (!isParallelIterator(iteratorTypes[0]) || - !isParallelIterator(iteratorTypes[1]) || - !isReductionIterator(iteratorTypes[2])) - return failure(); - - Type opResType = op.getType(); - VectorType vecType = dyn_cast(opResType); - if (vecType && vecType.isScalable()) { - // Note - this is sufficient to reject all cases with scalable vectors. - return failure(); - } - - Type elementType = op.getLhsType().getElementType(); - if (!elementType.isIntOrFloat()) - return failure(); - - Type dstElementType = vecType ? vecType.getElementType() : opResType; - if (elementType != dstElementType) - return failure(); - - // Perform lhs + rhs transpositions to conform to matmul row-major semantics. - // Bail out if the contraction cannot be put in this form. - MLIRContext *ctx = op.getContext(); - Location loc = op.getLoc(); - AffineExpr m, n, k; - bindDims(rew.getContext(), m, n, k); - // LHS must be A(m, k) or A(k, m). - Value lhs = op.getLhs(); - auto lhsMap = op.getIndexingMapsArray()[0]; - if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) - lhs = rew.create(loc, lhs, ArrayRef{1, 0}); - else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) - return failure(); - - // RHS must be B(k, n) or B(n, k). - Value rhs = op.getRhs(); - auto rhsMap = op.getIndexingMapsArray()[1]; - if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) - rhs = rew.create(loc, rhs, ArrayRef{1, 0}); - else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) - return failure(); - - // At this point lhs and rhs are in row-major. - VectorType lhsType = cast(lhs.getType()); - VectorType rhsType = cast(rhs.getType()); - int64_t lhsRows = lhsType.getDimSize(0); - int64_t lhsColumns = lhsType.getDimSize(1); - int64_t rhsColumns = rhsType.getDimSize(1); - - Type flattenedLHSType = - VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); - lhs = rew.create(loc, flattenedLHSType, lhs); - - Type flattenedRHSType = - VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); - rhs = rew.create(loc, flattenedRHSType, rhs); - - Value mul = rew.create(loc, lhs, rhs, lhsRows, lhsColumns, - rhsColumns); - mul = rew.create( - loc, - VectorType::get({lhsRows, rhsColumns}, - getElementTypeOrSelf(op.getAcc().getType())), - mul); - - // ACC must be C(m, n) or C(n, m). - auto accMap = op.getIndexingMapsArray()[2]; - if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) - mul = rew.create(loc, mul, ArrayRef{1, 0}); - else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) - llvm_unreachable("invalid contraction semantics"); - - Value res = - isa(elementType) - ? static_cast(rew.create(loc, op.getAcc(), mul)) - : static_cast( - rew.create(loc, op.getAcc(), mul)); - - return res; -} } // namespace void mlir::vector::populateVectorContractLoweringPatterns( @@ -1393,8 +1233,7 @@ void mlir::vector::populateVectorContractLoweringPatterns( bool disableOuterProductLowering) { if (!disableOuterProductLowering) patterns.add(patterns.getContext(), benefit); - patterns.add( + patterns.add( vectorContractLoweringOption, patterns.getContext(), benefit); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index 732e316c93381..34aceff2d46ff 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" @@ -329,21 +330,6 @@ class TransposeOpLowering : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "Options specifies lowering to shuffle"); - // Handle a true 2-D matrix transpose differently when requested. - if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat && - resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { - Type flattenedType = - VectorType::get(resType.getNumElements(), resType.getElementType()); - auto matrix = - rewriter.create(loc, flattenedType, input); - auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); - auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); - Value trans = rewriter.create( - loc, flattenedType, matrix, rows, columns); - rewriter.replaceOpWithNewOp(op, resType, trans); - return success(); - } - // Generate unrolled extract/insert ops. We do not unroll the rightmost // (i.e., highest-order) dimensions that are not transposed and leave them // in vector form to improve performance. Therefore, we prune those diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 64e51f5554628..72810b5dddaa3 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1424,36 +1424,6 @@ func.func @fma_scalable(%vec_1d: vector<[8]xf32>, %vec_2d: vector<2x[4]xf32>, %v return %0, %1, %2: vector<[8]xf32>, vector<2x[4]xf32>, vector<1x1x[1]xf32> } -// ----- - -//===----------------------------------------------------------------------===// -// vector.matrix_multiply -//===----------------------------------------------------------------------===// - -// 4x16 16x3 4x3 -func.func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> { - %C = vector.matrix_multiply %A, %B - { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } : - (vector<64xf64>, vector<48xf64>) -> vector<12xf64> - return %C: vector<12xf64> -} -// CHECK-LABEL: @matrix_ops -// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} { -// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32 -// CHECK-SAME: } : (vector<64xf64>, vector<48xf64>) -> vector<12xf64> - -// ----- - -func.func @matrix_ops_index(%A: vector<64xindex>, %B: vector<48xindex>) -> vector<12xindex> { - %C = vector.matrix_multiply %A, %B - { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } : - (vector<64xindex>, vector<48xindex>) -> vector<12xindex> - return %C: vector<12xindex> -} -// CHECK-LABEL: @matrix_ops_index -// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} { -// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32 -// CHECK-SAME: } : (vector<64xi64>, vector<48xi64>) -> vector<12xi64> // ----- @@ -1602,56 +1572,6 @@ func.func @create_mask_1d_scalable(%num_elems : index) -> vector<[4]xi1> { // ----- -//===----------------------------------------------------------------------===// -// vector.flat_transpose -//===----------------------------------------------------------------------===// - -func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> { - %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } - : vector<16xf32> -> vector<16xf32> - return %0 : vector<16xf32> -} - -// CHECK-LABEL: func @flat_transpose -// CHECK-SAME: %[[A:.*]]: vector<16xf32> -// CHECK: %[[T:.*]] = llvm.intr.matrix.transpose %[[A]] -// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} : -// CHECK-SAME: vector<16xf32> into vector<16xf32> -// CHECK: return %[[T]] : vector<16xf32> - -// ----- - -func.func @flat_transpose_index(%arg0: vector<16xindex>) -> vector<16xindex> { - %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } - : vector<16xindex> -> vector<16xindex> - return %0 : vector<16xindex> -} -// CHECK-LABEL: func @flat_transpose_index -// CHECK-SAME: %[[A:.*]]: vector<16xindex> -// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<16xindex> to vector<16xi64> -// CHECK: %[[T1:.*]] = llvm.intr.matrix.transpose %[[T0]] -// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} : -// CHECK-SAME: vector<16xi64> into vector<16xi64> -// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<16xi64> to vector<16xindex> -// CHECK: return %[[T2]] : vector<16xindex> - -// ----- - -func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> { - %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } - : vector<16xf32> -> vector<16xf32> - return %0 : vector<16xf32> -} - -// CHECK-LABEL: func @flat_transpose -// CHECK-SAME: %[[A:.*]]: vector<16xf32> -// CHECK: %[[T:.*]] = llvm.intr.matrix.transpose %[[A]] -// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} : -// CHECK-SAME: vector<16xf32> into vector<16xf32> -// CHECK: return %[[T]] : vector<16xf32> - -// ----- - //===----------------------------------------------------------------------===// // vector.gather // diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 5038646e1f026..dd00fd70ccc21 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1321,13 +1321,6 @@ func.func @transpose_dim_size_mismatch(%arg0: vector<11x7x3x2xi32>) { // ----- -func.func @flat_transpose_type_mismatch(%arg0: vector<16xf32>) { - // expected-error@+1 {{'vector.flat_transpose' op failed to verify that source operand and result have same element type}} - %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } : vector<16xf32> -> vector<16xf64> -} - -// ----- - func.func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>>) { // expected-error@+1 {{expects operand to be a memref with identity layout}} %0 = vector.type_cast %arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>> to memref> @@ -1939,26 +1932,6 @@ func.func @invalid_step_2d() { // ----- -func.func @matrix_multiply_scalable(%a: vector<[4]xf64>, %b: vector<4xf64>) { - // expected-error @+1 {{'vector.matrix_multiply' op operand #0 must be fixed-length vector of signless integer or signed integer or index or floating-point values of ranks 1, but got 'vector<[4]xf64>'}} - %c = vector.matrix_multiply %a, %b { - lhs_rows = 2: i32, - lhs_columns = 2: i32 , - rhs_columns = 2: i32 } - : (vector<[4]xf64>, vector<4xf64>) -> vector<4xf64> - - return -} - -// ----- - -func.func @flat_transpose_scalable(%arg0: vector<[16]xf32>) -> vector<[16]xf32> { - // expected-error @+1 {{'vector.flat_transpose' op operand #0 must be fixed-length vector of signless integer or signed integer or index or floating-point values of ranks 1, but got 'vector<[16]xf32>'}} - %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } - : vector<[16]xf32> -> vector<[16]xf32> - return %0 : vector<[16]xf32> -} - //===----------------------------------------------------------------------===// // vector.splat //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 10bf0f1620568..a24fe410c1440 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -738,22 +738,6 @@ func.func @transpose_int_0d(%arg0: vector) -> vector { return %0 : vector } -// CHECK-LABEL: @flat_transpose_fp -func.func @flat_transpose_fp(%arg0: vector<16xf32>) -> vector<16xf32> { - // CHECK: %[[X:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> - %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } : vector<16xf32> -> vector<16xf32> - // CHECK: return %[[X]] : vector<16xf32> - return %0 : vector<16xf32> -} - -// CHECK-LABEL: @flat_transpose_int -func.func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> { - // CHECK: %[[X:.*]] = vector.flat_transpose %{{.*}} {columns = 8 : i32, rows = 2 : i32} : vector<16xi32> -> vector<16xi32> - %0 = vector.flat_transpose %arg0 { rows = 2: i32, columns = 8: i32 } : vector<16xi32> -> vector<16xi32> - // CHECK: return %[[X]] : vector<16xi32> - return %0 : vector<16xi32> -} - // CHECK-LABEL: @vector_load_and_store_0d_scalar_memref func.func @vector_load_and_store_0d_scalar_memref(%memref : memref<200x100xf32>, %i : index, %j : index) { diff --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir index 08ac2ac5bb7d5..3950e54006eec 100644 --- a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s +// RUN: mlir-opt %s --convert-vector-to-llvm='vector-contract-lowering=matmul' | FileCheck %s #matmat_accesses = [ affine_map<(i, j, k) -> (i, k)>, @@ -10,31 +10,54 @@ iterator_types = ["parallel", "parallel", "reduction"] } -// CHECK-LABEL: func @matmul -// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, -// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, -// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> -// CHECK-DAG: %[[ub:.*]] = ub.poison : vector<8xf32> -// CHECK-DAG: %[[ub_0:.*]] = ub.poison : vector<12xf32> -// CHECK-DAG: %[[ub_1:.*]] = ub.poison : vector<2x3xf32> -// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<4xf32> from vector<2x4xf32> -// CHECK: %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[ub]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32> -// CHECK: %[[a2:.*]] = vector.extract %[[A]][1] : vector<4xf32> from vector<2x4xf32> -// CHECK: %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32> -// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<4x3xf32> -// CHECK: %[[b1:.*]] = vector.insert_strided_slice %[[b0]], %[[ub_0]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32> -// CHECK: %[[b2:.*]] = vector.extract %[[B]][1] : vector<3xf32> from vector<4x3xf32> -// CHECK: %[[b3:.*]] = vector.insert_strided_slice %[[b2]], %[[b1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32> -// CHECK: %[[b4:.*]] = vector.extract %[[B]][2] : vector<3xf32> from vector<4x3xf32> -// CHECK: %[[b5:.*]] = vector.insert_strided_slice %[[b4]], %[[b3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32> -// CHECK: %[[b6:.*]] = vector.extract %[[B]][3] : vector<3xf32> from vector<4x3xf32> -// CHECK: %[[b7:.*]] = vector.insert_strided_slice %[[b6]], %[[b5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32> -// CHECK: %[[mm1:.*]] = vector.matrix_multiply %[[a3]], %[[b7]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32> -// CHECK: %[[mm2:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> -// CHECK: %[[mm3:.*]] = vector.insert %[[mm2]], %[[ub_1]] [0] : vector<3xf32> into vector<2x3xf32> -// CHECK: %[[mm4:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> -// CHECK: %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32> -// CHECK: %[[mm6:.*]] = arith.addf %[[C]], %[[mm5]] : vector<2x3xf32> +// CHECK-LABEL: func.func @matmul( +// CHECK-SAME: %[[ARG0:.*]]: vector<2x4xf32>, +// CHECK-SAME: %[[ARG1:.*]]: vector<4x3xf32>, +// CHECK-SAME: %[[ARG2:.*]]: vector<2x3xf32>) -> vector<2x3xf32> { +// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<4x3xf32> to !llvm.array<4 x vector<3xf32>> +// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>> +// CHECK: %[[VAL_2:.*]] = ub.poison : vector<2x3xf32> +// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>> +// CHECK: %[[POISON_RHS:.*]] = ub.poison : vector<12xf32> +// CHECK: %[[POISON_LHS:.*]] = ub.poison : vector<8xf32> + +// ===> Extract LHS +// | ROW_1 | +// | ----- | --> | ROW_1 | ROW_2 | +// | ROW_2 | +// +// CHECK: %[[LHS_ROW_1:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<2 x vector<4xf32>> +// CHECK: %[[TP_1:.*]] = llvm.shufflevector %[[LHS_ROW_1]], %[[LHS_ROW_1]] [0, 1, 2, 3, 0, 0, 0, 0] : vector<4xf32> +// CHECK: %[[TP_2:.*]] = llvm.shufflevector %[[TP_1]], %[[POISON_LHS]] [0, 1, 2, 3, 12, 13, 14, 15] : vector<8xf32> +// CHECK: %[[LHS_ROW_2:.*]] = llvm.extractvalue %[[VAL_1]][1] : !llvm.array<2 x vector<4xf32>> +// CHECK: %[[TP_3:.*]] = llvm.shufflevector %[[LHS_ROW_2]], %[[LHS_ROW_2]] [0, 1, 2, 3, 0, 0, 0, 0] : vector<4xf32> +// CHECK: %[[LHS:.*]] = llvm.shufflevector %[[TP_3]], %[[TP_2]] [8, 9, 10, 11, 0, 1, 2, 3] : vector<8xf32> + +// == Extract RHS +// | ROW_1 | +// | ----- | +// | ROW_2 | +// | ----- | --> | ROW_1 | ROW_2 | ROW_3 | ROW_4 | +// | ROW_3 | +// | ----- | +// | ROW_4 | +// CHECK: %[[RHS_ROW_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.array<4 x vector<3xf32>> +// CHECK: %[[TP_4:.*]] = llvm.shufflevector %[[RHS_ROW_1]], %[[RHS_ROW_1]] [0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<3xf32> +// CHECK: %[[TP_5:.*]] = llvm.shufflevector %[[TP_4]], %[[POISON_RHS]] [0, 1, 2, 15, 16, 17, 18, 19, 20, 21, 22, 23] : vector<12xf32> +// CHECK: %[[RHS_ROW_2:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.array<4 x vector<3xf32>> +// CHECK: %[[TP_6:.*]] = llvm.shufflevector %[[RHS_ROW_2]], %[[RHS_ROW_2]] [0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<3xf32> +// CHECK: %[[TP_7:.*]] = llvm.shufflevector %[[TP_6]], %[[TP_5]] [12, 13, 14, 0, 1, 2, 18, 19, 20, 21, 22, 23] : vector<12xf32> +// CHECK: %[[RHS_ROW_3:.*]] = llvm.extractvalue %[[VAL_0]][2] : !llvm.array<4 x vector<3xf32>> +// CHECK: %[[TP_8:.*]] = llvm.shufflevector %[[RHS_ROW_3]], %[[RHS_ROW_3]] [0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<3xf32> +// CHECK: %[[TP_9:.*]] = llvm.shufflevector %[[TP_8]], %[[TP_7]] [12, 13, 14, 15, 16, 17, 0, 1, 2, 21, 22, 23] : vector<12xf32> +// CHECK: %[[RHS_ROW_4:.*]] = llvm.extractvalue %[[VAL_0]][3] : !llvm.array<4 x vector<3xf32>> +// CHECK: %[[TP_10:.*]] = llvm.shufflevector %[[RHS_ROW_4]], %[[RHS_ROW_4]] [0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<3xf32> +// CHECK: %[[RHS:.*]] = llvm.shufflevector %[[TP_10]], %[[TP_9]] [12, 13, 14, 15, 16, 17, 18, 19, 20, 0, 1, 2] : vector<12xf32> + +// ===> Matrix multiply +// CHECK: %[[MM:.*]] = llvm.intr.matrix.multiply %[[LHS]], %[[RHS]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32> +// CHECK: %[[RES:.*]] = arith.addf %[[ARG2]], %{{.*}} : vector<2x3xf32> +// CHECK: return %[[RES]] : vector<2x3xf32> func.func @matmul(%arg0: vector<2x4xf32>, %arg1: vector<4x3xf32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> { @@ -44,7 +67,7 @@ func.func @matmul(%arg0: vector<2x4xf32>, } // CHECK-LABEL: func @matmul_scalable -// CHECK-NOT: vector.matrix_multiply +// CHECK-NOT: llvm.intr.matrix.multiply func.func @matmul_scalable(%arg0: vector<2x4xf32>, %arg1: vector<4x[3]xf32>, %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> { @@ -52,19 +75,3 @@ func.func @matmul_scalable(%arg0: vector<2x4xf32>, : vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32> return %0 : vector<2x[3]xf32> } - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { - %f = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op - - transform.apply_patterns to %f { - transform.apply_patterns.vector.lower_contraction lowering_strategy = "matmulintrinsics" - } : !transform.any_op - - transform.apply_patterns to %f { - transform.apply_patterns.vector.lower_shape_cast - } : !transform.any_op - transform.yield - } -} diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir index a730f217f027d..7838aad1825bc 100644 --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -136,38 +136,6 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func @transpose( -func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> { - // CHECK: vector.shape_cast {{.*}} : vector<2x4xf32> to vector<8xf32> - // CHECK: vector.flat_transpose %{{.*}} {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> -> vector<8xf32> - // CHECK: vector.shape_cast {{.*}} : vector<8xf32> to vector<4x2xf32> - %0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32> - return %0 : vector<4x2xf32> -} - -/// Scalable vectors are not supported - -// CHECK-LABEL: func @transpose_scalable( -func.func @transpose_scalable(%arg0: vector<2x[4]xf32>) -> vector<[4]x2xf32> { - // CHECK-NOT: vector.shape_cast - // CHECK-NOT: vector.flat_transpose - // CHECK: vector.transpose - %0 = vector.transpose %arg0, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32> - return %0 : vector<[4]x2xf32> -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { - %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> - transform.apply_patterns to %func_op { - transform.apply_patterns.vector.lower_transpose lowering_strategy = "flat_transpose" - } : !transform.op<"func.func"> - transform.yield - } -} - -// ----- - // CHECK-LABEL: @transpose_shuffle16x16xf32( func.func @transpose_shuffle16x16xf32(%arg0: vector<16x16xf32>) -> vector<16x16xf32> { // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> diff --git a/mlir/test/Dialect/Vector/vector-transpose-to-matrix-intrinsics-transform.mlir b/mlir/test/Dialect/Vector/vector-transpose-to-matrix-intrinsics-transform.mlir new file mode 100644 index 0000000000000..94689fa0dfb88 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-transpose-to-matrix-intrinsics-transform.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt %s --convert-vector-to-llvm='vector-transpose-lowering=flat' --split-input-file | FileCheck %s + +// CHECK-LABEL: func @transpose( +func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> { + // CHECK: llvm.intr.matrix.transpose %{{.*}} {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> into vector<8xf32> + %0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32> + return %0 : vector<4x2xf32> +} + +/// Scalable vectors are not supported + +// CHECK-LABEL: func @transpose_scalable( +func.func @transpose_scalable(%arg0: vector<2x[4]xf32>) -> vector<[4]x2xf32> { + // CHECK-NOT: llvm.intr.matrix.transpose + // CHECK: vector.transpose + %0 = vector.transpose %arg0, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32> + return %0 : vector<[4]x2xf32> +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-col.mlir b/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-col.mlir index b414242b34cc0..86bd0b1e09763 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-col.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-col.mlir @@ -57,10 +57,10 @@ func.func @entry() { // ( 1, 4 ) -> ( 3, 4, 5 ) // ( 2, 5 ) // - %d = vector.flat_transpose %a { rows = 2: i32, columns = 2: i32 } : vector<4xf64> -> vector<4xf64> - %e = vector.flat_transpose %b { rows = 2: i32, columns = 2: i32 } : vector<4xf64> -> vector<4xf64> - %f = vector.flat_transpose %c { rows = 2: i32, columns = 3: i32 } : vector<6xf64> -> vector<6xf64> - %g = vector.flat_transpose %c { rows = 3: i32, columns = 2: i32 } : vector<6xf64> -> vector<6xf64> + %d = llvm.intr.matrix.transpose %a { rows = 2: i32, columns = 2: i32 } : vector<4xf64> into vector<4xf64> + %e = llvm.intr.matrix.transpose %b { rows = 2: i32, columns = 2: i32 } : vector<4xf64> into vector<4xf64> + %f = llvm.intr.matrix.transpose %c { rows = 2: i32, columns = 3: i32 } : vector<6xf64> into vector<6xf64> + %g = llvm.intr.matrix.transpose %c { rows = 3: i32, columns = 2: i32 } : vector<6xf64> into vector<6xf64> vector.print %d : vector<4xf64> vector.print %e : vector<4xf64> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-row.mlir b/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-row.mlir index 95b178e04a2bb..55103bc686fb2 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-row.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-row.mlir @@ -57,10 +57,10 @@ func.func @entry() { // ( 2, 3 ) -> ( 1, 3, 5 ) // ( 4, 5 ) // - %d = vector.flat_transpose %a { rows = 2: i32, columns = 2: i32 } : vector<4xf64> -> vector<4xf64> - %e = vector.flat_transpose %b { rows = 2: i32, columns = 2: i32 } : vector<4xf64> -> vector<4xf64> - %f = vector.flat_transpose %c { rows = 2: i32, columns = 3: i32 } : vector<6xf64> -> vector<6xf64> - %g = vector.flat_transpose %c { rows = 3: i32, columns = 2: i32 } : vector<6xf64> -> vector<6xf64> + %d = llvm.intr.matrix.transpose %a { rows = 2: i32, columns = 2: i32 } : vector<4xf64> into vector<4xf64> + %e = llvm.intr.matrix.transpose %b { rows = 2: i32, columns = 2: i32 } : vector<4xf64> into vector<4xf64> + %f = llvm.intr.matrix.transpose %c { rows = 2: i32, columns = 3: i32 } : vector<6xf64> into vector<6xf64> + %g = llvm.intr.matrix.transpose %c { rows = 3: i32, columns = 2: i32 } : vector<6xf64> into vector<6xf64> vector.print %d : vector<4xf64> vector.print %e : vector<4xf64> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-col.mlir b/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-col.mlir index 8f75ec98465ca..09941192cbc42 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-col.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-col.mlir @@ -39,7 +39,7 @@ func.func @entry() { // x = |/ | column-major! // ( 1, 3 ) (5, 7) ( 19, 27 ) // - %c = vector.matrix_multiply %a, %b + %c = llvm.intr.matrix.multiply %a, %b { lhs_rows = 2: i32, lhs_columns = 2: i32 , rhs_columns = 2: i32 } : (vector<4xf64>, vector<4xf64>) -> vector<4xf64> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-row.mlir b/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-row.mlir index b7d27c45226ef..d5f511c8ac119 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-row.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-row.mlir @@ -39,7 +39,7 @@ func.func @entry() { // x = // ( 2, 3 ) (6, 7) ( 26, 31 ) // - %c = vector.matrix_multiply %a, %b + %c = llvm.intr.matrix.multiply %a, %b { lhs_rows = 2: i32, lhs_columns = 2: i32 , rhs_columns = 2: i32 } : (vector<4xf64>, vector<4xf64>) -> vector<4xf64>