Skip to content

Commit 3092b76

Browse files
authored
[mlir][vector] Refactor WarpOpScfForOp to support unused or swapped forOp results. (#147620)
Current implementation generates incorrect code or crashes in the following valid cases. 1. At least one of the for op results are not yielded by the warpOp. Example: ``` %0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) { .... %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) { %1 = ... %acc = .... scf.yield %acc, %1 : vector<128xf32>, vector<128xf32> } gpu.yield %3#0 : vector<128xf32> // %3#1 is not used but can not be removed as dead code (loop carried). } "some_use"(%0) : (vector<4xf32>) -> () return ``` 2. Enclosing warpOp yields the forOp results in different order compared to the forOp results. Example: ``` %0:3 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>, vector<8xf32>) { .... %3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2, %arg6 = %ini3) -> (vector<256xf32>, vector<128xf32>, vector<128xf32>) { ..... scf.yield %acc1, %acc2, %acc3 : vector<256xf32>, vector<128xf32>, vector<128xf32> } gpu.yield %3#2, %3#1, %3#0 : vector<128xf32>, vector<128xf32>, vector<256xf32> // swapped order } "some_use_1"(%0#0) : (vector<4xf32>) -> () "some_use_2"(%0#1) : (vector<4xf32>) -> () "some_use_3"(%0#2) : (vector<8xf32>) -> () ```
1 parent a0895d0 commit 3092b76

File tree

3 files changed

+221
-54
lines changed

3 files changed

+221
-54
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 122 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,19 +1704,18 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17041704
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
17051705
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
17061706
PatternRewriter &rewriter) const override {
1707-
auto yield = cast<gpu::YieldOp>(
1707+
auto warpOpYield = cast<gpu::YieldOp>(
17081708
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1709-
// Only pick up forOp if it is the last op in the region.
1710-
Operation *lastNode = yield->getPrevNode();
1709+
// Only pick up `ForOp` if it is the last op in the region.
1710+
Operation *lastNode = warpOpYield->getPrevNode();
17111711
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
17121712
if (!forOp)
17131713
return failure();
1714-
// Collect Values that come from the warp op but are outside the forOp.
1715-
// Those Value needs to be returned by the original warpOp and passed to
1716-
// the new op.
1714+
// Collect Values that come from the `WarpOp` but are outside the `ForOp`.
1715+
// Those Values need to be returned by the new warp op.
17171716
llvm::SmallSetVector<Value, 32> escapingValues;
1718-
SmallVector<Type> inputTypes;
1719-
SmallVector<Type> distTypes;
1717+
SmallVector<Type> escapingValueInputTypes;
1718+
SmallVector<Type> escapingValueDistTypes;
17201719
mlir::visitUsedValuesDefinedAbove(
17211720
forOp.getBodyRegion(), [&](OpOperand *operand) {
17221721
Operation *parent = operand->get().getParentRegion()->getParentOp();
@@ -1728,81 +1727,153 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17281727
AffineMap map = distributionMapFn(operand->get());
17291728
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
17301729
}
1731-
inputTypes.push_back(operand->get().getType());
1732-
distTypes.push_back(distType);
1730+
escapingValueInputTypes.push_back(operand->get().getType());
1731+
escapingValueDistTypes.push_back(distType);
17331732
}
17341733
});
17351734

1736-
if (llvm::is_contained(distTypes, Type{}))
1735+
if (llvm::is_contained(escapingValueDistTypes, Type{}))
17371736
return failure();
1738-
1739-
SmallVector<size_t> newRetIndices;
1740-
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1741-
rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1742-
newRetIndices);
1743-
yield = cast<gpu::YieldOp>(
1744-
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1745-
1746-
SmallVector<Value> newOperands;
1747-
SmallVector<unsigned> resultIdx;
1748-
// Collect all the outputs coming from the forOp.
1749-
for (OpOperand &yieldOperand : yield->getOpOperands()) {
1750-
if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1737+
// `WarpOp` can yield two types of values:
1738+
// 1. Values that are not results of the `ForOp`:
1739+
// These values must also be yielded by the new `WarpOp`. Also, we need
1740+
// to record the index mapping for these values to replace them later.
1741+
// 2. Values that are results of the `ForOp`:
1742+
// In this case, we record the index mapping between the `WarpOp` result
1743+
// index and matching `ForOp` result index.
1744+
SmallVector<Value> nonForYieldedValues;
1745+
SmallVector<unsigned> nonForResultIndices;
1746+
llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
1747+
for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1748+
// Yielded value is not a result of the forOp.
1749+
if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
1750+
nonForYieldedValues.push_back(yieldOperand.get());
1751+
nonForResultIndices.push_back(yieldOperand.getOperandNumber());
17511752
continue;
1752-
auto forResult = cast<OpResult>(yieldOperand.get());
1753-
newOperands.push_back(
1754-
newWarpOp.getResult(yieldOperand.getOperandNumber()));
1755-
yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1756-
resultIdx.push_back(yieldOperand.getOperandNumber());
1753+
}
1754+
OpResult forResult = cast<OpResult>(yieldOperand.get());
1755+
forResultMapping[yieldOperand.getOperandNumber()] =
1756+
forResult.getResultNumber();
17571757
}
17581758

