Skip to content

[mlir][vector] Refactor WarpOpScfForOp to support unused or swapped forOp results. #147620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 127 additions & 50 deletions mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/FormatVariadic.h"
Expand Down Expand Up @@ -1745,19 +1749,18 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
auto yield = cast<gpu::YieldOp>(
auto newWarpOpYield = cast<gpu::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
// Only pick up forOp if it is the last op in the region.
Operation *lastNode = yield->getPrevNode();
// Only pick up `ForOp` if it is the last op in the region.
Operation *lastNode = newWarpOpYield->getPrevNode();
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
if (!forOp)
return failure();
// Collect Values that come from the warp op but are outside the forOp.
// Those Value needs to be returned by the original warpOp and passed to
// the new op.
// Collect Values that come from the `WarpOp` but are outside the `ForOp`.
// Those Values need to be returned by the new warp op.
llvm::SmallSetVector<Value, 32> escapingValues;
SmallVector<Type> inputTypes;
SmallVector<Type> distTypes;
SmallVector<Type> escapingValueInputTypes;
SmallVector<Type> escapingValuedistTypes;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dist->Dist

mlir::visitUsedValuesDefinedAbove(
forOp.getBodyRegion(), [&](OpOperand *operand) {
Operation *parent = operand->get().getParentRegion()->getParentOp();
Expand All @@ -1769,81 +1772,155 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
AffineMap map = distributionMapFn(operand->get());
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
}
inputTypes.push_back(operand->get().getType());
distTypes.push_back(distType);
escapingValueInputTypes.push_back(operand->get().getType());
escapingValuedistTypes.push_back(distType);
}
});

if (llvm::is_contained(distTypes, Type{}))
if (llvm::is_contained(escapingValuedistTypes, Type{}))
return failure();
// `WarpOp` can yield two types of values:
// 1. Values that are not results of the `ForOp`:
// These values must also be yielded by the new `WarpOp`. Also, we need
// to record the index mapping for these values to replace them later.
// 2. Values that are results of the `ForOp`:
// In this case, we record the index mapping between the `WarpOp` result
// index and matching `ForOp` result index.
SmallVector<Value> nonForYieldedValues;
SmallVector<unsigned> nonForResultIndices;
llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) {
// Yielded value is not a result of the forOp.
if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
nonForYieldedValues.push_back(yieldOperand.get());
nonForResultIndices.push_back(yieldOperand.getOperandNumber());
continue;
}
OpResult forResult = cast<OpResult>(yieldOperand.get());
forResultMapping[yieldOperand.getOperandNumber()] =
forResult.getResultNumber();
}

SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
newRetIndices);
yield = cast<gpu::YieldOp>(
// Newly created `WarpOp` will yield values in following order:
// 1. All init args of the `ForOp`.
// 2. All escaping values.
// 3. All non-`ForOp` yielded values.
SmallVector<Value> newWarpOpYieldValues;
SmallVector<Type> newWarpOpDistTypes;
for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
newWarpOpYieldValues.push_back(initArg);
// Compute the distributed type for this init arg.
Type distType = initArg.getType();
if (auto vecType = dyn_cast<VectorType>(distType)) {
AffineMap map = distributionMapFn(initArg);
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
}
newWarpOpDistTypes.push_back(distType);
}
// Insert escaping values and their distributed types.
newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
escapingValues.begin(), escapingValues.end());
newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
escapingValuedistTypes.begin(),
escapingValuedistTypes.end());
// Next, we insert all non-`ForOp` yielded values and their distributed
// types. We also create a mapping between the non-`ForOp` yielded value
// index and the corresponding new `WarpOp` yield value index (needed to
// update users later).
llvm::SmallDenseMap<unsigned, unsigned> warpResultMapping;
for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe using llvm::zip is more straight forward?

warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(v);
newWarpOpDistTypes.push_back(
warpOp.getResult(nonForResultIndices[i]).getType());
}
// Create the new `WarpOp` with the updated yield values and types.
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
newWarpOpYield = cast<gpu::YieldOp>(
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());

SmallVector<Value> newOperands;
SmallVector<unsigned> resultIdx;
// Collect all the outputs coming from the forOp.
for (OpOperand &yieldOperand : yield->getOpOperands()) {
if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
continue;
auto forResult = cast<OpResult>(yieldOperand.get());
newOperands.push_back(
newWarpOp.getResult(yieldOperand.getOperandNumber()));
yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
resultIdx.push_back(yieldOperand.getOperandNumber());
}
// Next, we create a new `ForOp` with the init args yielded by the new
// `WarpOp`.
const unsigned escapingValuesStartIdx =
forOp.getInitArgs().size(); // `ForOp` init args are positioned before
// escaping values in the new `WarpOp`.
SmallVector<Value> newForOpOperands;
for (size_t i = 0; i < escapingValuesStartIdx; ++i)
newForOpOperands.push_back(newWarpOp.getResult(i));

