-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][linalg] Add support for scalable vectorization of linalg.mmt4d #146531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
487db47
8ef6661
2b6019c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some changes on comments are not relevant, can you revert them? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure how that crept it. Fixed in this commit. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -222,9 +222,11 @@ struct VectorizationState { | |
/// canonical vector shape for vectorization. | ||
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, | ||
ArrayRef<int64_t> inputVectorSizes, | ||
ArrayRef<bool> inputScalableVecDims); | ||
ArrayRef<bool> inputScalableVecDims, | ||
bool assumeScalableVecSizesMatchDimSize = false); | ||
|
||
/// Returns the canonical vector shape used to vectorize the iteration space. | ||
/// Returns the canonical vector shape used to vectorize the iteration | ||
/// space. | ||
ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; } | ||
|
||
/// Returns the vector dimensions that are scalable in the canonical vector | ||
|
@@ -233,8 +235,8 @@ struct VectorizationState { | |
|
||
/// Returns a vector type of the provided `elementType` with the canonical | ||
/// vector shape and the corresponding fixed/scalable dimensions bit. If | ||
/// `dimPermutation` is provided, the canonical vector dimensions are permuted | ||
/// accordingly. | ||
/// `dimPermutation` is provided, the canonical vector dimensions are | ||
/// permuted accordingly. | ||
VectorType getCanonicalVecType( | ||
Type elementType, | ||
std::optional<AffineMap> dimPermutation = std::nullopt) const { | ||
|
@@ -254,9 +256,9 @@ struct VectorizationState { | |
} | ||
|
||
/// Masks an operation with the canonical vector mask if the operation needs | ||
/// masking. Returns the masked operation or the original operation if masking | ||
/// is not needed. If provided, the canonical mask for this operation is | ||
/// permuted using `maybeIndexingMap`. | ||
/// masking. Returns the masked operation or the original operation if | ||
/// masking is not needed. If provided, the canonical mask for this | ||
/// operation is permuted using `maybeIndexingMap`. | ||
Operation * | ||
maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, | ||
std::optional<AffineMap> maybeIndexingMap = std::nullopt); | ||
|
@@ -276,15 +278,15 @@ struct VectorizationState { | |
|
||
/// Create or retrieve an existing mask value to mask `opToMask` in the | ||
/// canonical vector iteration space. If `maybeMaskingMap` the mask is | ||
/// permuted using that permutation map. If a new mask is created, it will be | ||
/// cached for future users. | ||
/// permuted using that permutation map. If a new mask is created, it will | ||
/// be cached for future users. | ||
Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask, | ||
LinalgOp linalgOp, | ||
std::optional<AffineMap> maybeMaskingMap); | ||
|
||
/// Check whether this permutation map can be used for masking. At the | ||
/// moment we only make sure that there are no broadcast dimensions, but this | ||
/// might change if indexing maps evolve. | ||
/// moment we only make sure that there are no broadcast dimensions, but | ||
/// this might change if indexing maps evolve. | ||
bool isValidMaskingMap(AffineMap maskingMap) { | ||
return maskingMap.getBroadcastDims().size() == 0; | ||
} | ||
|
@@ -324,13 +326,24 @@ struct VectorizationState { | |
/// shape. | ||
SmallVector<bool> scalableVecDims; | ||
|
||
/// Holds the active masks for permutations of the canonical vector iteration | ||
/// space. | ||
/// Holds the active masks for permutations of the canonical vector | ||
/// iteration space. | ||
DenseMap<AffineMap, Value> activeMaskCache; | ||
|
||
/// Global vectorization guard for the incoming rewriter. It's initialized | ||
/// when the vectorization state is initialized. | ||
OpBuilder::InsertionGuard rewriterGuard; | ||
|
||
/// Do all scalable vector sizes match the corresponding input dim sizes? | ||
/// (tensor or memref) | ||
/// | ||
/// At the Tensor + MemRef levels, scalable sizes are modelled using | ||
/// dynamic dimensions (i.e. `?`). In many cases these sizes result from | ||
/// e.g. "scalable packing + tiling" and are known to always match the | ||
/// scalable vector sizes. In such cases, masking can be safely skipped, | ||
/// despite the presence of dynamic shapes. Use this flag with care and | ||
/// only for cases where you are confident the assumption holds. | ||
bool assumeScalableVecSizesMatchDimSize = false; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we make this generic, not only for scalable vectors? Assuming a dynamic dimension is multiple of a vector size, scalable or not, is useful in general. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should also add a TODO, as this eventually should be a list containing the divisibility information for each dimension... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having the bool option could be beneficial to IREE because IREE may use TensorDynamicDimAnalysis and can use such information to determine if the assumption is true or not. The analysis uses two MLIR analysis and one IREE specific analysis, which may be upstreamable. I seldom work on this area, so I'm not pretty sure about it. I mainly wanna share a data-point about "assume" mechanism that I learned before, which seems related to @dcaballe's comment. TensorDynamicDimAnalysis::TensorDynamicDimAnalysis(Operation *rootOp)
: rootOperation(rootOp) {
solver.load<mlir::dataflow::DeadCodeAnalysis>();
solver.load<mlir::dataflow::IntegerRangeAnalysis>();
solver.load<IREE::Util::IntegerDivisibilityAnalysis>();
} In IREE, we have the hint in IR and we can apply above analysis to retrieve the information. It is IREE specific because the
However, IIUC, what's missing is something like ValueRangeAnalysis (not just for integers) because what you wonder is if the size if multiple of Anyway, I think it is a good start, and I feel that we are missing some analysis to determine if we can make the value true or not. I don't suggest kick in an analysis during vectorization because it could be expensive; my understanding is that it is better to run it once in the beginning of a pass, and we use it vectorize all the ops later. (I'm happy to move the discussion to the issue, if it is better.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Sorry, I over-indexed on "scalable vectors" .
For my needs ATM, equality (rather than divisibility) is sufficient. Do you reckon that we will require divisibility in the future?
We do actually have and use ValueRangeAnalysis for "scalable vectors", see e.g. in Codegen/Transforms/Transforms.cpp in IREE. However, AFAIK, that analysis is costly and in the case of |
||
}; | ||
|
||
LogicalResult | ||
|
@@ -367,10 +380,12 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter, | |
/// Initializes the vectorization state, including the computation of the | ||
/// canonical vector shape for vectorization. | ||
// TODO: Move this to the constructor when we can remove the failure cases. | ||
LogicalResult | ||
VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp, | ||
ArrayRef<int64_t> inputVectorSizes, | ||
ArrayRef<bool> inputScalableVecDims) { | ||
LogicalResult VectorizationState::initState(RewriterBase &rewriter, | ||
LinalgOp linalgOp, | ||
ArrayRef<int64_t> inputVectorSizes, | ||
ArrayRef<bool> inputScalableVecDims, | ||
bool assumeScalableSizes) { | ||
assumeScalableVecSizesMatchDimSize = assumeScalableSizes; | ||
// Initialize the insertion point. | ||
rewriter.setInsertionPoint(linalgOp); | ||
|
||
|
@@ -470,6 +485,21 @@ Value VectorizationState::getOrCreateMaskFor( | |
return Value(); | ||
} | ||
|
||
if (assumeScalableVecSizesMatchDimSize) { | ||
// Given that all _scalable vector sizes_ match the corresponding | ||
// memref/tensor dim sizes, masking can be skipped provided that: | ||
// * all vector sizes corresponding to dynamic dims are scalable. | ||
if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getScalableDims()), | ||
[](auto it) { | ||
return std::get<0>(it) == ShapedType::kDynamic | ||
? std::get<1>(it) | ||
: false; | ||
Comment on lines
+490
to
+492
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Am I misinterpreting this for some reason or does this imply: "if every size in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
})) | ||
LDBG("Masking is not needed for masking map: " << maskingMap << "\n"); | ||
activeMaskCache[maskingMap] = Value(); | ||
return Value(); | ||
} | ||
|
||
// Permute the iteration space value sizes to compute the mask upper bounds. | ||
SmallVector<Value> upperBounds = | ||
applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes)); | ||
|
@@ -2479,7 +2509,8 @@ vectorizeScalableVectorPrecondition(Operation *op, | |
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) || | ||
isa<linalg::MatmulTransposeAOp>(op) || | ||
isa<linalg::DepthwiseConv1DNwcWcOp>(op) || | ||
isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp)); | ||
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) || | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was wondering if there is a particular reason why this wouldn't work for |
||
hasReductionIterator(linalgOp)); | ||
} | ||
|
||
LogicalResult mlir::linalg::vectorizeOpPrecondition( | ||
|
@@ -2535,11 +2566,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) { | |
tensor::InsertSliceOp>(op); | ||
} | ||
|
||
FailureOr<VectorizationResult> | ||
mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, | ||
ArrayRef<int64_t> inputVectorSizes, | ||
ArrayRef<bool> inputScalableVecDims, | ||
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) { | ||
FailureOr<VectorizationResult> mlir::linalg::vectorize( | ||
RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes, | ||
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract, | ||
bool flatten1DDepthwiseConv, bool assumeScalableSizesMultipleOfDim) { | ||
LDBG("Attempting to vectorize:\n" << *op << "\n"); | ||
LDBG("Input vector sizes: "); | ||
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); | ||
|
@@ -2559,7 +2589,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, | |
VectorizationState state(rewriter); | ||
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) { | ||
if (failed(state.initState(rewriter, linalgOp, inputVectorSizes, | ||
inputScalableVecDims))) { | ||
inputScalableVecDims, | ||
assumeScalableSizesMultipleOfDim))) { | ||
LDBG("Vectorization state couldn't be initialized\n"); | ||
return failure(); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have
assumeScalableSizesMultipleOfDim
andgetAssumeScalableSizesMatchDimSize
. Should we have only one?Also, it should be "dim multiple of vector sizes"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion, I went with
assumeDynamicDimsMatchVecSizes
, see this commit. As mentioned in my other comment, for now I am "assuming" equality rather than divisibility. Not sure whether we will need the latter?