1759+
// Newly created `WarpOp` will yield values in following order:
1760+
// 1. All init args of the `ForOp`.
1761+
// 2. All escaping values.
1762+
// 3. All non-`ForOp` yielded values.
1763+
SmallVector<Value> newWarpOpYieldValues;
1764+
SmallVector<Type> newWarpOpDistTypes;
1765+
for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
1766+
newWarpOpYieldValues.push_back(initArg);
1767+
// Compute the distributed type for this init arg.
1768+
Type distType = initArg.getType();
1769+
if (auto vecType = dyn_cast<VectorType>(distType)) {
1770+
AffineMap map = distributionMapFn(initArg);
1771+
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1772+
}
1773+
newWarpOpDistTypes.push_back(distType);
1774+
}
1775+
// Insert escaping values and their distributed types.
1776+
newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
1777+
escapingValues.begin(), escapingValues.end());
1778+
newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
1779+
escapingValueDistTypes.begin(),
1780+
escapingValueDistTypes.end());
1781+
// Next, we insert all non-`ForOp` yielded values and their distributed
1782+
// types. We also create a mapping between the non-`ForOp` yielded value
1783+
// index and the corresponding new `WarpOp` yield value index (needed to
1784+
// update users later).
1785+
llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
1786+
for (auto [i, v] :
1787+
llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
1788+
nonForResultMapping[i] = newWarpOpYieldValues.size();
1789+
newWarpOpYieldValues.push_back(v);
1790+
newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
1791+
}
1792+
// Create the new `WarpOp` with the updated yield values and types.
1793+
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
1794+
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1795+
1796+
// Next, we create a new `ForOp` with the init args yielded by the new
1797+
// `WarpOp`.
1798+
const unsigned escapingValuesStartIdx =
1799+
forOp.getInitArgs().size(); // `ForOp` init args are positioned before
1800+
// escaping values in the new `WarpOp`.
1801+
SmallVector<Value> newForOpOperands;
1802+
for (size_t i = 0; i < escapingValuesStartIdx; ++i)
1803+
newForOpOperands.push_back(newWarpOp.getResult(i));
1804+
1805+
// Create a new `ForOp` outside the new `WarpOp` region.
17591806
OpBuilder::InsertionGuard g(rewriter);
17601807
rewriter.setInsertionPointAfter(newWarpOp);
1761-
1762-
// Create a new for op outside the region with a WarpExecuteOnLane0Op
1763-
// region inside.
17641808
auto newForOp = rewriter.create<scf::ForOp>(
17651809
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1766-
forOp.getStep(), newOperands);
1810+
forOp.getStep(), newForOpOperands);
1811+
// Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
1812+
// newly created `ForOp`. This `WarpOp` will contain all ops that were
1813+
// contained within the original `ForOp` body.
17671814
rewriter.setInsertionPointToStart(newForOp.getBody());
17681815

