Skip to content

[mlir][vector] Refactor WarpOpScfForOp to support unused or swapped forOp results. #147620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 124 additions & 47 deletions mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/FormatVariadic.h"
Expand Down Expand Up @@ -1745,19 +1749,18 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
auto yield = cast<gpu::YieldOp>(
auto newWarpOpYield = cast<gpu::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
// Only pick up forOp if it is the last op in the region.
Operation *lastNode = yield->getPrevNode();
Operation *lastNode = newWarpOpYield->getPrevNode();
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
if (!forOp)
return failure();
// Collect Values that come from the warp op but are outside the forOp.
// Those Value needs to be returned by the original warpOp and passed to
// the new op.
// Those Value needs to be returned by the new warp op.
llvm::SmallSetVector<Value, 32> escapingValues;
SmallVector<Type> inputTypes;
SmallVector<Type> distTypes;
SmallVector<Type> escapingValueInputTypes;
SmallVector<Type> escapingValuedistTypes;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dist->Dist

mlir::visitUsedValuesDefinedAbove(
forOp.getBodyRegion(), [&](OpOperand *operand) {
Operation *parent = operand->get().getParentRegion()->getParentOp();
Expand All @@ -1769,81 +1772,155 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
AffineMap map = distributionMapFn(operand->get());
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
}
inputTypes.push_back(operand->get().getType());
distTypes.push_back(distType);
escapingValueInputTypes.push_back(operand->get().getType());
escapingValuedistTypes.push_back(distType);
}
});

if (llvm::is_contained(distTypes, Type{}))
if (llvm::is_contained(escapingValuedistTypes, Type{}))
return failure();
// Warp op can yield two types of values:
// 1. Values that are not results of the forOp:
// These values must also be yielded by the new warp op. Also, we need to
// record the index mapping for these values to replace them later.
// 2. Values that are results of the forOp:
// In this case, we record the index mapping between the warp op result
// index and matching forOp result index.
SmallVector<Value> nonForYieldedValues;
SmallVector<unsigned> nonForResultIndices;
DenseMap<unsigned, unsigned> forResultMapping;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can probably do the mapping with some existing tools like a value to value map.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about using IRMapping. but it has the same amount of book keeping. So I don't see any added benefit. DenseMap does the job and code looks easy to read in my view.

But if you can point to some code example with value mapping I will reconsider it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like function cloning comes to mind, but if it's the same effort/complexity, it doesn't really matter, I suppose.

for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) {
// Yielded value is not a result of the forOp.
if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
nonForYieldedValues.push_back(yieldOperand.get());
nonForResultIndices.push_back(yieldOperand.getOperandNumber());
continue;
}
OpResult forResult = cast<OpResult>(yieldOperand.get());
forResultMapping[yieldOperand.getOperandNumber()] =
forResult.getResultNumber();
}

SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
newRetIndices);
yield = cast<gpu::YieldOp>(
// Newly created warp op will yield values in following order:
// 1. All init args of the forOp.
// 2. All escaping values.
// 3. All non-for yielded values.
SmallVector<Value> newWarpOpYieldValues;
SmallVector<Type> newWarpOpDistTypes;
for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
newWarpOpYieldValues.push_back(initArg);
// Compute the distributed type for this init arg.
Type distType = initArg.getType();
if (auto vecType = dyn_cast<VectorType>(distType)) {
AffineMap map = distributionMapFn(initArg);
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
}
newWarpOpDistTypes.push_back(distType);
}
// Insert escaping values and their distributed types.
newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
escapingValues.begin(), escapingValues.end());
newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
escapingValuedistTypes.begin(),
escapingValuedistTypes.end());
// Next, we insert all non-for yielded values and their distributed types.
// We also create a mapping between the non-for yielded value index and the
// corresponding new warp op yield value index (needed to update users
// later).
DenseMap<unsigned, unsigned> warpResultMapping;
for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe using llvm::zip is more straight forward?

warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(v);
newWarpOpDistTypes.push_back(
warpOp.getResult(nonForResultIndices[i]).getType());
}
// Create the new warp op with the updated yield values and types.
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
newWarpOpYield = cast<gpu::YieldOp>(
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());

SmallVector<Value> newOperands;
SmallVector<unsigned> resultIdx;
// Collect all the outputs coming from the forOp.
for (OpOperand &yieldOperand : yield->getOpOperands()) {
if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
continue;
auto forResult = cast<OpResult>(yieldOperand.get());
newOperands.push_back(
newWarpOp.getResult(yieldOperand.getOperandNumber()));
yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
resultIdx.push_back(yieldOperand.getOperandNumber());
}
// Next, we create a new for op with the init args yielded by the new
// warp op.
unsigned escapingValuesStartIdx =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
unsigned escapingValuesStartIdx =
const unsigned escapingValuesStartIdx =

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

forOp.getInitArgs().size(); // ForOp init args are positioned before
// escaping values in the new warp op.
SmallVector<Value> newForOpOperands;
for (size_t i = 0; i < escapingValuesStartIdx; ++i)
newForOpOperands.push_back(newWarpOp.getResult(i));