// Create a new `ForOp` outside the new `WarpOp` region.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);

// Create a new for op outside the region with a WarpExecuteOnLane0Op
// region inside.
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newOperands);
forOp.getStep(), newForOpOperands);
// Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
// newly created `ForOp`. This `WarpOp` will contain all ops that were
// contained within the original `ForOp` body.
rewriter.setInsertionPointToStart(newForOp.getBody());

SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
newForOp.getRegionIterArgs().end());
SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
forOp.getResultTypes().end());
SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
newForOp.getRegionIterArgs().end());
SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
forOp.getResultTypes().end());
// Escaping values are forwarded to the inner `WarpOp` as its (additional)
// arguments. We keep track of the mapping between these values and their
// argument index in the inner `WarpOp` (to replace users later).
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
warpInput.push_back(newWarpOp.getResult(retIdx));
argIndexMapping[escapingValues[i]] = warpInputType.size();
warpInputType.push_back(inputTypes[i]);
for (size_t i = escapingValuesStartIdx;
i < escapingValuesStartIdx + escapingValues.size(); ++i) {
innerWarpInput.push_back(newWarpOp.getResult(i));
argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
innerWarpInputType.size();
innerWarpInputType.push_back(
escapingValueInputTypes[i - escapingValuesStartIdx]);
}
// Create the inner `WarpOp` with the new input values and types.
auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
newWarpOp.getWarpSize(), warpInput, warpInputType);
newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType);

// Inline the `ForOp` body into the inner `WarpOp` body.
SmallVector<Value> argMapping;
argMapping.push_back(newForOp.getInductionVar());
for (Value args : innerWarp.getBody()->getArguments()) {
for (Value args : innerWarp.getBody()->getArguments())
argMapping.push_back(args);
}

argMapping.resize(forOp.getBody()->getNumArguments());
SmallVector<Value> yieldOperands;
for (Value operand : forOp.getBody()->getTerminator()->getOperands())
yieldOperands.push_back(operand);

rewriter.eraseOp(forOp.getBody()->getTerminator());
rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);

// Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
// original `ForOp` results.
rewriter.setInsertionPointToEnd(innerWarp.getBody());
rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
rewriter.setInsertionPointAfter(innerWarp);
// Insert a scf.yield op at the end of the new `ForOp` body that yields
// the inner `WarpOp` results.
if (!innerWarp.getResults().empty())
rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());

// Update the users of original `WarpOp` results that were coming from the
// original `ForOp` to the corresponding new `ForOp` result.
for (auto [origIdx, newIdx] : forResultMapping)
rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
newForOp.getResult(newIdx), newForOp);
// Similarly, update any users of the `WarpOp` results that were not
// results of the `ForOp`.
for (auto [origIdx, newIdx] : warpResultMapping)
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
newWarpOp.getResult(newIdx));
// Remove the original `WarpOp` and `ForOp`, they should not have any uses
// at this point.
rewriter.eraseOp(forOp);
// Replace the warpOp result coming from the original ForOp.
for (const auto &res : llvm::enumerate(resultIdx)) {
rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
newForOp.getResult(res.index()));
newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
}
rewriter.eraseOp(warpOp);
// Update any users of escaping values that were forwarded to the
// inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
newForOp.walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
auto it = argIndexMapping.find(operand.get());
Expand All @@ -1853,7 +1930,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
}
});

// Finally, hoist out any now uniform code from the inner warp op.
// Finally, hoist out any now uniform code from the inner `WarpOp`.
mlir::vector::moveScalarUniformCode(innerWarp);
return success();
}
Expand Down
24 changes: 21 additions & 3 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"

namespace mlir {
namespace xegpu {
Expand Down Expand Up @@ -876,15 +877,32 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
// Step 3: Apply subgroup to workitem distribution patterns.
RewritePatternSet patterns(&getContext());
xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
// TODO: distributionFn and shuffleFn are not used at this point.
// distributionFn is used by vector distribution patterns to determine the
// distributed vector type for a given vector value. In XeGPU subgroup
// distribution context, we compute this based on lane layout.
auto distributionFn = [](Value val) {
VectorType vecType = dyn_cast<VectorType>(val.getType());
int64_t vecRank = vecType ? vecType.getRank() : 0;
OpBuilder builder(val.getContext());
if (vecRank == 0)
return AffineMap::get(val.getContext());
return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
// Get the layout of the vector type.
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val);
// If no layout is specified, assume the inner most dimension is distributed
// for now.
if (!layout)
return AffineMap::getMultiDimMapWithTargets(
vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes 2d, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. At SG level layout is currently only 2d. But upstream distribution makes no such assumption. Only assumption there is only 1 dim is distributed. We need to add more support there in future.

SmallVector<unsigned int> distributedDims;
// Get the distributed dimensions based on the layout.
ArrayRef<int> laneLayout = layout.getLaneLayout().asArrayRef();
for (unsigned i = 0; i < laneLayout.size(); ++i) {
if (laneLayout[i] > 1)
distributedDims.push_back(i);
}
return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
val.getContext());
};
// TODO: shuffleFn is not used.
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
int64_t warpSz) { return Value(); };
vector::populatePropagateWarpVectorDistributionPatterns(
Expand Down
79 changes: 79 additions & 0 deletions mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,85 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
return
}