1769-
SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
1770-
newForOp.getRegionIterArgs().end());
1771-
SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
1772-
forOp.getResultTypes().end());
1816+
SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
1817+
newForOp.getRegionIterArgs().end());
1818+
SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
1819+
forOp.getResultTypes().end());
1820+
// Escaping values are forwarded to the inner `WarpOp` as its (additional)
1821+
// arguments. We keep track of the mapping between these values and their
1822+
// argument index in the inner `WarpOp` (to replace users later).
17731823
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1774-
for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
1775-
warpInput.push_back(newWarpOp.getResult(retIdx));
1776-
argIndexMapping[escapingValues[i]] = warpInputType.size();
1777-
warpInputType.push_back(inputTypes[i]);
1824+
for (size_t i = escapingValuesStartIdx;
1825+
i < escapingValuesStartIdx + escapingValues.size(); ++i) {
1826+
innerWarpInput.push_back(newWarpOp.getResult(i));
1827+
argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
1828+
innerWarpInputType.size();
1829+
innerWarpInputType.push_back(
1830+
escapingValueInputTypes[i - escapingValuesStartIdx]);
17781831
}
1832+
// Create the inner `WarpOp` with the new input values and types.
17791833
auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
17801834
newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1781-
newWarpOp.getWarpSize(), warpInput, warpInputType);
1835+
newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType);
17821836

1837+
// Inline the `ForOp` body into the inner `WarpOp` body.
17831838
SmallVector<Value> argMapping;
17841839
argMapping.push_back(newForOp.getInductionVar());
1785-
for (Value args : innerWarp.getBody()->getArguments()) {
1840+
for (Value args : innerWarp.getBody()->getArguments())
17861841
argMapping.push_back(args);
1787-
}
1842+
17881843
argMapping.resize(forOp.getBody()->getNumArguments());
17891844
SmallVector<Value> yieldOperands;
17901845
for (Value operand : forOp.getBody()->getTerminator()->getOperands())
17911846
yieldOperands.push_back(operand);
1847+
17921848
rewriter.eraseOp(forOp.getBody()->getTerminator());
17931849
rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1850+
1851+
// Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
1852+
// original `ForOp` results.
17941853
rewriter.setInsertionPointToEnd(innerWarp.getBody());
17951854
rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
17961855
rewriter.setInsertionPointAfter(innerWarp);
1856+
// Insert a scf.yield op at the end of the new `ForOp` body that yields
1857+
// the inner `WarpOp` results.
17971858
if (!innerWarp.getResults().empty())
17981859
rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1860+
1861+
// Update the users of original `WarpOp` results that were coming from the
1862+
// original `ForOp` to the corresponding new `ForOp` result.
1863+
for (auto [origIdx, newIdx] : forResultMapping)
1864+
rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
1865+
newForOp.getResult(newIdx), newForOp);
1866+
// Similarly, update any users of the `WarpOp` results that were not
1867+
// results of the `ForOp`.
1868+
for (auto [origIdx, newIdx] : nonForResultMapping)
1869+
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
1870+
newWarpOp.getResult(newIdx));
1871+
// Remove the original `WarpOp` and `ForOp`, they should not have any uses
1872+
// at this point.
17991873
rewriter.eraseOp(forOp);
1800-
// Replace the warpOp result coming from the original ForOp.
1801-
for (const auto &res : llvm::enumerate(resultIdx)) {
1802-
rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
1803-
newForOp.getResult(res.index()));
1804-
newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1805-
}
1874+
rewriter.eraseOp(warpOp);
1875+
// Update any users of escaping values that were forwarded to the
1876+
// inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
18061877
newForOp.walk([&](Operation *op) {
18071878
for (OpOperand &operand : op->getOpOperands()) {
18081879
auto it = argIndexMapping.find(operand.get());
@@ -1812,7 +1883,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18121883
}
18131884
});
18141885

