@@ -1704,18 +1704,19 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
1704
1704
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1705
1705
LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
1706
1706
PatternRewriter &rewriter) const override {
1707
- auto warpOpYield = cast<gpu::YieldOp>(
1707
+ auto yield = cast<gpu::YieldOp>(
1708
1708
warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
1709
- // Only pick up `ForOp` if it is the last op in the region.
1710
- Operation *lastNode = warpOpYield ->getPrevNode ();
1709
+ // Only pick up forOp if it is the last op in the region.
1710
+ Operation *lastNode = yield ->getPrevNode ();
1711
1711
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1712
1712
if (!forOp)
1713
1713
return failure ();
1714
- // Collect Values that come from the `WarpOp` but are outside the `ForOp`.
1715
- // Those Values need to be returned by the new warp op.
1714
+ // Collect Values that come from the warp op but are outside the forOp.
1715
+ // Those Value needs to be returned by the original warpOp and passed to
1716
+ // the new op.
1716
1717
llvm::SmallSetVector<Value, 32 > escapingValues;
1717
- SmallVector<Type> escapingValueInputTypes ;
1718
- SmallVector<Type> escapingValueDistTypes ;
1718
+ SmallVector<Type> inputTypes ;
1719
+ SmallVector<Type> distTypes ;
1719
1720
mlir::visitUsedValuesDefinedAbove (
1720
1721
forOp.getBodyRegion (), [&](OpOperand *operand) {
1721
1722
Operation *parent = operand->get ().getParentRegion ()->getParentOp ();
@@ -1727,153 +1728,81 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
1727
1728
AffineMap map = distributionMapFn (operand->get ());
1728
1729
distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
1729
1730
}
1730
- escapingValueInputTypes .push_back (operand->get ().getType ());
1731
- escapingValueDistTypes .push_back (distType);
1731
+ inputTypes .push_back (operand->get ().getType ());
1732
+ distTypes .push_back (distType);
1732
1733
}
1733
1734
});
1734
1735
1735
- if (llvm::is_contained (escapingValueDistTypes , Type{}))
1736
+ if (llvm::is_contained (distTypes , Type{}))
1736
1737
return failure ();
1737
- // `WarpOp` can yield two types of values:
1738
- // 1. Values that are not results of the `ForOp`:
1739
- // These values must also be yielded by the new `WarpOp`. Also, we need
1740
- // to record the index mapping for these values to replace them later.
1741
- // 2. Values that are results of the `ForOp`:
1742
- // In this case, we record the index mapping between the `WarpOp` result
1743
- // index and matching `ForOp` result index.
1744
- SmallVector<Value> nonForYieldedValues;
1745
- SmallVector<unsigned > nonForResultIndices;
1746
- llvm::SmallDenseMap<unsigned , unsigned > forResultMapping;
1747
- for (OpOperand &yieldOperand : warpOpYield->getOpOperands ()) {
1748
- // Yielded value is not a result of the forOp.
1749
- if (yieldOperand.get ().getDefiningOp () != forOp.getOperation ()) {
1750
- nonForYieldedValues.push_back (yieldOperand.get ());
1751
- nonForResultIndices.push_back (yieldOperand.getOperandNumber ());
1738
+
1739
+ SmallVector<size_t > newRetIndices;
1740
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1741
+ rewriter, warpOp, escapingValues.getArrayRef (), distTypes,
1742
+ newRetIndices);
1743
+ yield = cast<gpu::YieldOp>(
1744
+ newWarpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
1745
+
1746
+ SmallVector<Value> newOperands;
1747
+ SmallVector<unsigned > resultIdx;
1748
+ // Collect all the outputs coming from the forOp.
1749
+ for (OpOperand &yieldOperand : yield->getOpOperands ()) {
1750
+ if (yieldOperand.get ().getDefiningOp () != forOp.getOperation ())
1752
1751
continue ;
1753
- }
1754
- OpResult forResult = cast<OpResult>(yieldOperand.get ());
1755
- forResultMapping[yieldOperand.getOperandNumber ()] =
1756
- forResult.getResultNumber ();
1752
+ auto forResult = cast<OpResult>(yieldOperand.get ());
1753
+ newOperands.push_back (
1754
+ newWarpOp.getResult (yieldOperand.getOperandNumber ()));
1755
+ yieldOperand.set (forOp.getInitArgs ()[forResult.getResultNumber ()]);
1756
+ resultIdx.push_back (yieldOperand.getOperandNumber ());
1757
1757
}
1758
1758
1759
- // Newly created `WarpOp` will yield values in following order:
1760
- // 1. All init args of the `ForOp`.
1761
- // 2. All escaping values.
1762
- // 3. All non-`ForOp` yielded values.
1763
- SmallVector<Value> newWarpOpYieldValues;
1764
- SmallVector<Type> newWarpOpDistTypes;
1765
- for (auto [i, initArg] : llvm::enumerate (forOp.getInitArgs ())) {
1766
- newWarpOpYieldValues.push_back (initArg);
1767
- // Compute the distributed type for this init arg.
1768
- Type distType = initArg.getType ();
1769
- if (auto vecType = dyn_cast<VectorType>(distType)) {
1770
- AffineMap map = distributionMapFn (initArg);
1771
- distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
1772
- }
1773
- newWarpOpDistTypes.push_back (distType);
1774
- }
1775
- // Insert escaping values and their distributed types.
1776
- newWarpOpYieldValues.insert (newWarpOpYieldValues.end (),
1777
- escapingValues.begin (), escapingValues.end ());
1778
- newWarpOpDistTypes.insert (newWarpOpDistTypes.end (),
1779
- escapingValueDistTypes.begin (),
1780
- escapingValueDistTypes.end ());
1781
- // Next, we insert all non-`ForOp` yielded values and their distributed
1782
- // types. We also create a mapping between the non-`ForOp` yielded value
1783
- // index and the corresponding new `WarpOp` yield value index (needed to
1784
- // update users later).
1785
- llvm::SmallDenseMap<unsigned , unsigned > nonForResultMapping;
1786
- for (auto [i, v] :
1787
- llvm::zip_equal (nonForResultIndices, nonForYieldedValues)) {
1788
- nonForResultMapping[i] = newWarpOpYieldValues.size ();
1789
- newWarpOpYieldValues.push_back (v);
1790
- newWarpOpDistTypes.push_back (warpOp.getResult (i).getType ());
1791
- }
1792
- // Create the new `WarpOp` with the updated yield values and types.
1793
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns (
1794
- rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1795
-
1796
- // Next, we create a new `ForOp` with the init args yielded by the new
1797
- // `WarpOp`.
1798
- const unsigned escapingValuesStartIdx =
1799
- forOp.getInitArgs ().size (); // `ForOp` init args are positioned before
1800
- // escaping values in the new `WarpOp`.
1801
- SmallVector<Value> newForOpOperands;
1802
- for (size_t i = 0 ; i < escapingValuesStartIdx; ++i)
1803
- newForOpOperands.push_back (newWarpOp.getResult (i));
1804
-
1805
- // Create a new `ForOp` outside the new `WarpOp` region.
1806
1759
OpBuilder::InsertionGuard g (rewriter);
1807
1760
rewriter.setInsertionPointAfter (newWarpOp);
1761
+
1762
+ // Create a new for op outside the region with a WarpExecuteOnLane0Op
1763
+ // region inside.
1808
1764
auto newForOp = rewriter.create <scf::ForOp>(
1809
1765
forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
1810
- forOp.getStep (), newForOpOperands);
1811
- // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
1812
- // newly created `ForOp`. This `WarpOp` will contain all ops that were
1813
- // contained within the original `ForOp` body.
1766
+ forOp.getStep (), newOperands);
1814
1767
rewriter.setInsertionPointToStart (newForOp.getBody ());
1815
1768
1816
- SmallVector<Value> innerWarpInput (newForOp.getRegionIterArgs ().begin (),
1817
- newForOp.getRegionIterArgs ().end ());
1818
- SmallVector<Type> innerWarpInputType (forOp.getResultTypes ().begin (),
1819
- forOp.getResultTypes ().end ());
1820
- // Escaping values are forwarded to the inner `WarpOp` as its (additional)
1821
- // arguments. We keep track of the mapping between these values and their
1822
- // argument index in the inner `WarpOp` (to replace users later).
1769
+ SmallVector<Value> warpInput (newForOp.getRegionIterArgs ().begin (),
1770
+ newForOp.getRegionIterArgs ().end ());
1771
+ SmallVector<Type> warpInputType (forOp.getResultTypes ().begin (),
1772
+ forOp.getResultTypes ().end ());
1823
1773
llvm::SmallDenseMap<Value, int64_t > argIndexMapping;
1824
- for (size_t i = escapingValuesStartIdx;
1825
- i < escapingValuesStartIdx + escapingValues.size (); ++i) {
1826
- innerWarpInput.push_back (newWarpOp.getResult (i));
1827
- argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
1828
- innerWarpInputType.size ();
1829
- innerWarpInputType.push_back (
1830
- escapingValueInputTypes[i - escapingValuesStartIdx]);
1774
+ for (auto [i, retIdx] : llvm::enumerate (newRetIndices)) {
1775
+ warpInput.push_back (newWarpOp.getResult (retIdx));
1776
+ argIndexMapping[escapingValues[i]] = warpInputType.size ();
1777
+ warpInputType.push_back (inputTypes[i]);
1831
1778
}
1832
- // Create the inner `WarpOp` with the new input values and types.
1833
1779
auto innerWarp = rewriter.create <WarpExecuteOnLane0Op>(
1834
1780
newWarpOp.getLoc (), newForOp.getResultTypes (), newWarpOp.getLaneid (),
1835
- newWarpOp.getWarpSize (), innerWarpInput, innerWarpInputType );
1781
+ newWarpOp.getWarpSize (), warpInput, warpInputType );
1836
1782
1837
- // Inline the `ForOp` body into the inner `WarpOp` body.
1838
1783
SmallVector<Value> argMapping;
1839
1784
argMapping.push_back (newForOp.getInductionVar ());
1840
- for (Value args : innerWarp.getBody ()->getArguments ())
1785
+ for (Value args : innerWarp.getBody ()->getArguments ()) {
1841
1786
argMapping.push_back (args);
1842
-
1787
+ }
1843
1788
argMapping.resize (forOp.getBody ()->getNumArguments ());
1844
1789
SmallVector<Value> yieldOperands;
1845
1790
for (Value operand : forOp.getBody ()->getTerminator ()->getOperands ())
1846
1791
yieldOperands.push_back (operand);
1847
-
1848
1792
rewriter.eraseOp (forOp.getBody ()->getTerminator ());
1849
1793
rewriter.mergeBlocks (forOp.getBody (), innerWarp.getBody (), argMapping);
1850
-
1851
- // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
1852
- // original `ForOp` results.
1853
1794
rewriter.setInsertionPointToEnd (innerWarp.getBody ());
1854
1795
rewriter.create <gpu::YieldOp>(innerWarp.getLoc (), yieldOperands);
1855
1796
rewriter.setInsertionPointAfter (innerWarp);
1856
- // Insert a scf.yield op at the end of the new `ForOp` body that yields
1857
- // the inner `WarpOp` results.
1858
1797
if (!innerWarp.getResults ().empty ())
1859
1798
rewriter.create <scf::YieldOp>(forOp.getLoc (), innerWarp.getResults ());
1860
-
1861
- // Update the users of original `WarpOp` results that were coming from the
1862
- // original `ForOp` to the corresponding new `ForOp` result.
1863
- for (auto [origIdx, newIdx] : forResultMapping)
1864
- rewriter.replaceAllUsesExcept (warpOp.getResult (origIdx),
1865
- newForOp.getResult (newIdx), newForOp);
1866
- // Similarly, update any users of the `WarpOp` results that were not
1867
- // results of the `ForOp`.
1868
- for (auto [origIdx, newIdx] : nonForResultMapping)
1869
- rewriter.replaceAllUsesWith (warpOp.getResult (origIdx),
1870
- newWarpOp.getResult (newIdx));
1871
- // Remove the original `WarpOp` and `ForOp`, they should not have any uses
1872
- // at this point.
1873
1799
rewriter.eraseOp (forOp);
1874
- rewriter.eraseOp (warpOp);
1875
- // Update any users of escaping values that were forwarded to the
1876
- // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
1800
+ // Replace the warpOp result coming from the original ForOp.
1801
+ for (const auto &res : llvm::enumerate (resultIdx)) {
1802
+ rewriter.replaceAllUsesWith (newWarpOp.getResult (res.value ()),
1803
+ newForOp.getResult (res.index ()));
1804
+ newForOp->setOperand (res.index () + 3 , newWarpOp.getResult (res.value ()));
1805
+ }
1877
1806
newForOp.walk ([&](Operation *op) {
1878
1807
for (OpOperand &operand : op->getOpOperands ()) {
1879
1808
auto it = argIndexMapping.find (operand.get ());
@@ -1883,7 +1812,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
1883
1812
}
1884
1813
});
1885
1814
1886
- // Finally, hoist out any now uniform code from the inner `WarpOp` .
1815
+ // Finally, hoist out any now uniform code from the inner warp op .
1887
1816
mlir::vector::moveScalarUniformCode (innerWarp);
1888
1817
return success ();
1889
1818
}
0 commit comments