-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][vector] Folder: shape_cast(extract) -> extract #146368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1696,59 +1696,71 @@ static bool hasZeroDimVectors(Operation *op) { | |
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType); | ||
} | ||
|
||
/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends | ||
/// 1s, are considered 'broadcastlike'. | ||
static bool isBroadcastLike(Operation *op) { | ||
if (isa<BroadcastOp, SplatOp>(op)) | ||
return true; | ||
|
||
auto shapeCast = dyn_cast<ShapeCastOp>(op); | ||
if (!shapeCast) | ||
return false; | ||
|
||
// Check that it just prepends 1s, like (2,3) -> (1,1,2,3). | ||
// Condition 1: dst has hight rank. | ||
// Condition 2: src shape is a suffix of dst shape. | ||
VectorType srcType = shapeCast.getSourceVectorType(); | ||
ArrayRef<int64_t> srcShape = srcType.getShape(); | ||
uint64_t srcRank = srcType.getRank(); | ||
ArrayRef<int64_t> dstShape = shapeCast.getType().getShape(); | ||
return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape; | ||
} | ||
|
||
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. | ||
static Value foldExtractFromBroadcast(ExtractOp extractOp) { | ||
Operation *defOp = extractOp.getVector().getDefiningOp(); | ||
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp)) | ||
|
||
Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp(); | ||
if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp)) | ||
return Value(); | ||
|
||
Value source = defOp->getOperand(0); | ||
if (extractOp.getType() == source.getType()) | ||
return source; | ||
auto getRank = [](Type type) { | ||
return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank() | ||
: 0; | ||
}; | ||
Value src = broadcastLikeOp->getOperand(0); | ||
|
||
// Replace extract(broadcast(X)) with X | ||
if (extractOp.getType() == src.getType()) | ||
return src; | ||
|
||
// If splat or broadcast from a scalar, just return the source scalar. | ||
unsigned broadcastSrcRank = getRank(source.getType()); | ||
if (broadcastSrcRank == 0 && source.getType() == extractOp.getType()) | ||
return source; | ||
// Get required types and ranks in the chain | ||
// src -> broadcastDst -> dst | ||
auto srcType = llvm::dyn_cast<VectorType>(src.getType()); | ||
auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType()); | ||
unsigned srcRank = srcType ? srcType.getRank() : 0; | ||
unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank(); | ||
unsigned dstRank = dstType ? dstType.getRank() : 0; | ||
|
||
unsigned extractResultRank = getRank(extractOp.getType()); | ||
if (extractResultRank > broadcastSrcRank) | ||
// Cannot do without the broadcast if overall the rank increases. | ||
if (dstRank > srcRank) | ||
return Value(); | ||
// Check that the dimension of the result haven't been broadcasted. | ||
auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType()); | ||
auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType()); | ||
if (extractVecType && broadcastVecType && | ||
extractVecType.getShape() != | ||
broadcastVecType.getShape().take_back(extractResultRank)) | ||
|
||
assert(srcType && "src must be a vector type because of previous checks"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, this doesn't quite agree with: unsigned srcRank = srcType ? srcType.getRank() : 0; There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's the intervening rank check which allows this assertion. Code is like if (extractOp.getType() == src.getType())
return src;
[...]
unsigned srcRank = srcType ? srcType.getRank() : 0;
[...]
if (dstRank > srcRank)
return Value();
[...]
assert(srcType && "src must be a vector type because of previous checks"); Suppose src is scalar at the point of assertion. TBH this is reasoning is probably too complicated, and could be replaced with a
|
||
|
||
ArrayRef<int64_t> srcShape = srcType.getShape(); | ||
if (dstType && dstType.getShape() != srcShape.take_back(dstRank)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a comment explaining what case this is? |
||
return Value(); | ||
|
||
auto broadcastOp = cast<vector::BroadcastOp>(defOp); | ||
int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank(); | ||
// Replace extract(broadcast(X)) with extract(X). | ||
// First, determine the new extraction position. | ||
unsigned deltaOverall = srcRank - dstRank; | ||
unsigned deltaBroadcast = broadcastDstRank - srcRank; | ||
|
||
// Detect all the positions that come from "dim-1" broadcasting. | ||
// These dimensions correspond to "dim-1" broadcasted dims; set the mathching | ||
// extract position to `0` when extracting from the source operand. | ||
llvm::SetVector<int64_t> broadcastedUnitDims = | ||
broadcastOp.computeBroadcastedUnitDims(); | ||
SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition()); | ||
OpBuilder b(extractOp.getContext()); | ||
int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank; | ||
for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i) | ||
if (broadcastedUnitDims.contains(i)) | ||
extractPos[i] = b.getIndexAttr(0); | ||
// `rankDiff` leading dimensions correspond to new broadcasted dims, drop the | ||
// matching extract position when extracting from the source operand. | ||
int64_t rankDiff = broadcastSrcRank - extractResultRank; | ||
extractPos.erase(extractPos.begin(), | ||
std::next(extractPos.begin(), extractPos.size() - rankDiff)); | ||
// OpBuilder is only used as a helper to build an I64ArrayAttr. | ||
auto [staticPos, dynPos] = decomposeMixedValues(extractPos); | ||
SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition(); | ||
SmallVector<OpFoldResult> newPositions(deltaOverall); | ||
IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0); | ||
for (auto [i, size] : llvm::enumerate(srcShape.take_front(deltaOverall))) { | ||
newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast]; | ||
} | ||
auto [staticPos, dynPos] = decomposeMixedValues(newPositions); | ||
extractOp->setOperands( | ||
llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos))); | ||
llvm::to_vector(llvm::concat<Value>(ValueRange(src), dynPos))); | ||
extractOp.setStaticPosition(staticPos); | ||
return extractOp.getResult(); | ||
} | ||
|
@@ -2193,32 +2205,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> { | |
|
||
LogicalResult matchAndRewrite(ExtractOp extractOp, | ||
PatternRewriter &rewriter) const override { | ||
Operation *defOp = extractOp.getVector().getDefiningOp(); | ||
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp)) | ||
return failure(); | ||
|
||
Value source = defOp->getOperand(0); | ||
if (extractOp.getType() == source.getType()) | ||
Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp(); | ||
VectorType outType = dyn_cast<VectorType>(extractOp.getType()); | ||
if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp) || !outType) | ||
return failure(); | ||
auto getRank = [](Type type) { | ||
return llvm::isa<VectorType>(type) | ||
? llvm::cast<VectorType>(type).getRank() | ||
: 0; | ||
}; | ||
unsigned broadcastSrcRank = getRank(source.getType()); | ||
unsigned extractResultRank = getRank(extractOp.getType()); | ||
// We only consider the case where the rank of the source is less than or | ||
// equal to the rank of the extract dst. The other cases are handled in the | ||
// folding patterns. | ||
if (extractResultRank < broadcastSrcRank) | ||
return failure(); | ||
// For scalar result, the input can only be a rank-0 vector, which will | ||
// be handled by the folder. | ||
if (extractResultRank == 0) | ||
|
||
Value source = broadcastLikeOp->getOperand(0); | ||
if (isBroadcastableTo(source.getType(), outType) != | ||
BroadcastableToResult::Success) | ||
return failure(); | ||
|
||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>( | ||
extractOp, extractOp.getType(), source); | ||
rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source); | ||
return success(); | ||
} | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess
ExtractOp
is missing somewhere here?