1815-
// Finally, hoist out any now uniform code from the inner warp op.
1886+
// Finally, hoist out any now uniform code from the inner `WarpOp`.
18161887
mlir::vector::moveScalarUniformCode(innerWarp);
18171888
return success();
18181889
}

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -876,15 +876,32 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
876876
// Step 3: Apply subgroup to workitem distribution patterns.
877877
RewritePatternSet patterns(&getContext());
878878
xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
879-
// TODO: distributionFn and shuffleFn are not used at this point.
879+
// distributionFn is used by vector distribution patterns to determine the
880+
// distributed vector type for a given vector value. In XeGPU subgroup
881+
// distribution context, we compute this based on lane layout.
880882
auto distributionFn = [](Value val) {
881883
VectorType vecType = dyn_cast<VectorType>(val.getType());
882884
int64_t vecRank = vecType ? vecType.getRank() : 0;
883-
OpBuilder builder(val.getContext());
884885
if (vecRank == 0)
885886
return AffineMap::get(val.getContext());
886-
return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
887+
// Get the layout of the vector type.
888+
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val);
889+
// If no layout is specified, assume the inner most dimension is distributed
890+
// for now.
891+
if (!layout)
892+
return AffineMap::getMultiDimMapWithTargets(
893+
vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
894+
SmallVector<unsigned int> distributedDims;
895+
// Get the distributed dimensions based on the layout.
896+
ArrayRef<int> laneLayout = layout.getLaneLayout().asArrayRef();
897+
for (unsigned i = 0; i < laneLayout.size(); ++i) {
898+
if (laneLayout[i] > 1)
899+
distributedDims.push_back(i);
900+
}
901+
return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
902+
val.getContext());
887903
};
904+
// TODO: shuffleFn is not used.
888905
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
889906
int64_t warpSz) { return Value(); };
890907
vector::populatePropagateWarpVectorDistributionPatterns(

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,85 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
584584
return
585585
}
586586

