-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[MLIR][AArch64] Lower vector.contract
to SVE FEAT_BF16 operations
#147052
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[MLIR][AArch64] Lower vector.contract
to SVE FEAT_BF16 operations
#147052
Conversation
This patch adds lowering of Bfloat16 widening matrix multiply and accumulate `vector.contract`, by parametrising and refactoring the pattern for 8-bit integers.
@llvm/pr-subscribers-mlir-sve @llvm/pr-subscribers-mlir-neon Author: Momchil Velikov (momchil-velikov) ChangesThis patch adds lowering of Bfloat16 widening matrix multiply and accumulate Patch is 65.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147052.diff 12 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 5a864865adffc..4f304b39a0528 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1437,6 +1437,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of Arm FEAT_I8MM instructions while lowering "
"the vector dialect.">,
+ Option<"armBF16", "enable-arm-bf16",
+ "bool", /*default=*/"false",
+ "Enables the use of Arm FEAT_BF16 instructions while lowering "
+ "the vector dialect.">,
Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
index 53784982be6dc..81b3c736b93f3 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
@@ -12,7 +12,7 @@ include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
-def ApplyArmSVELowerContractionPatternsOp
+def ApplyArmSVELowerContractionToI8MMPatternsOp
: Op<Transform_Dialect, "apply_patterns.arm_sve.vector_contract_to_i8mm",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
@@ -23,4 +23,14 @@ def ApplyArmSVELowerContractionPatternsOp
let assemblyFormat = "attr-dict";
}
+def ApplyArmSVELowerContractionToBFMMLAPatternsOp
+ : Op<Transform_Dialect, "apply_patterns.arm_sve.vector_contract_to_bfmmla",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector contraction-like operations should be lowered to
+ finer-grained vector primitives using the ArmSVE dialect.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
#endif // ARMSVE_VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 232e2be29e574..de160dbf8ed94 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -23,6 +23,8 @@ void populateArmSVELegalizeForLLVMExportPatterns(
void populateLowerContractionToSVEI8MMPatternPatterns(
RewritePatternSet &patterns);
+void populateLowerContractionToSVEBFMMLAPatterns(RewritePatternSet &patterns);
+
/// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
/// intrinsics.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 67c0eca15638a..4d74aabcaa50d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -89,6 +89,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
if (armSVE)
populateLowerContractionToSVEI8MMPatternPatterns(patterns);
}
+ if (armBF16)
+ populateLowerContractionToSVEBFMMLAPatterns(patterns);
+
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
index 7180884c77e98..a95fc51d562c2 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
@@ -12,7 +12,7 @@
// TODO: There may be opportunities to unify this with a similar pattern
// for SVE. See:
// https://github.com/llvm/llvm-project/issues/145559
-// LowerContractionToSVEI8MMPattern.cpp
+// LowerContractToSVEPatterns.cpp
//
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
index b2ca4fc1eaa8c..8572c34c8b12b 100644
--- a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
+++ b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
@@ -18,11 +18,16 @@ using namespace mlir;
// Apply...PatternsOp
//===----------------------------------------------------------------------===//
-void transform::ApplyArmSVELowerContractionPatternsOp::populatePatterns(
+void transform::ApplyArmSVELowerContractionToI8MMPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
mlir::populateLowerContractionToSVEI8MMPatternPatterns(patterns);
}
+void transform::ApplyArmSVELowerContractionToBFMMLAPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ mlir::populateLowerContractionToSVEBFMMLAPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index 65f98b44b1b69..c29eaca244b4a 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
add_mlir_dialect_library(MLIRArmSVETransforms
LegalizeForLLVMExport.cpp
LegalizeVectorStorage.cpp
- LowerContractionToSVEI8MMPattern.cpp
+ LowerContractToSVEPatterns.cpp
DEPENDS
MLIRArmSVEConversionsIncGen
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
new file mode 100644
index 0000000000000..2987287afe9cd
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
@@ -0,0 +1,607 @@
+//===- LowerContractToSVEPatterns.cpp - Contract to I8MM/BF16 ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to operations
+// that map to instructions from the SVE FEAT_I8MM and FEAT_BF16 extensions.
+//
+// TODO: There may be opportunities to unify this with a similar pattern
+// for Neon. See:
+// https://github.com/llvm/llvm-project/issues/145559
+// LowerContractionToNeonI8MMPattern.cpp
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <numeric>
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+
+namespace {
+// Get the operand of a `vector.contract`. This function is intended to abstract
+// away from the particular way a value is extended before feeding it into the
+// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
+// (for implicit sign-extension see `vector.contract` documentation).
+//
+// The template parameter `Op` indicates the extension operation (explicit or
+// implicit) for which we are checking.
+//
+// Return success only for extensions from `i8` to `i32`.
+template <typename Op>
+std::optional<Value> getExtOperand(Value v) {
+
+ static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
+ "Must be instantiated with either sign- or zero- extension op");
+
+ // If the operand is not defined by an explicit extend operation of the
+ // accepted operation type allow for an implicit sign-extension.
+ auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+ if (!extOp) {
+ if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
+ auto vTy = cast<VectorType>(v.getType());
+ if (!vTy.getElementType().isSignlessInteger(8))
+ return {};
+ return v;
+ }
+ return {};
+ }
+
+ // If the operand is defined by an explicit extend operation of the accepted
+ // operation type, check it's extended from `i8` to `i32`.
+ auto inOp = extOp.getIn();
+ auto inTy = dyn_cast<VectorType>(inOp.getType());
+ if (!inTy || !inTy.getElementType().isSignlessInteger(8))
+ return {};
+
+ auto outTy = dyn_cast<VectorType>(extOp.getType());
+ if (!outTy || !outTy.getElementType().isSignlessInteger(32))
+ return {};
+
+ return inOp;
+}
+
+/// This class encapsulates the algorithm and parametrisation (in terms of types
+/// and dimensions) of lowering a `vector.contract` to "primitive" matrix
+/// multiplication operations of the SVE dialect (here "primitive" would mean
+/// corresponding to a single target instruction).
+///
+/// Supported are lowering to FEAT_I8MM `smmla`, `ummla`, and `usmmla`, and to
+/// FEAT_BF16 `bfmmla`. All the transformations are very similar to each other
+/// for concreteness the description below is given for `smmla`.
+///
+/// The lowering triggers for a contraction operation that performs a matrix
+/// multiply of two 8-bit integer matrix tiles with logical dimensions
+/// <Mx8> and <8x[N]> for the left-hand side (LHS) and the right-hand side
+/// (RHS), respectively, added to a 32-bit integer accumulator operand (ACC)
+/// with dimensions <Mx[N]>, yielding a <Mx[N]> 32-bit integer result (OUT).
+///
+/// The operands' shapes are such that the operands can be evenly split into
+/// sub-tiles with dimensions as expected by the targeted FEAT_I8MM
+/// instructions. The intent is that M and N are chosen (by higher level
+/// transforms) in such a way as to maximise register usage. The main use case
+/// we envision as of now is MMT4D, thus the RHS operand is expected
+/// pre-transposed.
+///
+/// The matrix multiplication is performed by unrolling the usual tiled matrix
+/// multiplication algorithm using sub-tiles with dimensions <2x8> for the
+/// LHS, <8x[2]> for the RHS, and <2x[2]> for the result and the input
+/// accumulator.
+///
+/// One way to illustrate the operation is as follows:
+///
+/// RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]>
+/// +-----------------------------
+/// LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+/// ... | ... ... ... ...
+/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+///
+/// The RHS operand is unpacked into N/2 values, each representing a sequence
+/// of VSCALE number of sub-tiles with dimensions <8x2>.
+/// The LHS operand is initially unpacked into M/2 values, each representing a
+/// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
+/// VSCALE times. Multiplying thus replicated LHS sub-tile by the corresponding
+/// RHS sub-tile correctly computes an entire result sub-tile.
+/// The 2x2 sub-tiles of the ACC and OUT have rows that are not adjacent
+/// (in memory or when imposing a row-major layout on the 2D vector value).
+/// Reading the ACC is implemented as reading two consecutive rows and
+/// interleaving the by pairs to obtain a vector having length twice the length
+/// of an ACC row. This vector now is a sequence of one-dimensional tiles with
+/// the exact layout needed by the `smmla`/`bfmmla`/etc instructions, which
+/// tiles are extracted one by one. For illustration, if we have an 2x4 ACC tile
+/// a0 a1 b0 b1
+/// a2 a3 b2 b3
+/// we read the two rows as separate values and then interleave by pairs
+/// to obtain
+/// a0 a1 a2 a3 b0 b1 b2 b3
+/// from which we extract `a0 a1 a2 a3` and `b0 b1 b2 b3`.
+///
+/// Writing the OUT tile is done by the reverse of the above procedure,
+/// concatenate two "flattened" sub-tiles into
+/// c0 c1 c2 c3 d0 d1 d2 d3
+/// deinterleave by pairs to obtain as separate values
+/// c0 c1 d0 d1
+/// c2 c3 d2 d3
+/// which are then inserted into the final result.
+///
+/// Multiplication of a signed LHS by an unsigned LHS is performed by
+/// swapping the order of the operands and emitting an `usmmla` (since there
+/// isn't an `summla` instruction). Therefore each ACC sub-tile needs
+/// to be transposed before the addition and the sum, an OUT sub-tile,
+/// needs to be transposed before insertion into the final result.
+/// This is done very elegantly by a modification of the above to
+/// interleave/deinterleave not by pairs, but by individual elements, e.g.
+/// after ordinary interleave we obtain
+/// a0 a2 a1 a3 b0 b2 b1 b3
+/// which is exactly the desired layout of having each individual 2x2 tile
+/// transposed.
+///
+/// All of the above readily applies to FEAT_BF16 `bfmmla` with the
+/// difference that the shapes of the LHS, RHS are <Mx4>, <4x[M]>, and
+/// respectively, that is the "K" dimension is fixed to 4, instead of 8 (like
+/// for the integer case).
+class VectorContractRewriter {
+protected:
+ // Designate the operation (resp. instruction) used to do sub-tile matrix
+ // multiplications.
+ enum class MMLA {
+ Nop,
+ SignedInt, // smmla
+ UnsignedInt, // ummla
+ MixedInt, // usmmla
+ Bfloat // bfmmla
+ };
+
+ // Lower-level operation to be emitted.
+ MMLA mmlaOp = MMLA::Nop;
+
+ // The operand tiles. These are not necessarily the operends of
+ // `vector.contract`, for example they could be operands to `arith.extsi`
+ // that is in turn fed into `vector.contract`.
+ Value lhs;
+ Value rhs;
+ Value acc;
+
+ // Conventional names for matrix dimensions.
+ int64_t M = 0;
+ int64_t N = 0;
+ int64_t K = 0;
+
+ // Single-dimensional vector types for the operands of the ArmSVE dialect
+ // op.
+ VectorType flatLhsType;
+ VectorType flatRhsType;
+ VectorType flatAccType;
+
+ // Single-dimension vector type for the entire RHS tile.
+ VectorType flatRhsTileType;
+
+ // Vector type having the same number of elements as a row in the
+ // accumulator/output tile and the same element type.
+ VectorType accRowTy;
+
+ // Vector type having twice the number of elements as a row in the
+ // accumulator/output tile the same element type.
+ VectorType accRowX2Ty;
+
+ // Vector type having half the number of elements as a row in the
+ // accumulator/output tile and an integer element type with twice the bit
+ // width.
+ VectorType accRow64Ty;
+ VectorType accRowX264Ty;
+
+ // Indicate if the operands for the ArmSVE dialect operation need to be
+ // swapped. Currently this is needed in order to emulate an "summla"
+ // operation.
+ bool swapOperands = false;
+
+ // Create the matrix mulitply and accumulate operation according to
+ // `mmlaOp`.
+ Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
+ Value lhs, Value rhs);
+
+ // Check general preconditions for applying the transformation, common to the
+ // integer and the bfloat16 case.
+ LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter);
+
+public:
+ VectorContractRewriter() = default;
+
+ // Do the actuall rewrite. This member function is shared by both integer and
+ // bfloat16 rewrites.
+ Value rewrite(vector::ContractionOp op, PatternRewriter &rewriter);
+};
+
+Value VectorContractRewriter::createMMLA(PatternRewriter &rewriter,
+ Location loc, Value acc, Value lhs,
+ Value rhs) {
+ if (swapOperands)
+ std::swap(lhs, rhs);
+
+ switch (mmlaOp) {
+ case MMLA::SignedInt:
+ return rewriter.create<arm_sve::SmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+ case MMLA::UnsignedInt:
+ return rewriter.create<arm_sve::UmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+ case MMLA::MixedInt:
+ return rewriter.create<arm_sve::UsmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+ case MMLA::Bfloat:
+ return rewriter.create<arm_sve::BfmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+ default:
+ llvm_unreachable("Uninitialized operation kind");
+ }
+}
+
+LogicalResult VectorContractRewriter::match(vector::ContractionOp op,
+ PatternRewriter &rewriter) {
+ // Check iterator types for matrix multiplication.
+ auto itTypes = op.getIteratorTypesArray();
+ if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+ itTypes[1] != vector::IteratorType::parallel ||
+ itTypes[2] != vector::IteratorType::reduction)
+ return rewriter.notifyMatchFailure(
+ op, "iterator types do not correspond to matrix multiplication");
+
+ // Check permutation maps. For now only accept
+ // lhs: (d0, d1, d2) -> (d0, d2)
+ // rhs: (d0, d1, d2) -> (d1, d2)
+ // acc: (d0, d1, d2) -> (d0, d1)
+ // This corresponds to matrix multiplication with transposed RHS.
+ if (op.getIndexingMapsArray()[0] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[1] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[2] != AffineMap::getMultiDimMapWithTargets(
+ 3, ArrayRef{0u, 1u}, op.getContext()))
+ return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
+
+ // Check the combining kind is addition.
+ if (op.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(op, "combining kind is not an addition");
+
+ return success();
+}
+
+Value VectorContractRewriter::rewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) {
+ Location loc = op.getLoc();
+
+ // Extract LHS sub-tiles with logical shape <2xK>.
+ SmallVector<Value> lhsTile;
+ for (int64_t i = 0; i < M; i += 2) {
+ // Extract two consecutive rows of the LHS tile.
+ auto r0 =
+ rewriter.create<vector::ExtractOp>(loc, lhs, ArrayRef<int64_t>{i});
+ auto r1 =
+ rewriter.create<vector::ExtractOp>(loc, lhs, ArrayRef<int64_t>{i + 1});
+ // Concatenate to obtain a 2 x K x <input-type> flattened sub-tile.
+ SmallVector<int64_t> shuffleIdx(2 * K);
+ std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0);
+ auto t = rewriter.create<vector::ShuffleOp>(loc, r0, r1, shuffleIdx);
+ // Turn it into a scalable vector.
+ auto s = rewriter.create<vector::ScalableInsertOp>(
+ loc, t, rewriter.create<ub::PoisonOp>(loc, flatLhsType), 0);
+ // Replicate the sub-tile VSCALE times to fill the entire vector.
+ auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
+ lhsTile.push_back(r);
+ }
+
+ // "Flatten" the RHS tile from <[N]xK> to <[N*K]>.
+ auto rhs = rewriter.create<vector::ShapeCastOp>(this->rhs.getLoc(),
+ flatRhsTileType, this->rhs);
+
+ // Extract the RHS sub-tiles with logical shape <Kx[2]>.
+ SmallVector<Value> rhsTile;
+ for (int64_t j = 0; j < N; j += 2)
+ rhsTile.push_back(rewriter.create<vector::ScalableExtractOp>(
+ loc, flatRhsType, rhs, j * K));
+
+ // Extract and pack the ACC sub-tiles.
+ SmallVector<Value> accTile;
+ for (int64_t i = 0; i < M; i += 2) {
+ // Extract two consecutive rows of the accumulator tile.
+ auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+ ArrayRef<int64_t>{i});
+ auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+ ArrayRef<int64_t>{i + 1});
+ Value accTileVec;
+ if (swapOperands) {
+ // We are performing the operation with swapped LHS and RHS we need to
+ // transpose e...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! LG % nits
It's really nice how you combined patterns for I8MM and BF16!
Indicates that vector contraction-like operations should be lowered to | ||
finer-grained vector primitives using the ArmSVE dialect. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ATM this is just a copy of the description for ApplyArmSVELowerContractionToI8MMPatternsOp
. Could you extend it so that the difference between the two is clear? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
// Check the specific preconditions for the bfloat16 case. Initialise | ||
// parametrisation types and dimensions. | ||
LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does a bit more than just matching :) I would split this into two methods, but I am also fine with just renaming.
// Check the specific preconditions for the bfloat16 case. Initialise | |
// parametrisation types and dimensions. | |
LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter) { | |
// Check the specific preconditions for the bfloat16 case. Initialise | |
// parametrisation types and dimensions. | |
LogicalResult matchAndInitialize(vector::ContractionOp op, PatternRewriter &rewriter) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Further refactored and renamed.
%c4 = arith.constant 4 : index | ||
|
||
%vs = vector.vscale | ||
%d = arith.muli %c4, %vs : index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] What does d
stand for in %d
? Why not %ub
(upper bound) or %c4_vscale
(4 x vscale)? A more descriptive would be helpful, otherwise it's hard to see what happens in the loop below (scf.for %j = %c0 to %d step %c4
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It stands for "dimension".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, that's too broad IMO. That immediately raises two questions in my head:
- The dimension of what?
- Which dimension? (in the case of multi-dimensional shapes)
Something more descriptive would be really appreciated. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the name.
This patch adds lowering of Bfloat16 widening matrix multiply and accumulate
vector.contract
, by parametrising and refactoring the pattern for 8-bit integers.