-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][vector] Refactor WarpOpScfForOp to support unused or swapped forOp results. #147620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
5868390
4c36317
3595f17
ba94ee2
28ef9c9
99c340b
537ca0e
8ecece4
164e9d6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -17,8 +17,12 @@ | |||||
#include "mlir/IR/AffineExpr.h" | ||||||
#include "mlir/IR/Attributes.h" | ||||||
#include "mlir/IR/BuiltinTypes.h" | ||||||
#include "mlir/IR/Value.h" | ||||||
#include "mlir/Interfaces/SideEffectInterfaces.h" | ||||||
#include "mlir/Support/LLVM.h" | ||||||
#include "mlir/Transforms/RegionUtils.h" | ||||||
#include "llvm/ADT/DenseMap.h" | ||||||
#include "llvm/ADT/STLExtras.h" | ||||||
#include "llvm/ADT/SetVector.h" | ||||||
#include "llvm/ADT/SmallVectorExtras.h" | ||||||
#include "llvm/Support/FormatVariadic.h" | ||||||
|
@@ -1745,19 +1749,18 @@ struct WarpOpScfForOp : public WarpDistributionPattern { | |||||
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} | ||||||
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, | ||||||
PatternRewriter &rewriter) const override { | ||||||
auto yield = cast<gpu::YieldOp>( | ||||||
auto newWarpOpYield = cast<gpu::YieldOp>( | ||||||
warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); | ||||||
// Only pick up forOp if it is the last op in the region. | ||||||
Operation *lastNode = yield->getPrevNode(); | ||||||
Operation *lastNode = newWarpOpYield->getPrevNode(); | ||||||
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode); | ||||||
if (!forOp) | ||||||
return failure(); | ||||||
// Collect Values that come from the warp op but are outside the forOp. | ||||||
// Those Value needs to be returned by the original warpOp and passed to | ||||||
// the new op. | ||||||
// Those Value needs to be returned by the new warp op. | ||||||
llvm::SmallSetVector<Value, 32> escapingValues; | ||||||
SmallVector<Type> inputTypes; | ||||||
SmallVector<Type> distTypes; | ||||||
SmallVector<Type> escapingValueInputTypes; | ||||||
SmallVector<Type> escapingValuedistTypes; | ||||||
mlir::visitUsedValuesDefinedAbove( | ||||||
forOp.getBodyRegion(), [&](OpOperand *operand) { | ||||||
Operation *parent = operand->get().getParentRegion()->getParentOp(); | ||||||
|
@@ -1769,81 +1772,155 @@ struct WarpOpScfForOp : public WarpDistributionPattern { | |||||
AffineMap map = distributionMapFn(operand->get()); | ||||||
distType = getDistributedType(vecType, map, warpOp.getWarpSize()); | ||||||
} | ||||||
inputTypes.push_back(operand->get().getType()); | ||||||
distTypes.push_back(distType); | ||||||
escapingValueInputTypes.push_back(operand->get().getType()); | ||||||
escapingValuedistTypes.push_back(distType); | ||||||
} | ||||||
}); | ||||||
|
||||||
if (llvm::is_contained(distTypes, Type{})) | ||||||
if (llvm::is_contained(escapingValuedistTypes, Type{})) | ||||||
return failure(); | ||||||
// Warp op can yield two types of values: | ||||||
// 1. Values that are not results of the forOp: | ||||||
// These values must also be yielded by the new warp op. Also, we need to | ||||||
// record the index mapping for these values to replace them later. | ||||||
// 2. Values that are results of the forOp: | ||||||
// In this case, we record the index mapping between the warp op result | ||||||
// index and matching forOp result index. | ||||||
SmallVector<Value> nonForYieldedValues; | ||||||
SmallVector<unsigned> nonForResultIndices; | ||||||
DenseMap<unsigned, unsigned> forResultMapping; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can probably do the mapping with some existing tools like a value to value map. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought about using IRMapping. but it has the same amount of book keeping. So I don't see any added benefit. DenseMap does the job and code looks easy to read in my view. But if you can point to some code example with value mapping I will reconsider it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Something like function cloning comes to mind, but if it's the same effort/complexity, it doesn't really matter, I suppose. |
||||||
for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) { | ||||||
// Yielded value is not a result of the forOp. | ||||||
if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) { | ||||||
nonForYieldedValues.push_back(yieldOperand.get()); | ||||||
nonForResultIndices.push_back(yieldOperand.getOperandNumber()); | ||||||
continue; | ||||||
} | ||||||
OpResult forResult = cast<OpResult>(yieldOperand.get()); | ||||||
forResultMapping[yieldOperand.getOperandNumber()] = | ||||||
forResult.getResultNumber(); | ||||||
} | ||||||
|
||||||
SmallVector<size_t> newRetIndices; | ||||||
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( | ||||||
rewriter, warpOp, escapingValues.getArrayRef(), distTypes, | ||||||
newRetIndices); | ||||||
yield = cast<gpu::YieldOp>( | ||||||
// Newly created warp op will yield values in following order: | ||||||
// 1. All init args of the forOp. | ||||||
// 2. All escaping values. | ||||||
// 3. All non-for yielded values. | ||||||
SmallVector<Value> newWarpOpYieldValues; | ||||||
SmallVector<Type> newWarpOpDistTypes; | ||||||
for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) { | ||||||
newWarpOpYieldValues.push_back(initArg); | ||||||
// Compute the distributed type for this init arg. | ||||||
Type distType = initArg.getType(); | ||||||
if (auto vecType = dyn_cast<VectorType>(distType)) { | ||||||
AffineMap map = distributionMapFn(initArg); | ||||||
distType = getDistributedType(vecType, map, warpOp.getWarpSize()); | ||||||
} | ||||||
newWarpOpDistTypes.push_back(distType); | ||||||
} | ||||||
// Insert escaping values and their distributed types. | ||||||
newWarpOpYieldValues.insert(newWarpOpYieldValues.end(), | ||||||
escapingValues.begin(), escapingValues.end()); | ||||||
newWarpOpDistTypes.insert(newWarpOpDistTypes.end(), | ||||||
escapingValuedistTypes.begin(), | ||||||
escapingValuedistTypes.end()); | ||||||
// Next, we insert all non-for yielded values and their distributed types. | ||||||
// We also create a mapping between the non-for yielded value index and the | ||||||
// corresponding new warp op yield value index (needed to update users | ||||||
// later). | ||||||
DenseMap<unsigned, unsigned> warpResultMapping; | ||||||
for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe using llvm::zip is more straight forward? |
||||||
warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size(); | ||||||
newWarpOpYieldValues.push_back(v); | ||||||
newWarpOpDistTypes.push_back( | ||||||
warpOp.getResult(nonForResultIndices[i]).getType()); | ||||||
} | ||||||
// Create the new warp op with the updated yield values and types. | ||||||
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( | ||||||
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); | ||||||
newWarpOpYield = cast<gpu::YieldOp>( | ||||||
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); | ||||||
|
||||||
SmallVector<Value> newOperands; | ||||||
SmallVector<unsigned> resultIdx; | ||||||
// Collect all the outputs coming from the forOp. | ||||||
for (OpOperand &yieldOperand : yield->getOpOperands()) { | ||||||
if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) | ||||||
continue; | ||||||
auto forResult = cast<OpResult>(yieldOperand.get()); | ||||||
newOperands.push_back( | ||||||
newWarpOp.getResult(yieldOperand.getOperandNumber())); | ||||||
yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]); | ||||||
resultIdx.push_back(yieldOperand.getOperandNumber()); | ||||||
} | ||||||
// Next, we create a new for op with the init args yielded by the new | ||||||
// warp op. | ||||||
unsigned escapingValuesStartIdx = | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed. |
||||||
forOp.getInitArgs().size(); // ForOp init args are positioned before | ||||||
// escaping values in the new warp op. | ||||||
SmallVector<Value> newForOpOperands; | ||||||
for (size_t i = 0; i < escapingValuesStartIdx; ++i) | ||||||
newForOpOperands.push_back(newWarpOp.getResult(i)); | ||||||
|
||||||
// Create a new for op outside the new warp op region. | ||||||
OpBuilder::InsertionGuard g(rewriter); | ||||||
rewriter.setInsertionPointAfter(newWarpOp); | ||||||
|
||||||
// Create a new for op outside the region with a WarpExecuteOnLane0Op | ||||||
// region inside. | ||||||
auto newForOp = rewriter.create<scf::ForOp>( | ||||||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), | ||||||
forOp.getStep(), newOperands); | ||||||
forOp.getStep(), newForOpOperands); | ||||||
// Next, we insert a new warp op (called inner warp op) inside the | ||||||
// newly created for op. This warp op will contain all ops that were | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. btw, the comments would be easier to read if they highlight the op names, e.g.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cleaned up the comments. thanks! |
||||||
// contained within the original for op body. | ||||||
rewriter.setInsertionPointToStart(newForOp.getBody()); | ||||||
|
||||||
SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(), | ||||||
newForOp.getRegionIterArgs().end()); | ||||||
SmallVector<Type> warpInputType(forOp.getResultTypes().begin(), | ||||||
forOp.getResultTypes().end()); | ||||||
SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(), | ||||||
newForOp.getRegionIterArgs().end()); | ||||||
SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(), | ||||||
forOp.getResultTypes().end()); | ||||||
// Escaping values are forwarded to the inner warp op as its (additional) | ||||||
// arguments. We keep track of the mapping between these values and their | ||||||
// argument index in the inner warp op (to replcace uses later). | ||||||
llvm::SmallDenseMap<Value, int64_t> argIndexMapping; | ||||||
for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) { | ||||||
warpInput.push_back(newWarpOp.getResult(retIdx)); | ||||||
argIndexMapping[escapingValues[i]] = warpInputType.size(); | ||||||
warpInputType.push_back(inputTypes[i]); | ||||||
for (size_t i = escapingValuesStartIdx; | ||||||
i < escapingValuesStartIdx + escapingValues.size(); ++i) { | ||||||
innerWarpInput.push_back(newWarpOp.getResult(i)); | ||||||
argIndexMapping[escapingValues[i - escapingValuesStartIdx]] = | ||||||
innerWarpInputType.size(); | ||||||
innerWarpInputType.push_back( | ||||||
escapingValueInputTypes[i - escapingValuesStartIdx]); | ||||||
} | ||||||
// Create the inner warp op with the new input values and types. | ||||||
auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>( | ||||||
newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(), | ||||||
newWarpOp.getWarpSize(), warpInput, warpInputType); | ||||||
newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType); | ||||||
|
||||||
// Inline the for op body into the inner warp op body. | ||||||
SmallVector<Value> argMapping; | ||||||
argMapping.push_back(newForOp.getInductionVar()); | ||||||
for (Value args : innerWarp.getBody()->getArguments()) { | ||||||
for (Value args : innerWarp.getBody()->getArguments()) | ||||||
argMapping.push_back(args); | ||||||
} | ||||||
|
||||||
argMapping.resize(forOp.getBody()->getNumArguments()); | ||||||
SmallVector<Value> yieldOperands; | ||||||
for (Value operand : forOp.getBody()->getTerminator()->getOperands()) | ||||||
yieldOperands.push_back(operand); | ||||||
|
||||||
rewriter.eraseOp(forOp.getBody()->getTerminator()); | ||||||
rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); | ||||||
|
||||||
// Insert a gpu yieldOp at the end of the inner warp op body that yields | ||||||
// original forOp results. | ||||||
rewriter.setInsertionPointToEnd(innerWarp.getBody()); | ||||||
rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands); | ||||||
rewriter.setInsertionPointAfter(innerWarp); | ||||||
// Insert a scf.yield op at the end of the new for op body that yields | ||||||
// the inner warp op results. | ||||||
if (!innerWarp.getResults().empty()) | ||||||
rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults()); | ||||||
|
||||||
// Update the users of original warp op results that were coming from the | ||||||
// original forOp to the corresponding new forOp result. | ||||||
for (auto [origIdx, newIdx] : forResultMapping) | ||||||
rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), | ||||||
newForOp.getResult(newIdx), newForOp); | ||||||
// Similarly, update any users of the warp op results that were not | ||||||
// results of the forOp. | ||||||
for (auto [origIdx, newIdx] : warpResultMapping) | ||||||
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), | ||||||
newWarpOp.getResult(newIdx)); | ||||||
// Remove the original warp op and for op, they should not have any uses | ||||||
// at this point. | ||||||
rewriter.eraseOp(forOp); | ||||||
// Replace the warpOp result coming from the original ForOp. | ||||||
for (const auto &res : llvm::enumerate(resultIdx)) { | ||||||
rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()), | ||||||
newForOp.getResult(res.index())); | ||||||
newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value())); | ||||||
} | ||||||
rewriter.eraseOp(warpOp); | ||||||
// Update any users of escaping values that were forwarded to the | ||||||
// inner warp op. These values are now arguments of the inner warp op. | ||||||
newForOp.walk([&](Operation *op) { | ||||||
for (OpOperand &operand : op->getOpOperands()) { | ||||||
auto it = argIndexMapping.find(operand.get()); | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,7 @@ | |
#include "llvm/ADT/ArrayRef.h" | ||
#include "llvm/ADT/STLExtras.h" | ||
#include "llvm/ADT/SmallVector.h" | ||
#include "llvm/ADT/SmallVectorExtras.h" | ||
|
||
namespace mlir { | ||
namespace xegpu { | ||
|
@@ -876,15 +877,32 @@ void XeGPUSubgroupDistributePass::runOnOperation() { | |
// Step 3: Apply subgroup to workitem distribution patterns. | ||
RewritePatternSet patterns(&getContext()); | ||
xegpu::populateXeGPUSubgroupDistributePatterns(patterns); | ||
// TODO: distributionFn and shuffleFn are not used at this point. | ||
// distributionFn is used by vector distribution patterns to determine the | ||
// distributed vector type for a given vector value. In XeGPU subgroup | ||
// distribution context, we compute this based on lane layout. | ||
auto distributionFn = [](Value val) { | ||
VectorType vecType = dyn_cast<VectorType>(val.getType()); | ||
int64_t vecRank = vecType ? vecType.getRank() : 0; | ||
OpBuilder builder(val.getContext()); | ||
if (vecRank == 0) | ||
return AffineMap::get(val.getContext()); | ||
return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext()); | ||
// Get the layout of the vector type. | ||
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val); | ||
// If no layout is specified, assume the inner most dimension is distributed | ||
// for now. | ||
if (!layout) | ||
return AffineMap::getMultiDimMapWithTargets( | ||
vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assumes 2d, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. At SG level layout is currently only 2d. But upstream distribution makes no such assumption. Only assumption there is only 1 dim is distributed. We need to add more support there in future. |
||
SmallVector<unsigned int> distributedDims; | ||
// Get the distributed dimensions based on the layout. | ||
ArrayRef<int> laneLayout = layout.getLaneLayout().asArrayRef(); | ||
for (unsigned i = 0; i < laneLayout.size(); ++i) { | ||
if (laneLayout[i] > 1) | ||
distributedDims.push_back(i); | ||
} | ||
return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims, | ||
val.getContext()); | ||
}; | ||
// TODO: shuffleFn is not used. | ||
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx, | ||
int64_t warpSz) { return Value(); }; | ||
vector::populatePropagateWarpVectorDistributionPatterns( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dist->Dist