-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][vector] Refactor WarpOpScfForOp to support unused or swapped forOp results. #147620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5868390
4c36317
3595f17
ba94ee2
28ef9c9
99c340b
537ca0e
8ecece4
164e9d6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,8 +17,12 @@ | |
#include "mlir/IR/AffineExpr.h" | ||
#include "mlir/IR/Attributes.h" | ||
#include "mlir/IR/BuiltinTypes.h" | ||
#include "mlir/IR/Value.h" | ||
#include "mlir/Interfaces/SideEffectInterfaces.h" | ||
#include "mlir/Support/LLVM.h" | ||
#include "mlir/Transforms/RegionUtils.h" | ||
#include "llvm/ADT/DenseMap.h" | ||
#include "llvm/ADT/STLExtras.h" | ||
#include "llvm/ADT/SetVector.h" | ||
#include "llvm/ADT/SmallVectorExtras.h" | ||
#include "llvm/Support/FormatVariadic.h" | ||
|
@@ -1745,19 +1749,18 @@ struct WarpOpScfForOp : public WarpDistributionPattern { | |
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} | ||
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, | ||
PatternRewriter &rewriter) const override { | ||
auto yield = cast<gpu::YieldOp>( | ||
auto newWarpOpYield = cast<gpu::YieldOp>( | ||
warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); | ||
// Only pick up forOp if it is the last op in the region. | ||
Operation *lastNode = yield->getPrevNode(); | ||
// Only pick up `ForOp` if it is the last op in the region. | ||
Operation *lastNode = newWarpOpYield->getPrevNode(); | ||
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode); | ||
if (!forOp) | ||
return failure(); | ||
// Collect Values that come from the warp op but are outside the forOp. | ||
// Those Value needs to be returned by the original warpOp and passed to | ||
// the new op. | ||
// Collect Values that come from the `WarpOp` but are outside the `ForOp`. | ||
// Those Values need to be returned by the new warp op. | ||
llvm::SmallSetVector<Value, 32> escapingValues; | ||
SmallVector<Type> inputTypes; | ||
SmallVector<Type> distTypes; | ||
SmallVector<Type> escapingValueInputTypes; | ||
SmallVector<Type> escapingValuedistTypes; | ||
mlir::visitUsedValuesDefinedAbove( | ||
forOp.getBodyRegion(), [&](OpOperand *operand) { | ||
Operation *parent = operand->get().getParentRegion()->getParentOp(); | ||
|
@@ -1769,81 +1772,155 @@ struct WarpOpScfForOp : public WarpDistributionPattern { | |
AffineMap map = distributionMapFn(operand->get()); | ||
distType = getDistributedType(vecType, map, warpOp.getWarpSize()); | ||
} | ||
inputTypes.push_back(operand->get().getType()); | ||
distTypes.push_back(distType); | ||
escapingValueInputTypes.push_back(operand->get().getType()); | ||
escapingValuedistTypes.push_back(distType); | ||
} | ||
}); | ||
|
||
if (llvm::is_contained(distTypes, Type{})) | ||
if (llvm::is_contained(escapingValuedistTypes, Type{})) | ||
return failure(); | ||
// `WarpOp` can yield two types of values: | ||
// 1. Values that are not results of the `ForOp`: | ||
// These values must also be yielded by the new `WarpOp`. Also, we need | ||
// to record the index mapping for these values to replace them later. | ||
// 2. Values that are results of the `ForOp`: | ||
// In this case, we record the index mapping between the `WarpOp` result | ||
// index and matching `ForOp` result index. | ||
SmallVector<Value> nonForYieldedValues; | ||
SmallVector<unsigned> nonForResultIndices; | ||
llvm::SmallDenseMap<unsigned, unsigned> forResultMapping; | ||
for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) { | ||
// Yielded value is not a result of the forOp. | ||
if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) { | ||
nonForYieldedValues.push_back(yieldOperand.get()); | ||
nonForResultIndices.push_back(yieldOperand.getOperandNumber()); | ||
continue; | ||
} | ||
OpResult forResult = cast<OpResult>(yieldOperand.get()); | ||
forResultMapping[yieldOperand.getOperandNumber()] = | ||
forResult.getResultNumber(); | ||
} | ||
|
||
SmallVector<size_t> newRetIndices; | ||
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( | ||
rewriter, warpOp, escapingValues.getArrayRef(), distTypes, | ||
newRetIndices); | ||
yield = cast<gpu::YieldOp>( | ||
// Newly created `WarpOp` will yield values in following order: | ||
// 1. All init args of the `ForOp`. | ||
// 2. All escaping values. | ||
// 3. All non-`ForOp` yielded values. | ||
SmallVector<Value> newWarpOpYieldValues; | ||
SmallVector<Type> newWarpOpDistTypes; | ||
for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) { | ||
newWarpOpYieldValues.push_back(initArg); | ||
// Compute the distributed type for this init arg. | ||
Type distType = initArg.getType(); | ||
if (auto vecType = dyn_cast<VectorType>(distType)) { | ||
AffineMap map = distributionMapFn(initArg); | ||
distType = getDistributedType(vecType, map, warpOp.getWarpSize()); | ||
} | ||
newWarpOpDistTypes.push_back(distType); | ||
} | ||
// Insert escaping values and their distributed types. | ||
newWarpOpYieldValues.insert(newWarpOpYieldValues.end(), | ||
escapingValues.begin(), escapingValues.end()); | ||
newWarpOpDistTypes.insert(newWarpOpDistTypes.end(), | ||
escapingValuedistTypes.begin(), | ||
escapingValuedistTypes.end()); | ||
// Next, we insert all non-`ForOp` yielded values and their distributed | ||
// types. We also create a mapping between the non-`ForOp` yielded value | ||
// index and the corresponding new `WarpOp` yield value index (needed to | ||
// update users later). | ||
llvm::SmallDenseMap<unsigned, unsigned> warpResultMapping; | ||
for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe using llvm::zip is more straight forward? |
||
warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size(); | ||
newWarpOpYieldValues.push_back(v); | ||
newWarpOpDistTypes.push_back( | ||
warpOp.getResult(nonForResultIndices[i]).getType()); | ||
} | ||
// Create the new `WarpOp` with the updated yield values and types. | ||
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( | ||
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); | ||
newWarpOpYield = cast<gpu::YieldOp>( | ||
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); | ||
|
||
SmallVector<Value> newOperands; | ||
SmallVector<unsigned> resultIdx; | ||
// Collect all the outputs coming from the forOp. | ||
for (OpOperand &yieldOperand : yield->getOpOperands()) { | ||
if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) | ||
continue; | ||
auto forResult = cast<OpResult>(yieldOperand.get()); | ||
newOperands.push_back( | ||
newWarpOp.getResult(yieldOperand.getOperandNumber())); | ||
yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]); | ||
resultIdx.push_back(yieldOperand.getOperandNumber()); | ||
} | ||
// Next, we create a new `ForOp` with the init args yielded by the new | ||
// `WarpOp`. | ||
const unsigned escapingValuesStartIdx = | ||
forOp.getInitArgs().size(); // `ForOp` init args are positioned before | ||
// escaping values in the new `WarpOp`. | ||
SmallVector<Value> newForOpOperands; | ||
for (size_t i = 0; i < escapingValuesStartIdx; ++i) | ||
newForOpOperands.push_back(newWarpOp.getResult(i)); | ||
|
||
// Create a new `ForOp` outside the new `WarpOp` region. | ||
OpBuilder::InsertionGuard g(rewriter); | ||
rewriter.setInsertionPointAfter(newWarpOp); | ||
|
||
// Create a new for op outside the region with a WarpExecuteOnLane0Op | ||
// region inside. | ||
auto newForOp = rewriter.create<scf::ForOp>( | ||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), | ||
forOp.getStep(), newOperands); | ||
forOp.getStep(), newForOpOperands); | ||
// Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the | ||
// newly created `ForOp`. This `WarpOp` will contain all ops that were | ||
// contained within the original `ForOp` body. | ||
rewriter.setInsertionPointToStart(newForOp.getBody()); | ||
|
||
SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(), | ||
newForOp.getRegionIterArgs().end()); | ||
SmallVector<Type> warpInputType(forOp.getResultTypes().begin(), | ||
forOp.getResultTypes().end()); | ||
SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(), | ||
newForOp.getRegionIterArgs().end()); | ||
SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(), | ||
forOp.getResultTypes().end()); | ||
// Escaping values are forwarded to the inner `WarpOp` as its (additional) | ||
// arguments. We keep track of the mapping between these values and their | ||
// argument index in the inner `WarpOp` (to replace users later). | ||
llvm::SmallDenseMap<Value, int64_t> argIndexMapping; | ||
for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) { | ||
warpInput.push_back(newWarpOp.getResult(retIdx)); | ||
argIndexMapping[escapingValues[i]] = warpInputType.size(); | ||
warpInputType.push_back(inputTypes[i]); | ||
for (size_t i = escapingValuesStartIdx; | ||
i < escapingValuesStartIdx + escapingValues.size(); ++i) { | ||
innerWarpInput.push_back(newWarpOp.getResult(i)); | ||
argIndexMapping[escapingValues[i - escapingValuesStartIdx]] = | ||
innerWarpInputType.size(); | ||
innerWarpInputType.push_back( | ||
escapingValueInputTypes[i - escapingValuesStartIdx]); | ||
} | ||
// Create the inner `WarpOp` with the new input values and types. | ||
auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>( | ||
newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(), | ||
newWarpOp.getWarpSize(), warpInput, warpInputType); | ||
newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType); | ||
|
||
// Inline the `ForOp` body into the inner `WarpOp` body. | ||
SmallVector<Value> argMapping; | ||
argMapping.push_back(newForOp.getInductionVar()); | ||
for (Value args : innerWarp.getBody()->getArguments()) { | ||
for (Value args : innerWarp.getBody()->getArguments()) | ||
argMapping.push_back(args); | ||
} | ||
|
||
argMapping.resize(forOp.getBody()->getNumArguments()); | ||
SmallVector<Value> yieldOperands; | ||
for (Value operand : forOp.getBody()->getTerminator()->getOperands()) | ||
yieldOperands.push_back(operand); | ||
|
||
rewriter.eraseOp(forOp.getBody()->getTerminator()); | ||
rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); | ||
|
||
// Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields | ||
// original `ForOp` results. | ||
rewriter.setInsertionPointToEnd(innerWarp.getBody()); | ||
rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands); | ||
rewriter.setInsertionPointAfter(innerWarp); | ||
// Insert a scf.yield op at the end of the new `ForOp` body that yields | ||
// the inner `WarpOp` results. | ||
if (!innerWarp.getResults().empty()) | ||
rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults()); | ||
|
||
// Update the users of original `WarpOp` results that were coming from the | ||
// original `ForOp` to the corresponding new `ForOp` result. | ||
for (auto [origIdx, newIdx] : forResultMapping) | ||
rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), | ||
newForOp.getResult(newIdx), newForOp); | ||
// Similarly, update any users of the `WarpOp` results that were not | ||
// results of the `ForOp`. | ||
for (auto [origIdx, newIdx] : warpResultMapping) | ||
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), | ||
newWarpOp.getResult(newIdx)); | ||
// Remove the original `WarpOp` and `ForOp`, they should not have any uses | ||
// at this point. | ||
rewriter.eraseOp(forOp); | ||
// Replace the warpOp result coming from the original ForOp. | ||
for (const auto &res : llvm::enumerate(resultIdx)) { | ||
rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()), | ||
newForOp.getResult(res.index())); | ||
newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value())); | ||
} | ||
rewriter.eraseOp(warpOp); | ||
// Update any users of escaping values that were forwarded to the | ||
// inner `WarpOp`. These values are now arguments of the inner `WarpOp`. | ||
newForOp.walk([&](Operation *op) { | ||
for (OpOperand &operand : op->getOpOperands()) { | ||
auto it = argIndexMapping.find(operand.get()); | ||
|
@@ -1853,7 +1930,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { | |
} | ||
}); | ||
|
||
// Finally, hoist out any now uniform code from the inner warp op. | ||
// Finally, hoist out any now uniform code from the inner `WarpOp`. | ||
mlir::vector::moveScalarUniformCode(innerWarp); | ||
return success(); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,7 @@ | |
#include "llvm/ADT/ArrayRef.h" | ||
#include "llvm/ADT/STLExtras.h" | ||
#include "llvm/ADT/SmallVector.h" | ||
#include "llvm/ADT/SmallVectorExtras.h" | ||
|
||
namespace mlir { | ||
namespace xegpu { | ||
|
@@ -876,15 +877,32 @@ void XeGPUSubgroupDistributePass::runOnOperation() { | |
// Step 3: Apply subgroup to workitem distribution patterns. | ||
RewritePatternSet patterns(&getContext()); | ||
xegpu::populateXeGPUSubgroupDistributePatterns(patterns); | ||
// TODO: distributionFn and shuffleFn are not used at this point. | ||
// distributionFn is used by vector distribution patterns to determine the | ||
// distributed vector type for a given vector value. In XeGPU subgroup | ||
// distribution context, we compute this based on lane layout. | ||
auto distributionFn = [](Value val) { | ||
VectorType vecType = dyn_cast<VectorType>(val.getType()); | ||
int64_t vecRank = vecType ? vecType.getRank() : 0; | ||
OpBuilder builder(val.getContext()); | ||
if (vecRank == 0) | ||
return AffineMap::get(val.getContext()); | ||
return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext()); | ||
// Get the layout of the vector type. | ||
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val); | ||
// If no layout is specified, assume the inner most dimension is distributed | ||
// for now. | ||
if (!layout) | ||
return AffineMap::getMultiDimMapWithTargets( | ||
vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assumes 2d, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. At SG level layout is currently only 2d. But upstream distribution makes no such assumption. Only assumption there is only 1 dim is distributed. We need to add more support there in future. |
||
SmallVector<unsigned int> distributedDims; | ||
// Get the distributed dimensions based on the layout. | ||
ArrayRef<int> laneLayout = layout.getLaneLayout().asArrayRef(); | ||
for (unsigned i = 0; i < laneLayout.size(); ++i) { | ||
if (laneLayout[i] > 1) | ||
distributedDims.push_back(i); | ||
} | ||
return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims, | ||
val.getContext()); | ||
}; | ||
// TODO: shuffleFn is not used. | ||
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx, | ||
int64_t warpSz) { return Value(); }; | ||
vector::populatePropagateWarpVectorDistributionPatterns( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dist->Dist