// Create a new for op outside the new warp op region.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);

// Create a new for op outside the region with a WarpExecuteOnLane0Op
// region inside.
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newOperands);
forOp.getStep(), newForOpOperands);
// Next, we insert a new warp op (called inner warp op) inside the
// newly created for op. This warp op will contain all ops that were
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, the comments would be easier to read if they highlight the op names, e.g.

Suggested change
// newly created for op. This warp op will contain all ops that were
// newly created `ForOp`. This warp op will contain all ops that were

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleaned up the comments. thanks!

// contained within the original for op body.
rewriter.setInsertionPointToStart(newForOp.getBody());

SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
newForOp.getRegionIterArgs().end());
SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
forOp.getResultTypes().end());
SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
newForOp.getRegionIterArgs().end());
SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
forOp.getResultTypes().end());
// Escaping values are forwarded to the inner warp op as its (additional)
// arguments. We keep track of the mapping between these values and their
// argument index in the inner warp op (to replcace uses later).
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
warpInput.push_back(newWarpOp.getResult(retIdx));
argIndexMapping[escapingValues[i]] = warpInputType.size();
warpInputType.push_back(inputTypes[i]);
for (size_t i = escapingValuesStartIdx;
i < escapingValuesStartIdx + escapingValues.size(); ++i) {
innerWarpInput.push_back(newWarpOp.getResult(i));
argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
innerWarpInputType.size();
innerWarpInputType.push_back(
escapingValueInputTypes[i - escapingValuesStartIdx]);
}
// Create the inner warp op with the new input values and types.
auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
newWarpOp.getWarpSize(), warpInput, warpInputType);
newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType);

// Inline the for op body into the inner warp op body.
SmallVector<Value> argMapping;
argMapping.push_back(newForOp.getInductionVar());
for (Value args : innerWarp.getBody()->getArguments()) {
for (Value args : innerWarp.getBody()->getArguments())
argMapping.push_back(args);
}

argMapping.resize(forOp.getBody()->getNumArguments());
SmallVector<Value> yieldOperands;
for (Value operand : forOp.getBody()->getTerminator()->getOperands())
yieldOperands.push_back(operand);

rewriter.eraseOp(forOp.getBody()->getTerminator());
rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);

// Insert a gpu yieldOp at the end of the inner warp op body that yields
// original forOp results.
rewriter.setInsertionPointToEnd(innerWarp.getBody());
rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
rewriter.setInsertionPointAfter(innerWarp);
// Insert a scf.yield op at the end of the new for op body that yields
// the inner warp op results.
if (!innerWarp.getResults().empty())
rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());

// Update the users of original warp op results that were coming from the
// original forOp to the corresponding new forOp result.
for (auto [origIdx, newIdx] : forResultMapping)
rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
newForOp.getResult(newIdx), newForOp);
// Similarly, update any users of the warp op results that were not
// results of the forOp.
for (auto [origIdx, newIdx] : warpResultMapping)
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
newWarpOp.getResult(newIdx));
// Remove the original warp op and for op, they should not have any uses
// at this point.
rewriter.eraseOp(forOp);
// Replace the warpOp result coming from the original ForOp.
for (const auto &res : llvm::enumerate(resultIdx)) {
rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
newForOp.getResult(res.index()));
newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
}
rewriter.eraseOp(warpOp);
// Update any users of escaping values that were forwarded to the
// inner warp op. These values are now arguments of the inner warp op.
newForOp.walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
auto it = argIndexMapping.find(operand.get());
Expand Down
24 changes: 21 additions & 3 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"

namespace mlir {
namespace xegpu {
Expand Down Expand Up @@ -876,15 +877,32 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
// Step 3: Apply subgroup to workitem distribution patterns.
RewritePatternSet patterns(&getContext());
xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
// TODO: distributionFn and shuffleFn are not used at this point.
// distributionFn is used by vector distribution patterns to determine the
// distributed vector type for a given vector value. In XeGPU subgroup
// distribution context, we compute this based on lane layout.
auto distributionFn = [](Value val) {
VectorType vecType = dyn_cast<VectorType>(val.getType());
int64_t vecRank = vecType ? vecType.getRank() : 0;
OpBuilder builder(val.getContext());
if (vecRank == 0)
return AffineMap::get(val.getContext());
return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
// Get the layout of the vector type.
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val);
// If no layout is specified, assume the inner most dimension is distributed
// for now.
if (!layout)
return AffineMap::getMultiDimMapWithTargets(
vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes 2d, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. At SG level layout is currently only 2d. But upstream distribution makes no such assumption. Only assumption there is only 1 dim is distributed. We need to add more support there in future.

SmallVector<unsigned int> distributedDims;
// Get the distributed dimensions based on the layout.
ArrayRef<int> laneLayout = layout.getLaneLayout().asArrayRef();
for (unsigned i = 0; i < laneLayout.size(); ++i) {
if (laneLayout[i] > 1)
distributedDims.push_back(i);
}
return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
val.getContext());
};
// TODO: shuffleFn is not used.
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
int64_t warpSz) { return Value(); };
vector::populatePropagateWarpVectorDistributionPatterns(
Expand Down
79 changes: 79 additions & 0 deletions mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,85 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
return
}