587+
// -----
588+
// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_for_result(
589+
// CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
590+
// CHECK-PROP: %[[INI0:.*]] = "some_def"() : () -> vector<128xf32>
591+
// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
592+
// CHECK-PROP: gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32>
593+
// CHECK-PROP: }
594+
// CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) {
595+
// CHECK-PROP: %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
596+
// CHECK-PROP: %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32>
597+
// CHECK-PROP: %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32>
598+
// CHECK-PROP: gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32>
599+
// CHECK-PROP: }
600+
// CHECK-PROP: scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32>
601+
// CHECK-PROP: }
602+
// CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
603+
func.func @warp_scf_for_unused_for_result(%arg0: index) {
604+
%c128 = arith.constant 128 : index
605+
%c1 = arith.constant 1 : index
606+
%c0 = arith.constant 0 : index
607+
%0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
608+
%ini = "some_def"() : () -> (vector<128xf32>)
609+
%ini1 = "some_def"() : () -> (vector<128xf32>)
610+
%3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
611+
%add = arith.addi %arg3, %c1 : index
612+
%1 = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>)
613+
%acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
614+
scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
615+
}
616+
gpu.yield %3#0 : vector<128xf32>
617+
}
618+
"some_use"(%0) : (vector<4xf32>) -> ()
619+
return
620+
}
621+
622+
// -----
623+
// CHECK-PROP-LABEL: func.func @warp_scf_for_swapped_for_results(
624+
// CHECK-PROP: %[[W0:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
625+
// CHECK-PROP-NEXT: %[[INI0:.*]] = "some_def"() : () -> vector<256xf32>
626+
// CHECK-PROP-NEXT: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
627+
// CHECK-PROP-NEXT: %[[INI2:.*]] = "some_def"() : () -> vector<128xf32>
628+
// CHECK-PROP-NEXT: gpu.yield %[[INI0]], %[[INI1]], %[[INI2]] : vector<256xf32>, vector<128xf32>, vector<128xf32>
629+
// CHECK-PROP-NEXT: }
630+
// CHECK-PROP-NEXT: %[[F0:.*]]:3 = scf.for {{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1, %{{.*}} = %[[W0]]#2) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
631+
// CHECK-PROP-NEXT: %[[W1:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} :
632+
// CHECK-PROP-SAME: vector<8xf32>, vector<4xf32>, vector<4xf32>) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
633+
// CHECK-PROP-NEXT: ^bb0(%{{.*}}: vector<256xf32>, %{{.*}}: vector<128xf32>, %{{.*}}: vector<128xf32>):
634+
// CHECK-PROP-NEXT: %[[T3:.*]] = "some_def_1"(%{{.*}}) : (vector<256xf32>) -> vector<256xf32>
635+
// CHECK-PROP-NEXT: %[[T4:.*]] = "some_def_2"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32>
636+
// CHECK-PROP-NEXT: %[[T5:.*]] = "some_def_3"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32>
637+
// CHECK-PROP-NEXT: gpu.yield %[[T3]], %[[T4]], %[[T5]] : vector<256xf32>, vector<128xf32>, vector<128xf32>
638+
// CHECK-PROP-NEXT: }
639+
// CHECK-PROP-NEXT: scf.yield %[[W1]]#0, %[[W1]]#1, %[[W1]]#2 : vector<8xf32>, vector<4xf32>, vector<4xf32>
640+
// CHECK-PROP-NEXT: }
641+
// CHECK-PROP-NEXT: "some_use_1"(%[[F0]]#2) : (vector<4xf32>) -> ()
642+
// CHECK-PROP-NEXT: "some_use_2"(%[[F0]]#1) : (vector<4xf32>) -> ()
643+
// CHECK-PROP-NEXT: "some_use_3"(%[[F0]]#0) : (vector<8xf32>) -> ()
644+
func.func @warp_scf_for_swapped_for_results(%arg0: index) {
645+
%c128 = arith.constant 128 : index
646+
%c1 = arith.constant 1 : index
647+
%c0 = arith.constant 0 : index
648+
%0:3 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>, vector<8xf32>) {
649+
%ini1 = "some_def"() : () -> (vector<256xf32>)
650+
%ini2 = "some_def"() : () -> (vector<128xf32>)
651+
%ini3 = "some_def"() : () -> (vector<128xf32>)
652+
%3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2, %arg6 = %ini3) -> (vector<256xf32>, vector<128xf32>, vector<128xf32>) {
653+
%acc1 = "some_def_1"(%arg4) : (vector<256xf32>) -> (vector<256xf32>)
654+
%acc2 = "some_def_2"(%arg5) : (vector<128xf32>) -> (vector<128xf32>)
655+
%acc3 = "some_def_3"(%arg6) : (vector<128xf32>) -> (vector<128xf32>)
656+
scf.yield %acc1, %acc2, %acc3 : vector<256xf32>, vector<128xf32>, vector<128xf32>
657+
}
658+
gpu.yield %3#2, %3#1, %3#0 : vector<128xf32>, vector<128xf32>, vector<256xf32>
659+
}
660+
"some_use_1"(%0#0) : (vector<4xf32>) -> ()
661+
"some_use_2"(%0#1) : (vector<4xf32>) -> ()
662+
"some_use_3"(%0#2) : (vector<8xf32>) -> ()
663+
return
664+
}
665+
587666
// -----
588667

589668
// CHECK-PROP-LABEL: func @vector_reduction(

0 commit comments

Comments
 (0)