// -----
// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_for_result(
// CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
// CHECK-PROP: %[[INI0:.*]] = "some_def"() : () -> vector<128xf32>
// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
// CHECK-PROP: gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32>
// CHECK-PROP: }
// CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) {
// CHECK-PROP: %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
// CHECK-PROP: %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32>
// CHECK-PROP: %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32>
// CHECK-PROP: gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32>
// CHECK-PROP: }
// CHECK-PROP: scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32>
// CHECK-PROP: }
// CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
func.func @warp_scf_for_unused_for_result(%arg0: index) {
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
%ini = "some_def"() : () -> (vector<128xf32>)
%ini1 = "some_def"() : () -> (vector<128xf32>)
%3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
%add = arith.addi %arg3, %c1 : index
%1 = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>)
%acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
}
gpu.yield %3#0 : vector<128xf32>
}
"some_use"(%0) : (vector<4xf32>) -> ()
return
}

// -----
// CHECK-PROP-LABEL: func.func @warp_scf_for_swapped_for_results(
// CHECK-PROP: %[[W0:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
// CHECK-PROP-NEXT: %[[INI0:.*]] = "some_def"() : () -> vector<256xf32>
// CHECK-PROP-NEXT: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
// CHECK-PROP-NEXT: %[[INI2:.*]] = "some_def"() : () -> vector<128xf32>
// CHECK-PROP-NEXT: gpu.yield %[[INI0]], %[[INI1]], %[[INI2]] : vector<256xf32>, vector<128xf32>, vector<128xf32>
// CHECK-PROP-NEXT: }
// CHECK-PROP-NEXT: %[[F0:.*]]:3 = scf.for {{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1, %{{.*}} = %[[W0]]#2) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
// CHECK-PROP-NEXT: %[[W1:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} :
// CHECK-PROP-SAME: vector<8xf32>, vector<4xf32>, vector<4xf32>) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
// CHECK-PROP-NEXT: ^bb0(%{{.*}}: vector<256xf32>, %{{.*}}: vector<128xf32>, %{{.*}}: vector<128xf32>):
// CHECK-PROP-NEXT: %[[T3:.*]] = "some_def_1"(%{{.*}}) : (vector<256xf32>) -> vector<256xf32>
// CHECK-PROP-NEXT: %[[T4:.*]] = "some_def_2"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32>
// CHECK-PROP-NEXT: %[[T5:.*]] = "some_def_3"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32>
// CHECK-PROP-NEXT: gpu.yield %[[T3]], %[[T4]], %[[T5]] : vector<256xf32>, vector<128xf32>, vector<128xf32>
// CHECK-PROP-NEXT: }
// CHECK-PROP-NEXT: scf.yield %[[W1]]#0, %[[W1]]#1, %[[W1]]#2 : vector<8xf32>, vector<4xf32>, vector<4xf32>
// CHECK-PROP-NEXT: }
// CHECK-PROP-NEXT: "some_use_1"(%[[F0]]#2) : (vector<4xf32>) -> ()
// CHECK-PROP-NEXT: "some_use_2"(%[[F0]]#1) : (vector<4xf32>) -> ()
// CHECK-PROP-NEXT: "some_use_3"(%[[F0]]#0) : (vector<8xf32>) -> ()
func.func @warp_scf_for_swapped_for_results(%arg0: index) {
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0:3 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>, vector<8xf32>) {
%ini1 = "some_def"() : () -> (vector<256xf32>)
%ini2 = "some_def"() : () -> (vector<128xf32>)
%ini3 = "some_def"() : () -> (vector<128xf32>)
%3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2, %arg6 = %ini3) -> (vector<256xf32>, vector<128xf32>, vector<128xf32>) {
%acc1 = "some_def_1"(%arg4) : (vector<256xf32>) -> (vector<256xf32>)
%acc2 = "some_def_2"(%arg5) : (vector<128xf32>) -> (vector<128xf32>)
%acc3 = "some_def_3"(%arg6) : (vector<128xf32>) -> (vector<128xf32>)
scf.yield %acc1, %acc2, %acc3 : vector<256xf32>, vector<128xf32>, vector<128xf32>
}
gpu.yield %3#2, %3#1, %3#0 : vector<128xf32>, vector<128xf32>, vector<256xf32>
}
"some_use_1"(%0#0) : (vector<4xf32>) -> ()
"some_use_2"(%0#1) : (vector<4xf32>) -> ()
"some_use_3"(%0#2) : (vector<8xf32>) -> ()
return
}

// -----

// CHECK-PROP-LABEL: func @vector_reduction(
Expand Down