-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][linalg] Add support for scalable vectorization of linalg.mmt4d #146531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
487db47
8ef6661
2b6019c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -222,7 +222,8 @@ struct VectorizationState { | |
/// canonical vector shape for vectorization. | ||
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, | ||
ArrayRef<int64_t> inputVectorSizes, | ||
ArrayRef<bool> inputScalableVecDims); | ||
ArrayRef<bool> inputScalableVecDims, | ||
bool assumeDynamicDimsMatchVecSizes = false); | ||
|
||
/// Returns the canonical vector shape used to vectorize the iteration space. | ||
ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; } | ||
|
@@ -331,6 +332,14 @@ struct VectorizationState { | |
/// Global vectorization guard for the incoming rewriter. It's initialized | ||
/// when the vectorization state is initialized. | ||
OpBuilder::InsertionGuard rewriterGuard; | ||
|
||
/// Do all dynamic dims match the corresponding vector sizes? | ||
/// | ||
/// When a dynamic tensor/memref dimension matches the corresponding vector | ||
/// dimension, masking can be safely skipped, despite the presence of dynamic | ||
/// shapes. Use this flag with care and only for cases where you are | ||
/// confident the assumption holds. | ||
bool assumeDynamicDimsMatchVecSizes = false; | ||
}; | ||
|
||
LogicalResult | ||
|
@@ -367,10 +376,12 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter, | |
/// Initializes the vectorization state, including the computation of the | ||
/// canonical vector shape for vectorization. | ||
// TODO: Move this to the constructor when we can remove the failure cases. | ||
LogicalResult | ||
VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp, | ||
ArrayRef<int64_t> inputVectorSizes, | ||
ArrayRef<bool> inputScalableVecDims) { | ||
LogicalResult VectorizationState::initState(RewriterBase &rewriter, | ||
LinalgOp linalgOp, | ||
ArrayRef<int64_t> inputVectorSizes, | ||
ArrayRef<bool> inputScalableVecDims, | ||
bool assumeDimsMatchVec) { | ||
assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec; | ||
// Initialize the insertion point. | ||
rewriter.setInsertionPoint(linalgOp); | ||
|
||
|
@@ -470,6 +481,21 @@ Value VectorizationState::getOrCreateMaskFor( | |
return Value(); | ||
} | ||
|
||
if (assumeDynamicDimsMatchVecSizes) { | ||
// Given that all _scalable vector sizes_ match the corresponding | ||
// memref/tensor dim sizes, masking can be skipped provided that: | ||
// * all vector sizes corresponding to dynamic dims are scalable. | ||
if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getScalableDims()), | ||
[](auto it) { | ||
return std::get<0>(it) == ShapedType::kDynamic | ||
? std::get<1>(it) | ||
: false; | ||
Comment on lines
+490
to
+492
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Am I misinterpreting this for some reason or does this imply: "if every size in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
})) | ||
LDBG("Masking is not needed for masking map: " << maskingMap << "\n"); | ||
activeMaskCache[maskingMap] = Value(); | ||
return Value(); | ||
} | ||
|
||
// Permute the iteration space value sizes to compute the mask upper bounds. | ||
SmallVector<Value> upperBounds = | ||
applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes)); | ||
|
@@ -2479,7 +2505,8 @@ vectorizeScalableVectorPrecondition(Operation *op, | |
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) || | ||
isa<linalg::MatmulTransposeAOp>(op) || | ||
isa<linalg::DepthwiseConv1DNwcWcOp>(op) || | ||
isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp)); | ||
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) || | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was wondering if there is a particular reason why this wouldn't work for |
||
hasReductionIterator(linalgOp)); | ||
} | ||
|
||
LogicalResult mlir::linalg::vectorizeOpPrecondition( | ||
|
@@ -2535,11 +2562,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) { | |
tensor::InsertSliceOp>(op); | ||
} | ||
|
||
FailureOr<VectorizationResult> | ||
mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, | ||
ArrayRef<int64_t> inputVectorSizes, | ||
ArrayRef<bool> inputScalableVecDims, | ||
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) { | ||
FailureOr<VectorizationResult> mlir::linalg::vectorize( | ||
RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes, | ||
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract, | ||
bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes) { | ||
LDBG("Attempting to vectorize:\n" << *op << "\n"); | ||
LDBG("Input vector sizes: "); | ||
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); | ||
|
@@ -2559,7 +2585,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, | |
VectorizationState state(rewriter); | ||
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) { | ||
if (failed(state.initState(rewriter, linalgOp, inputVectorSizes, | ||
inputScalableVecDims))) { | ||
inputScalableVecDims, | ||
assumeDynamicDimsMatchVecSizes))) { | ||
LDBG("Vectorization state couldn't be initialized\n"); | ||
return failure(); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some changes on comments are not relevant, can you revert them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure how that crept it. Fixed in this commit.