-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][vector] Refactor WarpOpScfForOp to support unused or swapped forOp results. #147620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
5868390
4c36317
3595f17
ba94ee2
28ef9c9
99c340b
537ca0e
8ecece4
164e9d6
683fad8
2c2703e
f0451dc
6297e47
bf7058a
3690c61
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,7 @@ | |
#include "llvm/ADT/ArrayRef.h" | ||
#include "llvm/ADT/STLExtras.h" | ||
#include "llvm/ADT/SmallVector.h" | ||
#include "llvm/ADT/SmallVectorExtras.h" | ||
|
||
namespace mlir { | ||
namespace xegpu { | ||
|
@@ -876,15 +877,32 @@ void XeGPUSubgroupDistributePass::runOnOperation() { | |
// Step 3: Apply subgroup to workitem distribution patterns. | ||
RewritePatternSet patterns(&getContext()); | ||
xegpu::populateXeGPUSubgroupDistributePatterns(patterns); | ||
// TODO: distributionFn and shuffleFn are not used at this point. | ||
// distributionFn is used by vector distribution patterns to determine the | ||
// distributed vector type for a given vector value. In XeGPU subgroup | ||
// distribution context, we compute this based on lane layout. | ||
auto distributionFn = [](Value val) { | ||
VectorType vecType = dyn_cast<VectorType>(val.getType()); | ||
int64_t vecRank = vecType ? vecType.getRank() : 0; | ||
OpBuilder builder(val.getContext()); | ||
if (vecRank == 0) | ||
return AffineMap::get(val.getContext()); | ||
return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext()); | ||
// Get the layout of the vector type. | ||
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val); | ||
// If no layout is specified, assume the inner most dimension is distributed | ||
// for now. | ||
if (!layout) | ||
return AffineMap::getMultiDimMapWithTargets( | ||
vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assumes 2d, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. At SG level layout is currently only 2d. But upstream distribution makes no such assumption. Only assumption there is only 1 dim is distributed. We need to add more support there in future. |
||
SmallVector<unsigned int> distributedDims; | ||
// Get the distributed dimensions based on the layout. | ||
ArrayRef<int> laneLayout = layout.getLaneLayout().asArrayRef(); | ||
for (unsigned i = 0; i < laneLayout.size(); ++i) { | ||
if (laneLayout[i] > 1) | ||
distributedDims.push_back(i); | ||
} | ||
return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims, | ||
val.getContext()); | ||
}; | ||
// TODO: shuffleFn is not used. | ||
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx, | ||
int64_t warpSz) { return Value(); }; | ||
vector::populatePropagateWarpVectorDistributionPatterns( | ||
|
Uh oh!
There was an error while loading. Please reload this page.