From c87e3c0bfc75d941af4748ee4d804ddb4f08ced3 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Fri, 4 Jul 2025 12:53:05 +0000 Subject: [PATCH 1/2] [MLIR][AArch64] Lower `vector.contract` to SVE FEAT_BF16 operations This patch adds lowering of Bfloat16 widening matrix multiply and accumulate `vector.contract`, by parametrising and refactoring the pattern for 8-bit integers. --- mlir/include/mlir/Conversion/Passes.td | 4 + .../TransformOps/ArmSVEVectorTransformOps.td | 12 +- .../Dialect/ArmSVE/Transforms/Transforms.h | 2 + .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 3 + .../LowerContractionToNeonI8MMPattern.cpp | 2 +- .../TransformOps/ArmSVEVectorTransformOps.cpp | 7 +- .../Dialect/ArmSVE/Transforms/CMakeLists.txt | 2 +- .../Transforms/LowerContractToSVEPatterns.cpp | 607 ++++++++++++++++++ .../LowerContractionToSVEI8MMPattern.cpp | 366 ----------- .../Vector/CPU/ArmSVE/vector-bfmmla.mlir | 105 +++ .../CPU/ArmSVE/vector-contract-bfmmla.mlir | 201 ++++++ .../CPU/ArmSVE/vector-contract-i8mm.mlir | 6 +- 12 files changed, 944 insertions(+), 373 deletions(-) create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp delete mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-bfmmla.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-bfmmla.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 5a864865adffc..4f304b39a0528 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1437,6 +1437,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of Arm FEAT_I8MM instructions while lowering " "the vector dialect.">, + Option<"armBF16", "enable-arm-bf16", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_BF16 instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td index 53784982be6dc..81b3c736b93f3 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td @@ -12,7 +12,7 @@ include "mlir/Dialect/Transform/IR/TransformAttrs.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" -def ApplyArmSVELowerContractionPatternsOp +def ApplyArmSVELowerContractionToI8MMPatternsOp : Op]> { let description = [{ @@ -23,4 +23,14 @@ def ApplyArmSVELowerContractionPatternsOp let assemblyFormat = "attr-dict"; } +def ApplyArmSVELowerContractionToBFMMLAPatternsOp + : Op]> { + let description = [{ + Indicates that vector contraction-like operations should be lowered to + finer-grained vector primitives using the ArmSVE dialect. + }]; + + let assemblyFormat = "attr-dict"; +} #endif // ARMSVE_VECTOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h index 232e2be29e574..de160dbf8ed94 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h @@ -23,6 +23,8 @@ void populateArmSVELegalizeForLLVMExportPatterns( void populateLowerContractionToSVEI8MMPatternPatterns( RewritePatternSet &patterns); +void populateLowerContractionToSVEBFMMLAPatterns(RewritePatternSet &patterns); + /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM /// intrinsics. void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 67c0eca15638a..4d74aabcaa50d 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -89,6 +89,9 @@ void ConvertVectorToLLVMPass::runOnOperation() { if (armSVE) populateLowerContractionToSVEI8MMPatternPatterns(patterns); } + if (armBF16) + populateLowerContractionToSVEBFMMLAPatterns(patterns); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp index 7180884c77e98..a95fc51d562c2 100644 --- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp +++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp @@ -12,7 +12,7 @@ // TODO: There may be opportunities to unify this with a similar pattern // for SVE. See: // https://github.com/llvm/llvm-project/issues/145559 -// LowerContractionToSVEI8MMPattern.cpp +// LowerContractToSVEPatterns.cpp // //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp index b2ca4fc1eaa8c..8572c34c8b12b 100644 --- a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp +++ b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp @@ -18,11 +18,16 @@ using namespace mlir; // Apply...PatternsOp //===----------------------------------------------------------------------===// -void transform::ApplyArmSVELowerContractionPatternsOp::populatePatterns( +void transform::ApplyArmSVELowerContractionToI8MMPatternsOp::populatePatterns( RewritePatternSet &patterns) { mlir::populateLowerContractionToSVEI8MMPatternPatterns(patterns); } +void transform::ApplyArmSVELowerContractionToBFMMLAPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + mlir::populateLowerContractionToSVEBFMMLAPatterns(patterns); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt index 65f98b44b1b69..c29eaca244b4a 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt @@ -1,7 +1,7 @@ add_mlir_dialect_library(MLIRArmSVETransforms LegalizeForLLVMExport.cpp LegalizeVectorStorage.cpp - LowerContractionToSVEI8MMPattern.cpp + LowerContractToSVEPatterns.cpp DEPENDS MLIRArmSVEConversionsIncGen diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp new file mode 100644 index 0000000000000..2987287afe9cd --- /dev/null +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp @@ -0,0 +1,607 @@ +//===- LowerContractToSVEPatterns.cpp - Contract to I8MM/BF16 ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements lowering patterns from vector.contract to operations +// that map to instructions from the SVE FEAT_I8MM and FEAT_BF16 extensions. +// +// TODO: There may be opportunities to unify this with a similar pattern +// for Neon. See: +// https://github.com/llvm/llvm-project/issues/145559 +// LowerContractionToNeonI8MMPattern.cpp +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; + +namespace { +// Get the operand of a `vector.contract`. This function is intended to abstract +// away from the particular way a value is extended before feeding it into the +// `vector.contract` - via zero-extend or an explicit or implicit sign-extend +// (for implicit sign-extension see `vector.contract` documentation). +// +// The template parameter `Op` indicates the extension operation (explicit or +// implicit) for which we are checking. +// +// Return success only for extensions from `i8` to `i32`. +template +std::optional getExtOperand(Value v) { + + static_assert(llvm::is_one_of::value, + "Must be instantiated with either sign- or zero- extension op"); + + // If the operand is not defined by an explicit extend operation of the + // accepted operation type allow for an implicit sign-extension. + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) { + if constexpr (std::is_same::value) { + auto vTy = cast(v.getType()); + if (!vTy.getElementType().isSignlessInteger(8)) + return {}; + return v; + } + return {}; + } + + // If the operand is defined by an explicit extend operation of the accepted + // operation type, check it's extended from `i8` to `i32`. + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy || !inTy.getElementType().isSignlessInteger(8)) + return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!outTy || !outTy.getElementType().isSignlessInteger(32)) + return {}; + + return inOp; +} + +/// This class encapsulates the algorithm and parametrisation (in terms of types +/// and dimensions) of lowering a `vector.contract` to "primitive" matrix +/// multiplication operations of the SVE dialect (here "primitive" would mean +/// corresponding to a single target instruction). +/// +/// Supported are lowering to FEAT_I8MM `smmla`, `ummla`, and `usmmla`, and to +/// FEAT_BF16 `bfmmla`. All the transformations are very similar to each other +/// for concreteness the description below is given for `smmla`. +/// +/// The lowering triggers for a contraction operation that performs a matrix +/// multiply of two 8-bit integer matrix tiles with logical dimensions +/// and <8x[N]> for the left-hand side (LHS) and the right-hand side +/// (RHS), respectively, added to a 32-bit integer accumulator operand (ACC) +/// with dimensions , yielding a 32-bit integer result (OUT). +/// +/// The operands' shapes are such that the operands can be evenly split into +/// sub-tiles with dimensions as expected by the targeted FEAT_I8MM +/// instructions. The intent is that M and N are chosen (by higher level +/// transforms) in such a way as to maximise register usage. The main use case +/// we envision as of now is MMT4D, thus the RHS operand is expected +/// pre-transposed. +/// +/// The matrix multiplication is performed by unrolling the usual tiled matrix +/// multiplication algorithm using sub-tiles with dimensions <2x8> for the +/// LHS, <8x[2]> for the RHS, and <2x[2]> for the result and the input +/// accumulator. +/// +/// One way to illustrate the operation is as follows: +/// +/// RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]> +/// +----------------------------- +/// LHS: <2x8> | <2x[2]> <2x[2]> ... <2x[2]> +/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]> +/// ... | ... ... ... ... +/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]> +/// +/// The RHS operand is unpacked into N/2 values, each representing a sequence +/// of VSCALE number of sub-tiles with dimensions <8x2>. +/// The LHS operand is initially unpacked into M/2 values, each representing a +/// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated +/// VSCALE times. Multiplying thus replicated LHS sub-tile by the corresponding +/// RHS sub-tile correctly computes an entire result sub-tile. +/// The 2x2 sub-tiles of the ACC and OUT have rows that are not adjacent +/// (in memory or when imposing a row-major layout on the 2D vector value). +/// Reading the ACC is implemented as reading two consecutive rows and +/// interleaving the by pairs to obtain a vector having length twice the length +/// of an ACC row. This vector now is a sequence of one-dimensional tiles with +/// the exact layout needed by the `smmla`/`bfmmla`/etc instructions, which +/// tiles are extracted one by one. For illustration, if we have an 2x4 ACC tile +/// a0 a1 b0 b1 +/// a2 a3 b2 b3 +/// we read the two rows as separate values and then interleave by pairs +/// to obtain +/// a0 a1 a2 a3 b0 b1 b2 b3 +/// from which we extract `a0 a1 a2 a3` and `b0 b1 b2 b3`. +/// +/// Writing the OUT tile is done by the reverse of the above procedure, +/// concatenate two "flattened" sub-tiles into +/// c0 c1 c2 c3 d0 d1 d2 d3 +/// deinterleave by pairs to obtain as separate values +/// c0 c1 d0 d1 +/// c2 c3 d2 d3 +/// which are then inserted into the final result. +/// +/// Multiplication of a signed LHS by an unsigned LHS is performed by +/// swapping the order of the operands and emitting an `usmmla` (since there +/// isn't an `summla` instruction). Therefore each ACC sub-tile needs +/// to be transposed before the addition and the sum, an OUT sub-tile, +/// needs to be transposed before insertion into the final result. +/// This is done very elegantly by a modification of the above to +/// interleave/deinterleave not by pairs, but by individual elements, e.g. +/// after ordinary interleave we obtain +/// a0 a2 a1 a3 b0 b2 b1 b3 +/// which is exactly the desired layout of having each individual 2x2 tile +/// transposed. +/// +/// All of the above readily applies to FEAT_BF16 `bfmmla` with the +/// difference that the shapes of the LHS, RHS are , <4x[M]>, and +/// respectively, that is the "K" dimension is fixed to 4, instead of 8 (like +/// for the integer case). +class VectorContractRewriter { +protected: + // Designate the operation (resp. instruction) used to do sub-tile matrix + // multiplications. + enum class MMLA { + Nop, + SignedInt, // smmla + UnsignedInt, // ummla + MixedInt, // usmmla + Bfloat // bfmmla + }; + + // Lower-level operation to be emitted. + MMLA mmlaOp = MMLA::Nop; + + // The operand tiles. These are not necessarily the operends of + // `vector.contract`, for example they could be operands to `arith.extsi` + // that is in turn fed into `vector.contract`. + Value lhs; + Value rhs; + Value acc; + + // Conventional names for matrix dimensions. + int64_t M = 0; + int64_t N = 0; + int64_t K = 0; + + // Single-dimensional vector types for the operands of the ArmSVE dialect + // op. + VectorType flatLhsType; + VectorType flatRhsType; + VectorType flatAccType; + + // Single-dimension vector type for the entire RHS tile. + VectorType flatRhsTileType; + + // Vector type having the same number of elements as a row in the + // accumulator/output tile and the same element type. + VectorType accRowTy; + + // Vector type having twice the number of elements as a row in the + // accumulator/output tile the same element type. + VectorType accRowX2Ty; + + // Vector type having half the number of elements as a row in the + // accumulator/output tile and an integer element type with twice the bit + // width. + VectorType accRow64Ty; + VectorType accRowX264Ty; + + // Indicate if the operands for the ArmSVE dialect operation need to be + // swapped. Currently this is needed in order to emulate an "summla" + // operation. + bool swapOperands = false; + + // Create the matrix mulitply and accumulate operation according to + // `mmlaOp`. + Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc, + Value lhs, Value rhs); + + // Check general preconditions for applying the transformation, common to the + // integer and the bfloat16 case. + LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter); + +public: + VectorContractRewriter() = default; + + // Do the actuall rewrite. This member function is shared by both integer and + // bfloat16 rewrites. + Value rewrite(vector::ContractionOp op, PatternRewriter &rewriter); +}; + +Value VectorContractRewriter::createMMLA(PatternRewriter &rewriter, + Location loc, Value acc, Value lhs, + Value rhs) { + if (swapOperands) + std::swap(lhs, rhs); + + switch (mmlaOp) { + case MMLA::SignedInt: + return rewriter.create(loc, flatAccType, acc, lhs, rhs); + case MMLA::UnsignedInt: + return rewriter.create(loc, flatAccType, acc, lhs, rhs); + case MMLA::MixedInt: + return rewriter.create(loc, flatAccType, acc, lhs, rhs); + case MMLA::Bfloat: + return rewriter.create(loc, flatAccType, acc, lhs, rhs); + default: + llvm_unreachable("Uninitialized operation kind"); + } +} + +LogicalResult VectorContractRewriter::match(vector::ContractionOp op, + PatternRewriter &rewriter) { + // Check iterator types for matrix multiplication. + auto itTypes = op.getIteratorTypesArray(); + if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || + itTypes[1] != vector::IteratorType::parallel || + itTypes[2] != vector::IteratorType::reduction) + return rewriter.notifyMatchFailure( + op, "iterator types do not correspond to matrix multiplication"); + + // Check permutation maps. For now only accept + // lhs: (d0, d1, d2) -> (d0, d2) + // rhs: (d0, d1, d2) -> (d1, d2) + // acc: (d0, d1, d2) -> (d0, d1) + // This corresponds to matrix multiplication with transposed RHS. + if (op.getIndexingMapsArray()[0] != + AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u}, + op.getContext()) || + op.getIndexingMapsArray()[1] != + AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u}, + op.getContext()) || + op.getIndexingMapsArray()[2] != AffineMap::getMultiDimMapWithTargets( + 3, ArrayRef{0u, 1u}, op.getContext())) + return rewriter.notifyMatchFailure(op, "non-matching permutation maps"); + + // Check the combining kind is addition. + if (op.getKind() != vector::CombiningKind::ADD) + return rewriter.notifyMatchFailure(op, "combining kind is not an addition"); + + return success(); +} + +Value VectorContractRewriter::rewrite(vector::ContractionOp op, + PatternRewriter &rewriter) { + Location loc = op.getLoc(); + + // Extract LHS sub-tiles with logical shape <2xK>. + SmallVector lhsTile; + for (int64_t i = 0; i < M; i += 2) { + // Extract two consecutive rows of the LHS tile. + auto r0 = + rewriter.create(loc, lhs, ArrayRef{i}); + auto r1 = + rewriter.create(loc, lhs, ArrayRef{i + 1}); + // Concatenate to obtain a 2 x K x flattened sub-tile. + SmallVector shuffleIdx(2 * K); + std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0); + auto t = rewriter.create(loc, r0, r1, shuffleIdx); + // Turn it into a scalable vector. + auto s = rewriter.create( + loc, t, rewriter.create(loc, flatLhsType), 0); + // Replicate the sub-tile VSCALE times to fill the entire vector. + auto r = rewriter.create(loc, s, 0); + lhsTile.push_back(r); + } + + // "Flatten" the RHS tile from <[N]xK> to <[N*K]>. + auto rhs = rewriter.create(this->rhs.getLoc(), + flatRhsTileType, this->rhs); + + // Extract the RHS sub-tiles with logical shape . + SmallVector rhsTile; + for (int64_t j = 0; j < N; j += 2) + rhsTile.push_back(rewriter.create( + loc, flatRhsType, rhs, j * K)); + + // Extract and pack the ACC sub-tiles. + SmallVector accTile; + for (int64_t i = 0; i < M; i += 2) { + // Extract two consecutive rows of the accumulator tile. + auto r0 = rewriter.create(loc, op.getAcc(), + ArrayRef{i}); + auto r1 = rewriter.create(loc, op.getAcc(), + ArrayRef{i + 1}); + Value accTileVec; + if (swapOperands) { + // We are performing the operation with swapped LHS and RHS we need to + // transpose each individual 2x2 tile of the accumulator and (later) the + // final result. + accTileVec = rewriter.create(loc, r0, r1); + } else { + // Bitcast accumulator rows to double-width integer elements, so + // subsequent interleave/deinterleave work on pairs of elements. + auto r0I64 = rewriter.create(loc, accRow64Ty, r0); + auto r1I64 = rewriter.create(loc, accRow64Ty, r1); + + // Interleave the rows, effectively flattening each 2x2 tile into 4 + // consecutive elements. + auto intrI64 = rewriter.create(loc, r0I64, r1I64); + + // Bitcast back to original element type. + accTileVec = rewriter.create(loc, accRowX2Ty, intrI64); + } + // Extract ACC sub-tiles. + for (int64_t j = 0; j < N; j += 2) + accTile.push_back(rewriter.create( + loc, flatAccType, accTileVec, j * 2)); + } + + // Emit sub-tile matrix multiplications. + SmallVector outTile; + for (int64_t i = 0; i < M / 2; ++i) + for (int64_t j = 0; j < N / 2; ++j) { + Value mmla = createMMLA(rewriter, loc, accTile[i * N / 2 + j], lhsTile[i], + rhsTile[j]); + outTile.push_back(mmla); + } + + // Unpack the OUT sub-tiles and insert into the result. + Value result = rewriter.create(loc, op.getResultType()); + for (int64_t i = 0; i < M / 2; ++i) { + // Collect a number of sub-tiles in a row. + Value row = rewriter.create(loc, accRowX2Ty); + for (int64_t j = 0; j < N / 2; ++j) + row = rewriter.create( + loc, outTile[i * N / 2 + j], row, j * 4); + + // Unpack the row to obtain two rows of the output. If we have the out + // sub-tiles transposed we obtain two consecutive output rows by + // separating even and odd elements, i.e. a simple deinterleave. + // Otherwise, the interleave is by pairs. + Value out0, out1; + if (swapOperands) { + auto tmp = rewriter.create(loc, row); + out0 = tmp.getRes1(); + out1 = tmp.getRes2(); + } else { + // Deinterleave by pairs. + auto row64 = rewriter.create(loc, accRowX264Ty, row); + auto deintr64 = rewriter.create(loc, row64); + + // Bitcast back into original element type and insert into the result. + out0 = + rewriter.create(loc, accRowTy, deintr64.getRes1()); + out1 = + rewriter.create(loc, accRowTy, deintr64.getRes2()); + } + result = rewriter.create(loc, out0, result, i * 2); + result = rewriter.create(loc, out1, result, i * 2 + 1); + } + + return result; +} + +class VectorContractRewriterI8MM : public VectorContractRewriter { +public: + // Check the specific preconditions for the integer case. Initialise + // parametrisation types and dimensions. + LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter) { + + if (failed(VectorContractRewriter::match(op, rewriter))) + return failure(); + + VectorType lhsType = op.getLhsType(); + VectorType rhsType = op.getRhsType(); + + M = lhsType.getDimSize(0); + N = rhsType.getDimSize(0); + K = rhsType.getDimSize(1); + + // Check the operands have the expected shape: + // * for LHS: fixed vector MxK + // * for RHS: scalable vector [N]xK + // * K == 8 + // * M and N even and at least 2 + if (lhsType.isScalable() || !rhsType.getScalableDims()[0] || + rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 8 || + M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 || + !rhsType.getScalableDims()[0]) + return rewriter.notifyMatchFailure(op, "non-matching operand shape"); + + // Check the output is a vector of i32 elements. + auto outTy = dyn_cast(op.getResultType()); + if (!outTy || outTy.getElementType() != rewriter.getI32Type()) + return rewriter.notifyMatchFailure(op, + "output type is not a vector of i32"); + + // Check inputs are sign-/zero- extensions from i8 to i32. Get the values + // before the extension. All four signed/unsigned combinations for input + // operands are supported, but they are lowered to different operations. + // Determine which is the appropriate operation to lower to. + mmlaOp = MMLA::SignedInt; + swapOperands = false; + auto maybeLhs = getExtOperand(op.getLhs()); + if (!maybeLhs) { + mmlaOp = MMLA::UnsignedInt; + maybeLhs = getExtOperand(op.getLhs()); + } + if (!maybeLhs) + return rewriter.notifyMatchFailure( + op, "LHS is not a sign- or zero- extended i8"); + + auto maybeRhs = getExtOperand(op.getRhs()); + if (maybeRhs) { + if (mmlaOp == MMLA::UnsignedInt) + mmlaOp = MMLA::MixedInt; + } else { + if (mmlaOp == MMLA::SignedInt) { + mmlaOp = MMLA::MixedInt; + swapOperands = true; + } + maybeRhs = getExtOperand(op.getRhs()); + } + if (!maybeRhs) + return rewriter.notifyMatchFailure( + op, "RHS is not a sign- or zero- extended i8"); + + // Initialise algorithm parameters. + lhs = *maybeLhs; + rhs = *maybeRhs; + acc = op.getAcc(); + + flatLhsType = VectorType::get(/*shape=*/16, rewriter.getI8Type(), + /*scalableDims=*/{true}); + flatRhsType = VectorType::get(/*shape=*/16, rewriter.getI8Type(), + /*scalableDims=*/{true}); + + flatAccType = VectorType::get(/*shape=*/4, rewriter.getI32Type(), + /*scalableDims=*/{true}); + + flatRhsTileType = VectorType::get(/*shape=*/8 * N, rewriter.getI8Type(), + /*scalableDims=*/{true}); + + accRowTy = VectorType::get(/*shape=*/N, rewriter.getI32Type(), + /*scalableDims=*/{true}); + accRowX2Ty = VectorType::get(/*shape=*/2 * N, rewriter.getI32Type(), + /*scalableDims=*/{true}); + accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(), + /*scalableDims=*/{true}); + accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(), + /*scalableDims=*/{true}); + + return success(); + } +}; + +class VectorContractRewriterBfloat : public VectorContractRewriter { +public: + // Check the specific preconditions for the bfloat16 case. Initialise + // parametrisation types and dimensions. + LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter) { + + if (failed(VectorContractRewriter::match(op, rewriter))) + return failure(); + + VectorType lhsType = op.getLhsType(); + VectorType rhsType = op.getRhsType(); + + M = lhsType.getDimSize(0); + N = rhsType.getDimSize(0); + K = rhsType.getDimSize(1); + + // Check the operands have the expected shape: + // * for LHS: fixed vector MxK + // * for RHS: scalable vector [N]xK + // * K == 4 + // * M and N even and at least 2 + if (lhsType.isScalable() || !rhsType.getScalableDims()[0] || + rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 4 || + M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 || + !rhsType.getScalableDims()[0]) + return rewriter.notifyMatchFailure(op, "non-matching operand shape"); + + // Check the output is a vector of Float32 elements. + auto outTy = dyn_cast(op.getResultType()); + if (!outTy || outTy.getElementType() != rewriter.getF32Type()) + return rewriter.notifyMatchFailure(op, + "output type is not a vector of f32"); + + // Check the inputs are vectors of BFloat16 elements. + if (lhsType.getElementType() != rewriter.getBF16Type()) + return rewriter.notifyMatchFailure(op, + "input type is not a vector of bf16"); + + // Initialise algorithm parameters. + mmlaOp = MMLA::Bfloat; + swapOperands = false; + lhs = op.getLhs(); + rhs = op.getRhs(); + acc = op.getAcc(); + + flatLhsType = VectorType::get(/*shape=*/8, rewriter.getBF16Type(), + /*scalableDims=*/{true}); + flatRhsType = VectorType::get(/*shape=*/8, rewriter.getBF16Type(), + /*scalableDims=*/{true}); + + flatAccType = VectorType::get(/*shape=*/4, rewriter.getF32Type(), + /*scalableDims=*/{true}); + + flatRhsTileType = VectorType::get(/*shape=*/4 * N, rewriter.getBF16Type(), + /*scalableDims=*/{true}); + + accRowTy = VectorType::get(/*shape=*/N, rewriter.getF32Type(), + /*scalableDims=*/{true}); + accRowX2Ty = VectorType::get(/*shape=*/2 * N, rewriter.getF32Type(), + /*scalableDims=*/{true}); + accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(), + /*scalableDims=*/{true}); + accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(), + /*scalableDims=*/{true}); + + return success(); + } +}; + +class LowerContractionToSVEI8MMPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + + // Match i8xi8 -> i32 matrix multiply and accumulate. + VectorContractRewriterI8MM vcr; + if (failed(vcr.match(op, rewriter))) + return failure(); + + Value result = vcr.rewrite(op, rewriter); + rewriter.replaceOp(op, result); + + return success(); + } +}; + +class LowerContractionToSVEBFMMLAPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + + // Match bf16xbf16 -> f32 matrix multiply and accumulate. + VectorContractRewriterBfloat vcr; + if (failed(vcr.match(op, rewriter))) + return failure(); + + Value result = vcr.rewrite(op, rewriter); + rewriter.replaceOp(op, result); + + return success(); + } +}; + +} // namespace + +void mlir::populateLowerContractionToSVEI8MMPatternPatterns( + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add(context, /*benefit=*/2); +} + +void mlir::populateLowerContractionToSVEBFMMLAPatterns( + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add(context, /*benefit=*/2); +} diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp deleted file mode 100644 index b7703ff0393eb..0000000000000 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp +++ /dev/null @@ -1,366 +0,0 @@ -//===- LowerContractionToSVEI8MMPattern.cpp - Contract to I8MM --*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements lowering patterns from vector.contract to operations -// that map to instructions from the SVE FEAT_I8MM extension. -// -// TODO: There may be opportunities to unify this with a similar pattern -// for Neon. See: -// https://github.com/llvm/llvm-project/issues/145559 -// LowerContractionToNeonI8MMPattern.cpp -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" -#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "mlir/Dialect/UB/IR/UBOps.h" - -#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" - -using namespace mlir; - -namespace { -// Get the operand of a `vector.contract`. This function is intended to abstract -// away from the particular way a value is extended before feeding it into the -// `vector.contract` - via zero-extend or an explicit or implicit sign-extend -// (for implicit sign-extension see `vector.contract` documentation). -// -// The template parameter `Op` indicates the extension operation (explicit or -// implicit) for which we are checking. -// -// Return success only for extensions from `i8` to `i32`. -template -std::optional getExtOperand(Value v) { - - static_assert(llvm::is_one_of::value, - "Must be instantiated with either sign- or zero- extension op"); - - // If the operand is not defined by an explicit extend operation of the - // accepted operation type allow for an implicit sign-extension. - auto extOp = dyn_cast_or_null(v.getDefiningOp()); - if (!extOp) { - if constexpr (std::is_same::value) { - auto vTy = cast(v.getType()); - if (!vTy.getElementType().isSignlessInteger(8)) - return {}; - return v; - } - return {}; - } - - // If the operand is defined by an explicit extend operation of the accepted - // operation type, check it's extended from `i8` to `i32`. - auto inOp = extOp.getIn(); - auto inTy = dyn_cast(inOp.getType()); - if (!inTy || !inTy.getElementType().isSignlessInteger(8)) - return {}; - - auto outTy = dyn_cast(extOp.getType()); - if (!outTy || !outTy.getElementType().isSignlessInteger(32)) - return {}; - - return inOp; -} - -// Designate the operation (resp. instruction) used to do sub-tile matrix -// multiplications. -enum class MMLA { - Signed, // smmla - Unsigned, // ummla - Mixed, // usmmla - MixedSwapped // usmmla with LHS and RHS swapped -}; - -// Create the matrix mulitply and accumulate operation according to `op`. -Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, - mlir::VectorType accType, Value acc, Value lhs, Value rhs) { - switch (op) { - case MMLA::Signed: - return rewriter.create(loc, accType, acc, lhs, rhs); - case MMLA::Unsigned: - return rewriter.create(loc, accType, acc, lhs, rhs); - case MMLA::Mixed: - return rewriter.create(loc, accType, acc, lhs, rhs); - case MMLA::MixedSwapped: - // The accumulator comes transposed and the result will be transposed - // later, so all we have to do here is swap the operands. - return rewriter.create(loc, accType, acc, rhs, lhs); - } -} - -/// Lower a contraction operation that performs a matrix multiplication -/// of two 8-bit integer matrix tiles with logical dimensions and <8x[N]> -/// for the left-hand side and the right-hand side, respectively, -/// yielding a 32-bit integer result. -/// -/// The operands' shapes are such that the operands can be evenly split into -/// sub-tiles with dimensions as expected by the targeted FEAT_I8MM -/// instructions. The intent is that M and N are chosen (by higher level -/// transforms) in such a way as to maximise register usage. The main use case -/// we envision as of now is MMT4D, thus the RHS operand is expected -/// pre-transposed. -/// -/// The matrix multiplication is performed by unrolling the usual tiled matrix -/// multiplication algorithm using sub-tiles with dimensions <2x8> for the LHS, -/// <8x[2]> for the RHS, and <2x[2]> for the result and the input accumulator. -/// -/// One way to illustrate the operation is as follows: -/// -/// RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]> -/// +----------------------------- -/// LHS: <2x8> | <2x[2]> <2x[2]> ... <2x[2]> -/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]> -/// ... | ... ... ... ... -/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]> -/// -/// The RHS operand is unpacked into N/2 values, each representing a sequence of -/// VSCALE number of sub-tiles with dimensions <8x2>. -/// The LHS operand is initially unpacked into M/2 values, each representing a -/// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated -/// VSCALE times. -/// Multiplying thus replicated LHS sub-tile by the corresponding RHS sub-tile -/// correctly computes an entire result sub-tile. -class LowerContractionToSVEI8MMPattern - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override { - - Location loc = op.getLoc(); - mlir::VectorType lhsType = op.getLhsType(); - mlir::VectorType rhsType = op.getRhsType(); - - // Check the rank the types so we can safely examine their dimensions. - if (lhsType.getRank() != 2 || rhsType.getRank() != 2) - return rewriter.notifyMatchFailure(op, "non-matching operand shape"); - - auto M = lhsType.getDimSize(0); - auto N = rhsType.getDimSize(0); - auto K = rhsType.getDimSize(1); - - // Check the operands have the expected shape: - // * for LHS: fixed vector MxK - // * for RHS: scalable vector [N]xK - // * K == 8 - // * M and N even and at least 2 - if (lhsType.isScalable() || !rhsType.getScalableDims()[0] || - rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 8 || - M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 || - !rhsType.getScalableDims()[0]) - return rewriter.notifyMatchFailure(op, "non-matching operand shape"); - - // Check permutation maps. For now only accept - // lhs: (d0, d1, d2) -> (d0, d2) - // rhs: (d0, d1, d2) -> (d1, d2) - // acc: (d0, d1, d2) -> (d0, d1) - // This corresponds to matrix multiplication with transposed RHS. - if (op.getIndexingMapsArray()[0] != - AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u}, - op.getContext()) || - op.getIndexingMapsArray()[1] != - AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u}, - op.getContext()) || - op.getIndexingMapsArray()[2] != - AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u}, - op.getContext())) - return rewriter.notifyMatchFailure(op, "non-matching permutation maps"); - - // Check iterator types for matrix multiplication. - auto itTypes = op.getIteratorTypesArray(); - if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || - itTypes[1] != vector::IteratorType::parallel || - itTypes[2] != vector::IteratorType::reduction) - return rewriter.notifyMatchFailure( - op, "iterator types do not correspond to matrix multiplication"); - - // Check the combining kind is addition. - if (op.getKind() != vector::CombiningKind::ADD) - return rewriter.notifyMatchFailure(op, - "combining kind is not an addition"); - - // Check the output is a vector of i32 elements. - auto outTy = dyn_cast(op.getResultType()); - if (!outTy || outTy.getElementType() != rewriter.getI32Type()) - return rewriter.notifyMatchFailure(op, - "output type is not a vector of i32"); - - // Check inputs are sign-/zero- extensions from i8 to i32. Get the values - // before the extension. All four signed/unsigned combinations for input - // operands are supported, but they are lowered to different operations. - // Determine which is the appropriate operation to lower to. - MMLA mmlaOp = MMLA::Signed; - auto maybeLhs = getExtOperand(op.getLhs()); - if (!maybeLhs) { - mmlaOp = MMLA::Unsigned; - maybeLhs = getExtOperand(op.getLhs()); - } - if (!maybeLhs) - return rewriter.notifyMatchFailure( - op, "LHS is not a sign- or zero- extended i8"); - - auto maybeRhs = getExtOperand(op.getRhs()); - if (maybeRhs) { - if (mmlaOp == MMLA::Unsigned) - mmlaOp = MMLA::Mixed; - } else { - if (mmlaOp == MMLA::Signed) - mmlaOp = MMLA::MixedSwapped; - maybeRhs = getExtOperand(op.getRhs()); - } - if (!maybeRhs) - return rewriter.notifyMatchFailure( - op, "RHS is not a sign- or zero- extended i8"); - - // One-dimensional vector types for arm_sve.*mmla - auto nxv16i8 = VectorType::get(/*shape=*/16, rewriter.getI8Type(), - /*scalableDims=*/{true}); - auto nxv4i32 = VectorType::get(/*shape=*/4, rewriter.getI32Type(), - /*scalableDims=*/{true}); - - // Extract LHS sub-tiles with logicall shape <2x8>. - SmallVector lhsTile; - for (int64_t i = 0; i < M; i += 2) { - // Extract two consecutive rows of the LHS tile. - auto r0 = rewriter.create(loc, *maybeLhs, - ArrayRef{i}); - auto r1 = rewriter.create(loc, *maybeLhs, - ArrayRef{i + 1}); - // Concatenate to obtain a 16 x i8 flattened sub-tile. - auto t = rewriter.create( - loc, r0, r1, - llvm::ArrayRef{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, - 14, 15}); - // Turn it into a scalable vector. - auto s = rewriter.create( - loc, t, rewriter.create(loc, nxv16i8), 0); - // Replicate the sub-tile VSCALE times to fill the entire vector. - auto r = rewriter.create(loc, s, 0); - lhsTile.push_back(r); - } - - // "Flatten" the RHS tile from <[N]x8> to <[8*N]>. - auto rhs = rewriter.create( - maybeRhs->getLoc(), - VectorType::get(/*shape=*/8 * N, rewriter.getI8Type(), - /*scalableDims=*/{true}), - *maybeRhs); - - // Extract the RHS sub-tiles with logical shape <8x[2]>. - SmallVector rhsTile; - for (int64_t j = 0; j < N; j += 2) - rhsTile.push_back( - rewriter.create(loc, nxv16i8, rhs, j * 8)); - - // Handy types for packing/unpacking of the accumulator tile. - auto accRowTy = VectorType::get(/*shape=*/N, rewriter.getI32Type(), - /*scalableDims=*/{true}); - auto accRowX2Ty = VectorType::get(/*shape=*/2 * N, rewriter.getI32Type(), - /*scalableDims=*/{true}); - auto accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(), - /*scalableDims=*/{true}); - auto accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(), - /*scalableDims=*/{true}); - - // Extract and pack the ACC sub-tiles. - SmallVector accTile; - for (int64_t i = 0; i < M; i += 2) { - // Extract two consecutive rows of the accumulator tile. - auto r0 = rewriter.create(loc, op.getAcc(), - ArrayRef{i}); - auto r1 = rewriter.create(loc, op.getAcc(), - ArrayRef{i + 1}); - Value accTileVec; - if (mmlaOp == MMLA::MixedSwapped) { - // We need to swap the positions of the LHS and RHS (since we don't have - // a signed * unsigned operation), but then each individual 2x2 tile of - // the acumulator and (later) the result need to be transposed. - accTileVec = rewriter.create(loc, r0, r1); - } else { - // Bitcast them to 64-bit elements, so subsequent - // interleave/deinterleave work on pairs of 32-bit numbers. - auto r0I64 = rewriter.create(loc, accRow64Ty, r0); - auto r1I64 = rewriter.create(loc, accRow64Ty, r1); - - // Interleave the rows, effectively flattening each 2x2 tile into 4 - // consecutive elements. - auto intrI64 = rewriter.create(loc, r0I64, r1I64); - - // Bitcast back to 32-bit elements. - accTileVec = - rewriter.create(loc, accRowX2Ty, intrI64); - } - // Extract ACC sub-tiles. - for (int64_t j = 0; j < N; j += 2) - accTile.push_back(rewriter.create( - loc, nxv4i32, accTileVec, j * 2)); - } - - // Emit sub-tile matrix multiplications. - SmallVector outTile; - for (int64_t i = 0; i < M / 2; ++i) - for (int64_t j = 0; j < N / 2; ++j) { - Value mmla = createMMLA(rewriter, mmlaOp, loc, nxv4i32, - accTile[i * N / 2 + j], lhsTile[i], rhsTile[j]); - outTile.push_back(mmla); - } - - // Unpack the OUT sub-tiles and insert into the result. - Value result = rewriter.create(loc, op.getResultType()); - for (int64_t i = 0; i < M / 2; ++i) { - // Collect a number of sub-tiles in a row. - Value row = rewriter.create(loc, accRowX2Ty); - for (int64_t j = 0; j < N / 2; ++j) - row = rewriter.create( - loc, outTile[i * N / 2 + j], row, j * 4); - - // Unpack the row to obtain two rows of the output. If we have the out - // sub-tiles transposed we obtain two consecutive output rows by - // separating even and odd elements, i.e. a simple deinterleave. - // Otherwise, the interleave is by pairs. - Value out0, out1; - if (mmlaOp == MMLA::MixedSwapped) { - auto tmp = rewriter.create(loc, row); - out0 = tmp.getRes1(); - out1 = tmp.getRes2(); - } else { - // Deinterleave by pairs. - auto row64 = rewriter.create(loc, accRowX264Ty, row); - auto deintr64 = rewriter.create(loc, row64); - - // Bitcast back into 32-bit elements and insert into the result. - out0 = rewriter.create(loc, accRowTy, - deintr64.getRes1()); - out1 = rewriter.create(loc, accRowTy, - deintr64.getRes2()); - } - result = rewriter.create(loc, out0, result, i * 2); - result = rewriter.create(loc, out1, result, i * 2 + 1); - } - - rewriter.replaceOp(op, result); - return success(); - } -}; - -} // namespace - -void mlir::populateLowerContractionToSVEI8MMPatternPatterns( - RewritePatternSet &patterns) { - MLIRContext *context = patterns.getContext(); - patterns.add(context, /*benefit=*/2); -} diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-bfmmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-bfmmla.mlir new file mode 100644 index 0000000000000..ca9d91576b512 --- /dev/null +++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-bfmmla.mlir @@ -0,0 +1,105 @@ +// RUN: mlir-opt %s --transform-interpreter | FileCheck %s + +#attrs = { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind +} + +// CHECK-LABEL: @test_vector_contract_to_bfmmla +// CHECK-SAME: %[[LHS:.+]]: vector<4x4xbf16>, %[[RHS:.+]]: vector<[4]x4xbf16>, %[[ACC:.+]]: vector<4x[4]xf32>) -> vector<4x[4]xf32> { +// CHECK-NEXT: %[[T0:.+]] = ub.poison : vector<[8]xf32> +// CHECK-NEXT: %[[UB:.+]] = ub.poison : vector<4x[4]xf32> +// CHECK-NEXT: %[[T2:.+]] = ub.poison : vector<[8]xbf16> + +// Extract rows 0 and 1 of the LHS, concatenate them, and replicate the resulting 8xbf16 vector +// VSCALE times to obtain a [8]xbf16 vector. +// CHECK-NEXT: %[[T3:.+]] = vector.extract %[[LHS]][0] : vector<4xbf16> from vector<4x4xbf16> +// CHECK-NEXT: %[[T4:.+]] = vector.extract %[[LHS]][1] : vector<4xbf16> from vector<4x4xbf16> +// CHECK-NEXT: %[[T5:.+]] = vector.shuffle %[[T3]], %[[T4]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xbf16>, vector<4xbf16> +// CHECK-NEXT: %[[T6:.+]] = vector.scalable.insert %[[T5]], %[[T2]][0] : vector<8xbf16> into vector<[8]xbf16> +// CHECK-NEXT: %[[LHS_00:.+]] = arm_sve.dupq_lane %[[T6]][0] : vector<[8]xbf16> + +// Same for rows 2 and 3 of the LHS. +// CHECK-NEXT: %[[T8:.+]] = vector.extract %[[LHS]][2] : vector<4xbf16> from vector<4x4xbf16> +// CHECK-NEXT: %[[T9:.+]] = vector.extract %[[LHS]][3] : vector<4xbf16> from vector<4x4xbf16> +// CHECK-NEXT: %[[T10:.+]] = vector.shuffle %[[T8]], %[[T9]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xbf16>, vector<4xbf16> +// CHECK-NEXT: %[[T11:.+]] = vector.scalable.insert %[[T10]], %[[T2]][0] : vector<8xbf16> into vector<[8]xbf16> +// CHECK-NEXT: %[[LHS_10:.+]] = arm_sve.dupq_lane %[[T11]][0] : vector<[8]xbf16> + +// Extract sub-tiles from the RHS +// CHECK-NEXT: %[[T13:.+]] = vector.shape_cast %[[RHS]] : vector<[4]x4xbf16> to vector<[16]xbf16> +// CHECK-NEXT: %[[RHS_00:.+]] = vector.scalable.extract %[[T13]][0] : vector<[8]xbf16> from vector<[16]xbf16> +// CHECK-NEXT: %[[RHS_01:.+]] = vector.scalable.extract %[[T13]][8] : vector<[8]xbf16> from vector<[16]xbf16> + + +// Extract accumulator rows 0 and 1 and pack (into "registers") +// CHECK-NEXT: %[[T16:.+]] = vector.extract %[[ACC]][0] : vector<[4]xf32> from vector<4x[4]xf32> +// CHECK-NEXT: %[[T17:.+]] = vector.extract %[[ACC]][1] : vector<[4]xf32> from vector<4x[4]xf32> +// CHECK-NEXT: %[[T18:.+]] = vector.bitcast %[[T16]] : vector<[4]xf32> to vector<[2]xi64> +// CHECK-NEXT: %[[T19:.+]] = vector.bitcast %[[T17]] : vector<[4]xf32> to vector<[2]xi64> +// CHECK-NEXT: %[[T20:.+]] = vector.interleave %[[T18]], %[[T19]] : vector<[2]xi64> -> vector<[4]xi64> +// CHECK-NEXT: %[[T21:.+]] = vector.bitcast %[[T20]] : vector<[4]xi64> to vector<[8]xf32> +// CHECK-NEXT: %[[ACC_00:.+]] = vector.scalable.extract %[[T21]][0] : vector<[4]xf32> from vector<[8]xf32> +// CHECK-NEXT: %[[ACC_01:.+]] = vector.scalable.extract %[[T21]][4] : vector<[4]xf32> from vector<[8]xf32> + +// Same for accumulator rows 2 and 3 +// CHECK-NEXT: %[[T24:.+]] = vector.extract %[[ACC]][2] : vector<[4]xf32> from vector<4x[4]xf32> +// CHECK-NEXT: %[[T25:.+]] = vector.extract %[[ACC]][3] : vector<[4]xf32> from vector<4x[4]xf32> +// CHECK-NEXT: %[[T26:.+]] = vector.bitcast %[[T24]] : vector<[4]xf32> to vector<[2]xi64> +// CHECK-NEXT: %[[T27:.+]] = vector.bitcast %[[T25]] : vector<[4]xf32> to vector<[2]xi64> +// CHECK-NEXT: %[[T28:.+]] = vector.interleave %[[T26]], %[[T27]] : vector<[2]xi64> -> vector<[4]xi64> +// CHECK-NEXT: %[[T29:.+]] = vector.bitcast %[[T28]] : vector<[4]xi64> to vector<[8]xf32> +// CHECK-NEXT: %[[ACC_10:.+]] = vector.scalable.extract %[[T29]][0] : vector<[4]xf32> from vector<[8]xf32> +// CHECK-NEXT: %[[ACC_11:.+]] = vector.scalable.extract %[[T29]][4] : vector<[4]xf32> from vector<[8]xf32> + +// Do the sub-tile matrix multiplications +// CHECK-NEXT: %[[PACK_RES_00:.+]] = arm_sve.intr.bfmmla %[[ACC_00]], %[[LHS_00]], %[[RHS_00]] : vector<[8]xbf16> to vector<[4]xf32> +// CHECK-NEXT: %[[PACK_RES_01:.+]] = arm_sve.intr.bfmmla %[[ACC_01]], %[[LHS_00]], %[[RHS_01]] : vector<[8]xbf16> to vector<[4]xf32> +// CHECK-NEXT: %[[PACK_RES_10:.+]] = arm_sve.intr.bfmmla %[[ACC_10]], %[[LHS_10]], %[[RHS_00]] : vector<[8]xbf16> to vector<[4]xf32> +// CHECK-NEXT: %[[PACK_RES_11:.+]] = arm_sve.intr.bfmmla %[[ACC_11]], %[[LHS_10]], %[[RHS_01]] : vector<[8]xbf16> to vector<[4]xf32> + +// Unpack (from "registers") and insert in the output result rows 0 and 1 +// CHECK-NEXT: %[[T36:.+]] = vector.scalable.insert %[[PACK_RES_00]], %[[T0]][0] : vector<[4]xf32> into vector<[8]xf32> +// CHECK-NEXT: %[[T37:.+]] = vector.scalable.insert %[[PACK_RES_01]], %[[T36]][4] : vector<[4]xf32> into vector<[8]xf32> +// CHECK-NEXT: %[[T38:.+]] = vector.bitcast %[[T37]] : vector<[8]xf32> to vector<[4]xi64> +// CHECK-NEXT: %res1, %res2 = vector.deinterleave %[[T38]] : vector<[4]xi64> -> vector<[2]xi64> +// CHECK-NEXT: %[[UNPACK_RES_00:.+]] = vector.bitcast %res1 : vector<[2]xi64> to vector<[4]xf32> +// CHECK-NEXT: %[[UNPACK_RES_01:.+]] = vector.bitcast %res2 : vector<[2]xi64> to vector<[4]xf32> +// CHECK-NEXT: %[[TMP_OUT_0:.+]] = vector.insert %[[UNPACK_RES_00]], %[[UB]] [0] : vector<[4]xf32> into vector<4x[4]xf32> +// CHECK-NEXT: %[[TMP_OUT_1:.+]] = vector.insert %[[UNPACK_RES_01]], %[[TMP_OUT_0]] [1] : vector<[4]xf32> into vector<4x[4]xf32> + +// Same for result rows 2 and 3 +// CHECK-NEXT: %[[T43:.+]] = vector.scalable.insert %[[PACK_RES_10]], %[[T0]][0] : vector<[4]xf32> into vector<[8]xf32> +// CHECK-NEXT: %[[T44:.+]] = vector.scalable.insert %[[PACK_RES_11]], %[[T43]][4] : vector<[4]xf32> into vector<[8]xf32> +// CHECK-NEXT: %[[T45:.+]] = vector.bitcast %[[T44]] : vector<[8]xf32> to vector<[4]xi64> +// CHECK-NEXT: %res1_0, %res2_1 = vector.deinterleave %[[T45]] : vector<[4]xi64> -> vector<[2]xi64> +// CHECK-NEXT: %[[UNPACK_RES_10:.+]] = vector.bitcast %res1_0 : vector<[2]xi64> to vector<[4]xf32> +// CHECK-NEXT: %[[UNPACK_RES_11:.+]] = vector.bitcast %res2_1 : vector<[2]xi64> to vector<[4]xf32> +// CHECK-NEXT: %[[TMP_OUT_2:.+]] = vector.insert %[[UNPACK_RES_10]], %[[TMP_OUT_1]] [2] : vector<[4]xf32> into vector<4x[4]xf32> +// CHECK-NEXT: %[[OUT:.+]] = vector.insert %[[UNPACK_RES_11]], %[[TMP_OUT_2]] [3] : vector<[4]xf32> into vector<4x[4]xf32> +// CHECK-NEXT: return %[[OUT]] : vector<4x[4]xf32> +func.func @test_vector_contract_to_bfmmla(%lhs: vector<4x4xbf16>, + %rhs: vector<[4]x4xbf16>, + %acc: vector<4x[4]xf32>) -> vector<4x[4]xf32> { + %0 = vector.contract #attrs %lhs, %rhs, %acc + : vector<4x4xbf16>, vector<[4]x4xbf16> into vector<4x[4]xf32> + + return %0 : vector<4x[4]xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func"> + + transform.apply_patterns to %func { + transform.apply_patterns.arm_sve.vector_contract_to_bfmmla + } : !transform.op<"func.func"> + + transform.yield + } +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-bfmmla.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-bfmmla.mlir new file mode 100644 index 0000000000000..0e988d1c2f42c --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-bfmmla.mlir @@ -0,0 +1,201 @@ +// REQUIRES: arm-emulator + +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-bf16' \ +// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm \ +// DEFINE: --lower-affine --convert-arith-to-llvm --reconcile-unrealized-casts \ +// DEFINE: -o %t + +// DEFINE: %{entry_point} = main + +// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+bf16" \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils + +// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s + +#packed_maps = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (m, n)> +] + +// +// Test the lowering of `vector.contract` using the `LowerContractionToSVEBFMMLAPattern` +// +// The operation that the `vector.contract` in this test performs is matrix +// multiplication with accumulate +// OUT = ACC + LHS * RHS +// of two BFloat16 matrices LHS and RHS, and a Float32 matrix ACC into a Float32 OUT. +// +// Tested are calculations as well as that the relevant `ArmSVE` dialect +// operation ('arm_sve.intr.bfmmla`) is emitted. +// +// That pattern above handles (therefore this test prepares) input/output vectors with +// specific shapes: +// * LHS: vector +// * RHS: vector<[N]x4xbf16> +// * ACC, OUT: vector +// Note that the RHS is transposed. +// This data layout makes it efficient to load data into SVE +// registers in the layout expected by te BFMMLA instruction. +// Such a `vector.contract` is representative of the code we aim to generate +// by scalable vectorisation of `linalg.mmt4d`. +// See mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp +// for more information and rationale about these shapes. +// +// In this specific test we use M == 4 and N == 4 +// + +// Allocate and initialise a memref containing test data for use as the ACC +// operand. The memref has one dynamic dimension whose extent depends on the +// runtime value of VSCALE. +// +// The input parameter `%in` is a vector that is replicated VSCALE times +// across the columns of the memref. +func.func private @prepareAccTestData(%in: vector<4x4xf32>) -> memref<4x?xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + + %vs = vector.vscale + %d = arith.muli %c4, %vs : index + %mem = memref.alloc(%d) : memref<4x?xf32> + + scf.for %j = %c0 to %d step %c4 { + vector.transfer_write %in, %mem[%c0, %j] {in_bounds = [true, true]} : vector<4x4xf32>, memref<4x?xf32> + } + + return %mem : memref<4x?xf32> +} + +// Allocate and initialise a memref containing test data for use as the LHS +// operand. This function just writes the parameter `%in` into the memref. +// The size of the LHS does not depends on VSCALE. +func.func private @prepareLHSTestData(%in: vector<4x4xbf16>) -> memref<4x4xbf16> { + %c0 = arith.constant 0 : index + + %mem = memref.alloc() : memref<4x4xbf16> + vector.transfer_write %in, %mem[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xbf16>, memref<4x4xbf16> + + return %mem : memref<4x4xbf16> +} + +// Allocate and initialise a memref containing test data for use as the RHS +// operand. The memref has one dynamic dimension whose extent depends on the +// runtime value of VSCALE. +// +// The input parameter `%in` is a vector that is replicated VSCALE times +// across the rows of the memref. +// +// For convenience, flatten the memref, since the RHS vector is read first as a +// single-dimensional scalable vector and then cast into [N]x4 shape. +func.func private @prepareRHSTestData(%in: vector<4x4xbf16>) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + + %vs = vector.vscale + %d = arith.muli %c4, %vs : index + %mem = memref.alloc(%d) : memref + + scf.for %i = %c0 to %d step %c4 { + vector.transfer_write %in, %mem[%i, %c0] {in_bounds = [true, true]} : vector<4x4xbf16>, memref + } + + %mem_out = memref.collapse_shape %mem [[0, 1]] : memref into memref + return %mem_out : memref +} + + +// CHECK-IR-LABEL: llvm.func @test_bfmmla +// CHECK-IR-COUNT-4: arm_sve.intr.bfmmla +func.func @test_bfmmla() { + + %c0 = arith.constant 0 : index + %c0_f32 = arith.constant 0.0 : f32 + %c0_bf16 = arith.constant 0.0 : bf16 + + // Accumulator test data + %acc_cst = arith.constant dense<[[ 0.7, 1.0, -0.1, 1.8], + [-0.5, 0.9, 0.7, -0.7], + [ 0.5, -1.3, -2.2, 0.1], + [-0.7, 1.0, 1.7, -1.0]]> : vector<4x4xf32> + + %acc_mem = func.call @prepareAccTestData(%acc_cst) : (vector<4x4xf32>) -> memref<4x?xf32> + %acc = vector.transfer_read %acc_mem[%c0, %c0], %c0_f32 {in_bounds = [true, true]} : memref<4x?xf32>, vector<4x[4]xf32> + + // LHS test data + %lhs_cst = arith.constant dense<[[ 0.1, 0.7, -0.9, 1.3], + [-1.6, 0.7, -0.3, -0.3], + [-0.4, 0.6, 0.8, -0.5], + [-0.6, -1.0, -1.0, -1.0]]> : vector<4x4xbf16> + + %lhs_mem = func.call @prepareLHSTestData(%lhs_cst) : (vector<4x4xbf16>) -> memref<4x4xbf16> + %lhs = vector.transfer_read %lhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16> + + // RHS test data + %rhs_cst = arith.constant dense<[[ 0.6, 1.3, 0.1, -0.9], + [ 0.5, 1.6, 1.8, 1.6], + [-0.2, 0.4, 1.0, 0.4], + [-1.3, -0.2, -2.2, 0.3]]> : vector<4x4xbf16> + + %rhs_mem = func.call @prepareRHSTestData(%rhs_cst) : (vector<4x4xbf16>) -> memref + %rhs_flat = vector.transfer_read %rhs_mem[%c0], %c0_bf16 {in_bounds = [true]} : memref, vector<[16]xbf16> + %rhs = vector.shape_cast %rhs_flat : vector<[16]xbf16> to vector<[4]x4xbf16> + + // Matrix multiplication and accumulate with transposed RHS. + %0 = vector.contract {indexing_maps = #packed_maps, + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %lhs, %rhs, %acc + : vector<4x4xbf16>, vector<[4]x4xbf16> into vector<4x[4]xf32> + + // Display the result of the multiplication + vector.print str "Result(BFMMLA):\n" + %u0 = vector.extract %0[0] : vector<[4]xf32> from vector<4x[4]xf32> + %u1 = vector.extract %0[1] : vector<[4]xf32> from vector<4x[4]xf32> + %u2 = vector.extract %0[2] : vector<[4]xf32> from vector<4x[4]xf32> + %u3 = vector.extract %0[3] : vector<[4]xf32> from vector<4x[4]xf32> + vector.print %u0 : vector<[4]xf32> + vector.print %u1 : vector<[4]xf32> + vector.print %u2 : vector<[4]xf32> + vector.print %u3 : vector<[4]xf32> + + // Deallocate the buffers. + memref.dealloc %acc_mem : memref<4x?xf32> + memref.dealloc %lhs_mem : memref<4x4xbf16> + memref.dealloc %rhs_mem : memref + + return +} + +// Perform each test with SVE vector lengths 128 bits and 256 bits (i.e. VSCALEs +// 1 and 2, respectively). The vector length is set via the `setArmVLBits` +// function. The effect of setting a different vector length is that the tests +// allocate and operate on different sized buffers (see `prepareTestData` +// functions). + +func.func @main() { + %c128 = arith.constant 128 : i32 + %c256 = arith.constant 256 : i32 + +// CHECK-LABEL: Result(BFMMLA): +// CHECK: ( 0.411922, 2.63254, -0.219259, 3.89965 ) +// CHECK: ( -0.316515, 0.196875, 0.879375, 1.80924 ) +// CHECK: ( 1.56867, 0.101367, -1.2784, -1.41579 ) +// CHECK: ( -1.56041, -4.30078, 0.0196488, 1.88269 ) + func.call @setArmVLBits(%c128) : (i32) -> () + func.call @test_bfmmla() : () -> () + +// CHECK: Result(BFMMLA): +// CHECK: ( 0.411922, 2.63254, -0.219259, 3.89965, 0.411922, 2.63254, -0.219259, 3.89965 ) +// CHECK: ( -0.316515, 0.196875, 0.879375, 1.80924, -0.316515, 0.196875, 0.879375, 1.80924 ) +// CHECK: ( 1.56867, 0.101367, -1.2784, -1.41579, 1.56867, 0.101367, -1.2784, -1.41579 ) +// CHECK: ( -1.56041, -4.30078, 0.0196488, 1.88269, -1.56041, -4.30078, 0.0196488, 1.88269 ) + func.call @setArmVLBits(%c256) : (i32) -> () + func.call @test_bfmmla() : () -> () + + return +} + +func.func private @setArmVLBits(%bits : i32) +func.func private @printMemrefF32(%ptr : memref<*xf32>) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir index 5f6e8e4c30892..8504d664fa0c6 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir @@ -20,7 +20,7 @@ ] // -// Test the lowering of `vector.contract` using the `LowerContractionToSVEI8MMPattern` +// Test the lowering of `vector.contract` using the `LowerContractionToSVEBFMMLAPattern` // // The operation that the `vector.contract` in this test performs is matrix // multiplication with accumulate @@ -42,7 +42,7 @@ // registers in the layout expected by FEAT_I8MM instructions. // Such a `vector.contract` is representative of the code we aim to generate // by scalable vectorisation of `linalg.mmt4d`. -// See mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp +// See mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp // for more information and rationale about these shapes. // // In this specific test we use M == 4 and N == 4 @@ -316,7 +316,7 @@ func.func @test_usmmla() { // Test the operation where LHS is interpreted as signed and RHS is interpreted // as unsigned. In this test we ultimately emit end execute the `usmmla` -// instruction with reversed operands, see `LowerContractionToSVEI8MMPattern.cpp` +// instruction with reversed operands, see `LowerContractToSVEPatterns.cpp` // for more details. // CHECK-IR-LABEL: llvm.func @test_summla From 9a94866aa12d01d05dae3c3e6eec54ad2177d861 Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Tue, 8 Jul 2025 14:58:34 +0000 Subject: [PATCH 2/2] [fixup] Some refactoring --- .../TransformOps/ArmSVEVectorTransformOps.td | 8 +- .../Transforms/LowerContractToSVEPatterns.cpp | 149 ++++++++---------- .../CPU/ArmSVE/vector-contract-bfmmla.mlir | 12 +- 3 files changed, 79 insertions(+), 90 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td index 81b3c736b93f3..7777e6060ea76 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td @@ -16,8 +16,8 @@ def ApplyArmSVELowerContractionToI8MMPatternsOp : Op]> { let description = [{ - Indicates that vector contraction-like operations should be lowered to - finer-grained vector primitives using the ArmSVE dialect. + Indicates that vector contract operations should be lowered to + to ArmSVE dialect operations mapping to instructions from FEAT_I8MM. }]; let assemblyFormat = "attr-dict"; @@ -27,8 +27,8 @@ def ApplyArmSVELowerContractionToBFMMLAPatternsOp : Op]> { let description = [{ - Indicates that vector contraction-like operations should be lowered to - finer-grained vector primitives using the ArmSVE dialect. + Indicates that vector contract operations should be lowered to + ArmSVE dialect operations mapping to instructions from FEAT_BF16. }]; let assemblyFormat = "attr-dict"; diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp index 2987287afe9cd..8652cc7fc60dd 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include #include #define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" @@ -172,6 +173,11 @@ class VectorContractRewriter { // Lower-level operation to be emitted. MMLA mmlaOp = MMLA::Nop; + // Indicate if the operands for the ArmSVE dialect operation need to be + // swapped. Currently this is needed in order to emulate an "summla" + // operation. + bool swapOperands = false; + // The operand tiles. These are not necessarily the operends of // `vector.contract`, for example they could be operands to `arith.extsi` // that is in turn fed into `vector.contract`. @@ -184,34 +190,6 @@ class VectorContractRewriter { int64_t N = 0; int64_t K = 0; - // Single-dimensional vector types for the operands of the ArmSVE dialect - // op. - VectorType flatLhsType; - VectorType flatRhsType; - VectorType flatAccType; - - // Single-dimension vector type for the entire RHS tile. - VectorType flatRhsTileType; - - // Vector type having the same number of elements as a row in the - // accumulator/output tile and the same element type. - VectorType accRowTy; - - // Vector type having twice the number of elements as a row in the - // accumulator/output tile the same element type. - VectorType accRowX2Ty; - - // Vector type having half the number of elements as a row in the - // accumulator/output tile and an integer element type with twice the bit - // width. - VectorType accRow64Ty; - VectorType accRowX264Ty; - - // Indicate if the operands for the ArmSVE dialect operation need to be - // swapped. Currently this is needed in order to emulate an "summla" - // operation. - bool swapOperands = false; - // Create the matrix mulitply and accumulate operation according to // `mmlaOp`. Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc, @@ -232,18 +210,20 @@ class VectorContractRewriter { Value VectorContractRewriter::createMMLA(PatternRewriter &rewriter, Location loc, Value acc, Value lhs, Value rhs) { + + Type resTy = acc.getType(); if (swapOperands) std::swap(lhs, rhs); switch (mmlaOp) { case MMLA::SignedInt: - return rewriter.create(loc, flatAccType, acc, lhs, rhs); + return rewriter.create(loc, resTy, acc, lhs, rhs); case MMLA::UnsignedInt: - return rewriter.create(loc, flatAccType, acc, lhs, rhs); + return rewriter.create(loc, resTy, acc, lhs, rhs); case MMLA::MixedInt: - return rewriter.create(loc, flatAccType, acc, lhs, rhs); + return rewriter.create(loc, resTy, acc, lhs, rhs); case MMLA::Bfloat: - return rewriter.create(loc, flatAccType, acc, lhs, rhs); + return rewriter.create(loc, resTy, acc, lhs, rhs); default: llvm_unreachable("Uninitialized operation kind"); } @@ -283,6 +263,55 @@ LogicalResult VectorContractRewriter::match(vector::ContractionOp op, Value VectorContractRewriter::rewrite(vector::ContractionOp op, PatternRewriter &rewriter) { + + // Initialize some helper types. + Type operandEltType = cast(lhs.getType()).getElementType(); + Type resultEltType = cast(op.getResultType()).getElementType(); + + const int64_t numOperandSubTileElts = + 128 / operandEltType.getIntOrFloatBitWidth(); + + assert(resultEltType.getIntOrFloatBitWidth() == 32 && + "Only implemented for i32 or f32 output"); + const int64_t numResultSubTileElts = 4; + + // Single-dimensional vector types for the operands of the ArmSVE dialect + // op. + auto flatLhsType = + VectorType::get(/*shape=*/numOperandSubTileElts, operandEltType, + /*scalableDims=*/{true}); + auto flatRhsType = + VectorType::get(/*shape=*/numOperandSubTileElts, operandEltType, + /*scalableDims=*/{true}); + auto flatAccType = + VectorType::get(/*shape=*/numResultSubTileElts, resultEltType, + /*scalableDims=*/{true}); + + // Single-dimension vector type for the entire RHS tile. + + auto flatRhsTileType = VectorType::get(/*shape=*/K * N, operandEltType, + /*scalableDims=*/{true}); + + // Vector type having the same number of elements as a row in the + // accumulator/output tile and the same element type. + auto accRowTy = VectorType::get(/*shape=*/N, resultEltType, + /*scalableDims=*/{true}); + + // Vector type having twice the number of elements as a row in the + // accumulator/output tile the same element type. + auto accRowX2Ty = VectorType::get(/*shape=*/2 * N, resultEltType, + /*scalableDims=*/{true}); + // Vector type having half the number of elements as a row in the + // accumulator/output tile and an integer element type with twice the bit + // width. + auto accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(), + /*scalableDims=*/{true}); + // Vector type having the same the number of elements as a row in the + // accumulator/output tile and an integer element type with twice the bit + // width. + auto accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(), + /*scalableDims=*/{true}); + Location loc = op.getLoc(); // Extract LHS sub-tiles with logical shape <2xK>. @@ -397,9 +426,9 @@ class VectorContractRewriterI8MM : public VectorContractRewriter { public: // Check the specific preconditions for the integer case. Initialise // parametrisation types and dimensions. - LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter) { - - if (failed(VectorContractRewriter::match(op, rewriter))) + LogicalResult matchAndInit(vector::ContractionOp op, + PatternRewriter &rewriter) { + if (failed(match(op, rewriter))) return failure(); VectorType lhsType = op.getLhsType(); @@ -461,26 +490,6 @@ class VectorContractRewriterI8MM : public VectorContractRewriter { rhs = *maybeRhs; acc = op.getAcc(); - flatLhsType = VectorType::get(/*shape=*/16, rewriter.getI8Type(), - /*scalableDims=*/{true}); - flatRhsType = VectorType::get(/*shape=*/16, rewriter.getI8Type(), - /*scalableDims=*/{true}); - - flatAccType = VectorType::get(/*shape=*/4, rewriter.getI32Type(), - /*scalableDims=*/{true}); - - flatRhsTileType = VectorType::get(/*shape=*/8 * N, rewriter.getI8Type(), - /*scalableDims=*/{true}); - - accRowTy = VectorType::get(/*shape=*/N, rewriter.getI32Type(), - /*scalableDims=*/{true}); - accRowX2Ty = VectorType::get(/*shape=*/2 * N, rewriter.getI32Type(), - /*scalableDims=*/{true}); - accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(), - /*scalableDims=*/{true}); - accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(), - /*scalableDims=*/{true}); - return success(); } }; @@ -489,9 +498,9 @@ class VectorContractRewriterBfloat : public VectorContractRewriter { public: // Check the specific preconditions for the bfloat16 case. Initialise // parametrisation types and dimensions. - LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter) { - - if (failed(VectorContractRewriter::match(op, rewriter))) + LogicalResult matchAndInit(vector::ContractionOp op, + PatternRewriter &rewriter) { + if (failed(match(op, rewriter))) return failure(); VectorType lhsType = op.getLhsType(); @@ -530,26 +539,6 @@ class VectorContractRewriterBfloat : public VectorContractRewriter { rhs = op.getRhs(); acc = op.getAcc(); - flatLhsType = VectorType::get(/*shape=*/8, rewriter.getBF16Type(), - /*scalableDims=*/{true}); - flatRhsType = VectorType::get(/*shape=*/8, rewriter.getBF16Type(), - /*scalableDims=*/{true}); - - flatAccType = VectorType::get(/*shape=*/4, rewriter.getF32Type(), - /*scalableDims=*/{true}); - - flatRhsTileType = VectorType::get(/*shape=*/4 * N, rewriter.getBF16Type(), - /*scalableDims=*/{true}); - - accRowTy = VectorType::get(/*shape=*/N, rewriter.getF32Type(), - /*scalableDims=*/{true}); - accRowX2Ty = VectorType::get(/*shape=*/2 * N, rewriter.getF32Type(), - /*scalableDims=*/{true}); - accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(), - /*scalableDims=*/{true}); - accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(), - /*scalableDims=*/{true}); - return success(); } }; @@ -563,7 +552,7 @@ class LowerContractionToSVEI8MMPattern // Match i8xi8 -> i32 matrix multiply and accumulate. VectorContractRewriterI8MM vcr; - if (failed(vcr.match(op, rewriter))) + if (failed(vcr.matchAndInit(op, rewriter))) return failure(); Value result = vcr.rewrite(op, rewriter); @@ -582,7 +571,7 @@ class LowerContractionToSVEBFMMLAPattern // Match bf16xbf16 -> f32 matrix multiply and accumulate. VectorContractRewriterBfloat vcr; - if (failed(vcr.match(op, rewriter))) + if (failed(vcr.matchAndInit(op, rewriter))) return failure(); Value result = vcr.rewrite(op, rewriter); diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-bfmmla.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-bfmmla.mlir index 0e988d1c2f42c..8b209d3f777b5 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-bfmmla.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-bfmmla.mlir @@ -58,10 +58,10 @@ func.func private @prepareAccTestData(%in: vector<4x4xf32>) -> memref<4x?xf32> { %c4 = arith.constant 4 : index %vs = vector.vscale - %d = arith.muli %c4, %vs : index - %mem = memref.alloc(%d) : memref<4x?xf32> + %nCols = arith.muli %c4, %vs : index + %mem = memref.alloc(%nCols) : memref<4x?xf32> - scf.for %j = %c0 to %d step %c4 { + scf.for %j = %c0 to %nCols step %c4 { vector.transfer_write %in, %mem[%c0, %j] {in_bounds = [true, true]} : vector<4x4xf32>, memref<4x?xf32> } @@ -95,10 +95,10 @@ func.func private @prepareRHSTestData(%in: vector<4x4xbf16>) -> memref { %c4 = arith.constant 4 : index %vs = vector.vscale - %d = arith.muli %c4, %vs : index - %mem = memref.alloc(%d) : memref + %nRows = arith.muli %c4, %vs : index + %mem = memref.alloc(%nRows) : memref - scf.for %i = %c0 to %d step %c4 { + scf.for %i = %c0 to %nRows step %c4 { vector.transfer_write %in, %mem[%i, %c0] {in_bounds = [true, true]} : vector<4x4xbf16>, memref }