// -----
// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_for_result(
// CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
// CHECK-PROP: %[[INI0:.*]] = "some_def"() : () -> vector<128xf32>
// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
// CHECK-PROP: gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32>
// CHECK-PROP: }
// CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) {
// CHECK-PROP: %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
// CHECK-PROP: %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32>
// CHECK-PROP: %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32>
// CHECK-PROP: gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32>
// CHECK-PROP: }
// CHECK-PROP: scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32>
// CHECK-PROP: }
// CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
func.func @warp_scf_for_unused_for_result(%arg0: index) {
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
%ini = "some_def"() : () -> (vector<128xf32>)
%ini1 = "some_def"() : () -> (vector<128xf32>)
%3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
%add = arith.addi %arg3, %c1 : index
%1 = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>)
%acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
}
gpu.yield %3#0 : vector<128xf32>
}
"some_use"(%0) : (vector<4xf32>) -> ()
return
}

// -----
// CHECK-PROP-LABEL: func.func @warp_scf_for_swapped_for_results(
// CHECK-PROP: %[[W0:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
// CHECK-PROP-NEXT: %[[INI0:.*]] = "some_def"() : () -> vector<256xf32>
// CHECK-PROP-NEXT: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
// CHECK-PROP-NEXT: %[[INI2:.*]] = "some_def"() : () -> vector<128xf32>
// CHECK-PROP-NEXT: gpu.yield %[[INI0]], %[[INI1]], %[[INI2]] : vector<256xf32>, vector<128xf32>, vector<128xf32>
// CHECK-PROP-NEXT: }
// CHECK-PROP-NEXT: %[[F0:.*]]:3 = scf.for {{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1, %{{.*}} = %[[W0]]#2) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
// CHECK-PROP-NEXT: %[[W1:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} :
// CHECK-PROP-SAME: vector<8xf32>, vector<4xf32>, vector<4xf32>) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
// CHECK-PROP-NEXT: ^bb0(%{{.*}}: vector<256xf32>, %{{.*}}: vector<128xf32>, %{{.*}}: vector<128xf32>):
// CHECK-PROP-NEXT: %[[T3:.*]] = "some_def_1"(%{{.*}}) : (vector<256xf32>) -> vector<256xf32>
// CHECK-PROP-NEXT: %[[T4:.*]] = "some_def_2"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32>
// CHECK-PROP-NEXT: %[[T5:.*]] = "some_def_3"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32>
// CHECK-PROP-NEXT: gpu.yield %[[T3]], %[[T4]], %[[T5]] : vector<256xf32>, vector<128xf32>, vector<128xf32>
// CHECK-PROP-NEXT: }
// CHECK-PROP-NEXT: scf.yield %[[W1]]#0, %[[W1]]#1, %[[W1]]#2 : vector<8xf32>, vector<4xf32>, vector<4xf32>
// CHECK-PROP-NEXT: }
// CHECK-PROP-NEXT: "some_use_1"(%[[F0]]#2) : (vector<4xf32>) -> ()
// CHECK-PROP-NEXT: "some_use_2"(%[[F0]]#1) : (vector<4xf32>) -> ()
// CHECK-PROP-NEXT: "some_use_3"(%[[F0]]#0) : (vector<8xf32>) -> ()
func.func @warp_scf_for_swapped_for_results(%arg0: index) {
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0:3 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>, vector<8xf32>) {
%ini1 = "some_def"() : () -> (vector<256xf32>)
%ini2 = "some_def"() : () -> (vector<128xf32>)
%ini3 = "some_def"() : () -> (vector<128xf32>)
%3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2, %arg6 = %ini3) -> (vector<256xf32>, vector<128xf32>, vector<128xf32>) {
%acc1 = "some_def_1"(%arg4) : (vector<256xf32>) -> (vector<256xf32>)
%acc2 = "some_def_2"(%arg5) : (vector<128xf32>) -> (vector<128xf32>)
%acc3 = "some_def_3"(%arg6) : (vector<128xf32>) -> (vector<128xf32>)
scf.yield %acc1, %acc2, %acc3 : vector<256xf32>, vector<128xf32>, vector<128xf32>
}
gpu.yield %3#2, %3#1, %3#0 : vector<128xf32>, vector<128xf32>, vector<256xf32>
}
"some_use_1"(%0#0) : (vector<4xf32>) -> ()
"some_use_2"(%0#1) : (vector<4xf32>) -> ()
"some_use_3"(%0#2) : (vector<8xf32>) -> ()
return
}

// -----

// CHECK-PROP-LABEL: func @vector_reduction(
Expand Down