@@ -1600,8 +1600,8 @@ SDValue WebAssemblyTargetLowering::LowerBUILD_VECTOR(SDValue Op,
1600
1600
// TODO: Tune this. For example, lanewise swizzling is very expensive, so
1601
1601
// swizzled lanes should be given greater weight.
1602
1602
1603
- // TODO: Investigate building vectors by shuffling together vectors built by
1604
- // separately specialized means .
1603
+ // TODO: Investigate looping rather than always extracting/replacing specific
1604
+ // lanes to fill gaps .
1605
1605
1606
1606
auto IsConstant = [](const SDValue &V) {
1607
1607
return V.getOpcode () == ISD::Constant || V.getOpcode () == ISD::ConstantFP;
@@ -1632,12 +1632,30 @@ SDValue WebAssemblyTargetLowering::LowerBUILD_VECTOR(SDValue Op,
1632
1632
return std::make_pair (SwizzleSrc, SwizzleIndices);
1633
1633
};
1634
1634
1635
+ // If the lane is extracted from another vector at a constant index, return
1636
+ // that vector. The source vector must not have more lanes than the dest
1637
+ // because the shufflevector indices are in terms of the destination lanes and
1638
+ // would not be able to address the smaller individual source lanes.
1639
+ auto GetShuffleSrc = [&](const SDValue &Lane) {
1640
+ if (Lane->getOpcode () != ISD::EXTRACT_VECTOR_ELT)
1641
+ return SDValue ();
1642
+ if (!isa<ConstantSDNode>(Lane->getOperand (1 ).getNode ()))
1643
+ return SDValue ();
1644
+ if (Lane->getOperand (0 ).getValueType ().getVectorNumElements () >
1645
+ VecT.getVectorNumElements ())
1646
+ return SDValue ();
1647
+ return Lane->getOperand (0 );
1648
+ };
1649
+
1635
1650
using ValueEntry = std::pair<SDValue, size_t >;
1636
1651
SmallVector<ValueEntry, 16 > SplatValueCounts;
1637
1652
1638
1653
using SwizzleEntry = std::pair<std::pair<SDValue, SDValue>, size_t >;
1639
1654
SmallVector<SwizzleEntry, 16 > SwizzleCounts;
1640
1655
1656
+ using ShuffleEntry = std::pair<SDValue, size_t >;
1657
+ SmallVector<ShuffleEntry, 16 > ShuffleCounts;
1658
+
1641
1659
auto AddCount = [](auto &Counts, const auto &Val) {
1642
1660
auto CountIt =
1643
1661
llvm::find_if (Counts, [&Val](auto E) { return E.first == Val; });
@@ -1666,9 +1684,11 @@ SDValue WebAssemblyTargetLowering::LowerBUILD_VECTOR(SDValue Op,
1666
1684
1667
1685
AddCount (SplatValueCounts, Lane);
1668
1686
1669
- if (IsConstant (Lane)) {
1687
+ if (IsConstant (Lane))
1670
1688
NumConstantLanes++;
1671
- } else if (CanSwizzle) {
1689
+ if (auto ShuffleSrc = GetShuffleSrc (Lane))
1690
+ AddCount (ShuffleCounts, ShuffleSrc);
1691
+ if (CanSwizzle) {
1672
1692
auto SwizzleSrcs = GetSwizzleSrcs (I, Lane);
1673
1693
if (SwizzleSrcs.first )
1674
1694
AddCount (SwizzleCounts, SwizzleSrcs);
@@ -1686,18 +1706,81 @@ SDValue WebAssemblyTargetLowering::LowerBUILD_VECTOR(SDValue Op,
1686
1706
std::forward_as_tuple (std::tie (SwizzleSrc, SwizzleIndices),
1687
1707
NumSwizzleLanes) = GetMostCommon (SwizzleCounts);
1688
1708
1709
+ // Shuffles can draw from up to two vectors, so find the two most common
1710
+ // sources.
1711
+ SDValue ShuffleSrc1, ShuffleSrc2;
1712
+ size_t NumShuffleLanes = 0 ;
1713
+ if (ShuffleCounts.size ()) {
1714
+ std::tie (ShuffleSrc1, NumShuffleLanes) = GetMostCommon (ShuffleCounts);
1715
+ ShuffleCounts.erase (std::remove_if (ShuffleCounts.begin (),
1716
+ ShuffleCounts.end (),
1717
+ [&](const auto &Pair) {
1718
+ return Pair.first == ShuffleSrc1;
1719
+ }),
1720
+ ShuffleCounts.end ());
1721
+ }
1722
+ if (ShuffleCounts.size ()) {
1723
+ size_t AdditionalShuffleLanes;
1724
+ std::tie (ShuffleSrc2, AdditionalShuffleLanes) =
1725
+ GetMostCommon (ShuffleCounts);
1726
+ NumShuffleLanes += AdditionalShuffleLanes;
1727
+ }
1728
+
1689
1729
// Predicate returning true if the lane is properly initialized by the
1690
1730
// original instruction
1691
1731
std::function<bool (size_t , const SDValue &)> IsLaneConstructed;
1692
1732
SDValue Result;
1693
- // Prefer swizzles over vector consts over splats
1694
- if (NumSwizzleLanes >= NumSplatLanes && NumSwizzleLanes >= NumConstantLanes) {
1733
+ // Prefer swizzles over shuffles over vector consts over splats
1734
+ if (NumSwizzleLanes >= NumShuffleLanes &&
1735
+ NumSwizzleLanes >= NumConstantLanes && NumSwizzleLanes >= NumSplatLanes) {
1695
1736
Result = DAG.getNode (WebAssemblyISD::SWIZZLE, DL, VecT, SwizzleSrc,
1696
1737
SwizzleIndices);
1697
1738
auto Swizzled = std::make_pair (SwizzleSrc, SwizzleIndices);
1698
1739
IsLaneConstructed = [&, Swizzled](size_t I, const SDValue &Lane) {
1699
1740
return Swizzled == GetSwizzleSrcs (I, Lane);
1700
1741
};
1742
+ } else if (NumShuffleLanes >= NumConstantLanes &&
1743
+ NumShuffleLanes >= NumSplatLanes) {
1744
+ size_t DestLaneSize = VecT.getVectorElementType ().getFixedSizeInBits () / 8 ;
1745
+ size_t DestLaneCount = VecT.getVectorNumElements ();
1746
+ size_t Scale1 = 1 ;
1747
+ size_t Scale2 = 1 ;
1748
+ SDValue Src1 = ShuffleSrc1;
1749
+ SDValue Src2 = ShuffleSrc2 ? ShuffleSrc2 : DAG.getUNDEF (VecT);
1750
+ if (Src1.getValueType () != VecT) {
1751
+ size_t LaneSize =
1752
+ Src1.getValueType ().getVectorElementType ().getFixedSizeInBits () / 8 ;
1753
+ assert (LaneSize > DestLaneSize);
1754
+ Scale1 = LaneSize / DestLaneSize;
1755
+ Src1 = DAG.getBitcast (VecT, Src1);
1756
+ }
1757
+ if (Src2.getValueType () != VecT) {
1758
+ size_t LaneSize =
1759
+ Src2.getValueType ().getVectorElementType ().getFixedSizeInBits () / 8 ;
1760
+ assert (LaneSize > DestLaneSize);
1761
+ Scale2 = LaneSize / DestLaneSize;
1762
+ Src2 = DAG.getBitcast (VecT, Src2);
1763
+ }
1764
+
1765
+ int Mask[16 ];
1766
+ assert (DestLaneCount <= 16 );
1767
+ for (size_t I = 0 ; I < DestLaneCount; ++I) {
1768
+ const SDValue &Lane = Op->getOperand (I);
1769
+ SDValue Src = GetShuffleSrc (Lane);
1770
+ if (Src == ShuffleSrc1) {
1771
+ Mask[I] = Lane->getConstantOperandVal (1 ) * Scale1;
1772
+ } else if (Src && Src == ShuffleSrc2) {
1773
+ Mask[I] = DestLaneCount + Lane->getConstantOperandVal (1 ) * Scale2;
1774
+ } else {
1775
+ Mask[I] = -1 ;
1776
+ }
1777
+ }
1778
+ ArrayRef<int > MaskRef (Mask, DestLaneCount);
1779
+ Result = DAG.getVectorShuffle (VecT, DL, Src1, Src2, MaskRef);
1780
+ IsLaneConstructed = [&](size_t , const SDValue &Lane) {
1781
+ auto Src = GetShuffleSrc (Lane);
1782
+ return Src == ShuffleSrc1 || (Src && Src == ShuffleSrc2);
1783
+ };
1701
1784
} else if (NumConstantLanes >= NumSplatLanes) {
1702
1785
SmallVector<SDValue, 16 > ConstLanes;
1703
1786
for (const SDValue &Lane : Op->op_values ()) {
0 commit comments