-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[Vector] Add canonicalization for select(pred, true, false) -> broadcast(pred) #147934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2913,13 +2913,74 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> { | |||||||||
return success(); | ||||||||||
} | ||||||||||
}; | ||||||||||
|
||||||||||
/// true: vector | ||||||||||
/// false: vector | ||||||||||
/// pred: i1 | ||||||||||
/// | ||||||||||
/// select(pred, true, false) -> broadcast(pred) | ||||||||||
/// select(pred, false, true) -> broadcast(not(pred)) | ||||||||||
/// | ||||||||||
/// Ideally, this would be a canonicalization pattern on arith::SelectOp, but | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that this paragraph makes more sense near where you register this canonicalization. |
||||||||||
/// we cannot have arith depending on vector. Also, it would implicitly force | ||||||||||
/// users only using arith and vector dialect to use vector dialect. Instead, | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo?
Suggested change
|
||||||||||
/// this canonicalization only runs if vector::BroadcastOp was a registered | ||||||||||
/// operation. | ||||||||||
struct FoldI1SelectToBroadcast : public OpRewritePattern<arith::SelectOp> { | ||||||||||
using OpRewritePattern::OpRewritePattern; | ||||||||||
|
||||||||||
LogicalResult matchAndRewrite(arith::SelectOp selectOp, | ||||||||||
PatternRewriter &rewriter) const override { | ||||||||||
auto vecType = dyn_cast<VectorType>(selectOp.getType()); | ||||||||||
if (!vecType || !vecType.getElementType().isInteger(1)) | ||||||||||
return failure(); | ||||||||||
|
||||||||||
// Vector conditionals do not need broadcast and are already handled by | ||||||||||
// the arith.select folder. | ||||||||||
Value pred = selectOp.getCondition(); | ||||||||||
if (isa<VectorType>(pred.getType())) | ||||||||||
return failure(); | ||||||||||
|
||||||||||
std::optional<int64_t> trueInt = | ||||||||||
getConstantIntValue(selectOp.getTrueValue()); | ||||||||||
std::optional<int64_t> falseInt = | ||||||||||
getConstantIntValue(selectOp.getFalseValue()); | ||||||||||
if (!trueInt || !falseInt) | ||||||||||
return failure(); | ||||||||||
|
||||||||||
// Redundant selects are already handled by arith.select canonicalizations. | ||||||||||
if (trueInt.value() == falseInt.value()) { | ||||||||||
return failure(); | ||||||||||
} | ||||||||||
|
||||||||||
// The only remaining possibilities are: | ||||||||||
// | ||||||||||
// select(pred, true, false) | ||||||||||
// select(pred, false, true) | ||||||||||
|
||||||||||
// select(pred, false, true) -> select(not(pred), true, false) | ||||||||||
if (trueInt.value() == 0) { | ||||||||||
Value one = rewriter.create<arith::ConstantIntOp>( | ||||||||||
selectOp.getLoc(), /*value=*/1, /*width=*/1); | ||||||||||
pred = rewriter.create<arith::XOrIOp>(selectOp.getLoc(), pred, one); | ||||||||||
} | ||||||||||
|
||||||||||
/// select(pred, true, false) -> broadcast(pred) | ||||||||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>( | ||||||||||
selectOp, vecType.clone(rewriter.getI1Type()), pred); | ||||||||||
return success(); | ||||||||||
|
||||||||||
return failure(); | ||||||||||
Comment on lines
+2972
to
+2973
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dead return
Suggested change
|
||||||||||
} | ||||||||||
}; | ||||||||||
|
||||||||||
} // namespace | ||||||||||
|
||||||||||
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, | ||||||||||
MLIRContext *context) { | ||||||||||
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by | ||||||||||
// calling `populateCastAwayVectorLeadingOneDimPatterns` | ||||||||||
results.add<BroadcastFolder>(context); | ||||||||||
results.add<BroadcastFolder, FoldI1SelectToBroadcast>(context); | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have some precedent for canon patterns hooked up to the op they produce? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is quite counter-intuitive :( But not that uncommon: llvm-project/mlir/lib/Dialect/Vector/IR/VectorOps.cpp Lines 5357 to 5360 in 5f1141d
|
||||||||||
} | ||||||||||
|
||||||||||
//===----------------------------------------------------------------------===// | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.