Skip to content

Commit 6bb7d24

Browse files
[mlir][Linalg] Add a first vectorization pattern for conv1d in NWCxWCF format.
This revision uses the newly refactored StructuredGenerator to create a simple vectorization for conv1d_nwc_wcf. Note that the pattern is not specific to the op and is technically not even specific to the ConvolutionOpInterface (modulo minor details related to dilations and strides). The overall design follows the same ideas as the lowering of vector::ContractionOp -> vector::OuterProduct: it seeks to be minimally complex, composable and extensible while avoiding inference analysis. Instead, we metaprogram the maps/indexings we expect and we match against them. This is just a first stab and still needs to be evaluated for performance. Other tradeoffs are possible that should be explored. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D111894
1 parent 9e7b730 commit 6bb7d24

File tree

6 files changed

+360
-26
lines changed

6 files changed

+360
-26
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,15 @@ bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer);
4141
//===----------------------------------------------------------------------===//
4242
using LinalgLoops = SmallVector<Operation *, 4>;
4343

44-
/// Populates patterns for vectorization of all ConvN-D ops.
44+
/// [DEPRECATED] Populates patterns for vectorization of all ConvN-D ops.
4545
void populateConvVectorizationPatterns(
4646
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
4747
ArrayRef<int64_t> tileSizes);
4848

49+
/// Populates patterns for vectorizing convolution ops.
50+
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns,
51+
PatternBenefit benefit = 1);
52+
4953
/// Populate patterns that convert `ElementwiseMappable` ops to linalg
5054
/// parallel loops.
5155
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);

mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
namespace mlir {
2727

28-
class PatternRewriter;
28+
class OpBuilder;
2929

3030
/// Tests whether the given maps describe a row major matmul. The test is
3131
/// permutation-invariant. Note that this only checks the affine maps from an
@@ -161,8 +161,8 @@ class StructuredGenerator {
161161
Win() : IteratorType(getWindowIteratorTypeName()) {}
162162
};
163163

164-
StructuredGenerator(PatternRewriter &rewriter, StructuredOpInterface op)
165-
: rewriter(rewriter), ctx(op.getContext()), loc(op.getLoc()),
164+
StructuredGenerator(OpBuilder &builder, StructuredOpInterface op)
165+
: builder(builder), ctx(op.getContext()), loc(op.getLoc()),
166166
iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
167167

168168
bool iters(ArrayRef<IteratorType> its) {
@@ -181,7 +181,7 @@ class StructuredGenerator {
181181
}
182182

183183
protected:
184-
PatternRewriter &rewriter;
184+
OpBuilder &builder;
185185
MLIRContext *ctx;
186186
Location loc;
187187
ArrayAttr iterators;

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 215 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2121
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2222
#include "mlir/Dialect/Vector/VectorOps.h"
23+
#include "mlir/Dialect/Vector/VectorTransforms.h"
2324
#include "mlir/IR/AffineExpr.h"
2425
#include "mlir/IR/Matchers.h"
2526
#include "mlir/IR/PatternMatch.h"
@@ -44,6 +45,12 @@ using llvm::dbgs;
4445
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
4546
#define LDBG(X) LLVM_DEBUG(DBGS() << X)
4647

48+
// Forward declarations.
49+
static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
50+
SmallVectorImpl<Value> &newResults);
51+
static FailureOr<Operation *>
52+
vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp);
53+
4754
/// Return the unique instance of OpType in `block` if it is indeed unique.
4855
/// Return null if none or more than 1 instances exist.
4956
template <typename OpType>
@@ -147,7 +154,7 @@ static Operation *matchLinalgReduction(OpOperand *outputOperand) {
147154
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
148155
unsigned outputPos =
149156
outputOperand->getOperandNumber() - linalgOp.getNumInputs();
150-
// Only single combiner operatios are supported for now.
157+
// Only single combiner operations are supported for now.
151158
SmallVector<Operation *, 4> combinerOps;
152159
if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
153160
combinerOps.size() != 1)
@@ -575,6 +582,11 @@ LogicalResult vectorizeAsLinalgGeneric(
575582
return success();
576583
}
577584

585+
/// Helper function to vectorize a `linalgOp` with contraction semantics in a
586+
/// generic fashion.
587+
/// This helper is needed atm because the truly generic implementation requires
588+
/// good vector.multi_reduce folding patterns that are currently NYI.
589+
// TODO: drop reliance on a specific pattern.
578590
static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
579591
SmallVectorImpl<Value> &newResults) {
580592
assert(isaContractionOpInterface(linalgOp) &&
@@ -664,6 +676,11 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
664676
return success();
665677
if (isaContractionOpInterface(linalgOp))
666678
return success();
679+
// TODO: isaConvolutionOpInterface that can also infer from generic features.
680+
// But we will still need stride/dilation attributes that will be annoying to
681+
// reverse-engineer...
682+
if (isa<ConvolutionOpInterface>(op))
683+
return success();
667684
// TODO: the common vector shape is equal to the static loop sizes only when
668685
// all indexing maps are projected permutations. For convs and stencils the
669686
// logic will need to evolve.
@@ -688,6 +705,18 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
688705
if (isaContractionOpInterface(linalgOp))
689706
return vectorizeContraction(b, linalgOp, newResults);
690707

708+
// TODO: isaConvolutionOpInterface that can also infer from generic features.
709+
// But we will still need stride/dilation attributes that will be annoying to
710+
// reverse-engineer...
711+
if (auto convOp = dyn_cast<ConvolutionOpInterface>(op)) {
712+
FailureOr<Operation *> resultOrFail = vectorizeConvolution(b, convOp);
713+
if (failed(resultOrFail))
714+
return failure();
715+
Operation *newOp = *resultOrFail;
716+
llvm::append_range(newResults, newOp->getResults());
717+
return success();
718+
}
719+
691720
LDBG(""
692721
<< "Vectorize linalg op as a generic by broadcasting to "
693722
"maximal common shape: "
@@ -1421,3 +1450,188 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
14211450

14221451
return success();
14231452
}
1453+
1454+
//===----------------------------------------------------------------------===//
1455+
// Convolution vectorization patterns
1456+
//===----------------------------------------------------------------------===//
1457+
namespace {
1458+
/// Generate a vector implementation for:
1459+
/// ```
1460+
/// Op def: ( n, w, c, kw, f )
1461+
/// Iters: ({Par(), Par(), Par(), Red(), Red()})
1462+
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
1463+
/// ```
1464+
/// w and kw are unrolled.
1465+
/// TODO: do not unroll w (resp. kw) when the strideW ( resp. dilationW) is > 1.
1466+
struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
1467+
Conv1D_NWC_WCF_Generator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
1468+
int dilationW)
1469+
: StructuredGenerator<LinalgOp>(builder, linalgOp), valid(false),
1470+
strideW(strideW), dilationW(dilationW) {
1471+
// Determine whether `linalgOp` can be generated with this generator
1472+
if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1)
1473+
return;
1474+
lhsShaped = linalgOp.inputs()[0];
1475+
rhsShaped = linalgOp.inputs()[1];
1476+
resShaped = linalgOp.outputs()[0];
1477+
lhsShapedType = lhsShaped.getType().dyn_cast<ShapedType>();
1478+
rhsShapedType = rhsShaped.getType().dyn_cast<ShapedType>();
1479+
resShapedType = resShaped.getType().dyn_cast<ShapedType>();
1480+
if (!lhsShapedType || !rhsShapedType || !resShapedType)
1481+
return;
1482+
if (lhsShapedType.getRank() != 3 || rhsShapedType.getRank() != 3 ||
1483+
resShapedType.getRank() != 3)
1484+
return;
1485+
1486+
// Check for reduction `add` preceded by `mul`.
1487+
Operation *reduceOp = matchLinalgReduction(linalgOp.getOutputOperand(0));
1488+
if (!reduceOp)
1489+
return;
1490+
llvm::Optional<vector::CombiningKind> maybeKind;
1491+
maybeKind = getKindForOp(reduceOp);
1492+
if (!maybeKind || *maybeKind != vector::CombiningKind::ADD)
1493+
return;
1494+
maybeKind = getKindForOp(&(linalgOp->getRegion(0).front().front()));
1495+
if (!maybeKind || *maybeKind != vector::CombiningKind::MUL)
1496+
return;
1497+
1498+
// The op is now known to be valid.
1499+
valid = true;
1500+
}
1501+
1502+
/// Generate a vector implementation for:
1503+
/// ```
1504+
/// Op def: ( n, w, c, kw, f )
1505+
/// Iters: ({Par(), Par(), Par(), Red(), Red()})
1506+
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
1507+
/// ```
1508+
/// w and kw are unrolled.
1509+
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
1510+
FailureOr<Operation *> conv() {
1511+
if (!valid)
1512+
return failure();
1513+
1514+
int nSize = lhsShapedType.getShape()[0];
1515+
int wSize = resShapedType.getShape()[1];
1516+
int cSize = lhsShapedType.getShape()[2];
1517+
int kwSize = rhsShapedType.getShape()[0];
1518+
int fSize = rhsShapedType.getShape()[2];
1519+
1520+
vector::TransferWriteOp write;
1521+
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1522+
1523+
// Unroll along kw and read slices of lhs and rhs.
1524+
// Alternatively we could preload both 3-d slices and extract smaller slices
1525+
// iteratively without touching memory. But this will quickly spill.
1526+
for (int64_t kw = 0; kw < kwSize; ++kw) {
1527+
// Read rhs slice of size {1, c, f} @ [kw, 0, 0].
1528+
Value kwVal = builder.create<arith::ConstantIndexOp>(loc, kw);
1529+
VectorType rhsType =
1530+
VectorType::get({1, cSize, fSize}, rhsShapedType.getElementType());
1531+
Value rhs = builder.create<vector::TransferReadOp>(
1532+
loc, rhsType, rhsShaped, ValueRange{kwVal, zero, zero});
1533+
1534+
for (int64_t w = 0; w < wSize; ++w) {
1535+
// Read lhs slice of size {n, 1, c} @ [0, sw * w + dw * kw, 0].
1536+
Value lhsStridedIdx = builder.create<arith::ConstantIndexOp>(
1537+
loc, strideW * w + dilationW * kw);
1538+
VectorType lhsType =
1539+
VectorType::get({nSize, 1, cSize}, lhsShapedType.getElementType());
1540+
Value lhs = builder.create<vector::TransferReadOp>(
1541+
loc, lhsType, lhsShaped, ValueRange{zero, lhsStridedIdx, zero});
1542+
1543+
// Read res slice: {n, 1, f} @ [0, w, 0].
1544+
Value wVal = builder.create<arith::ConstantIndexOp>(loc, w);
1545+
VectorType resType =
1546+
VectorType::get({nSize, 1, fSize}, resShapedType.getElementType());
1547+
// When operating on tensors, reading from the updated value is required
1548+
// for vector.transfer_read/write hoisting to function as expected.
1549+
Value res = builder.create<vector::TransferReadOp>(
1550+
loc, resType, resShaped, ValueRange{zero, wVal, zero});
1551+
1552+
// Compute contraction: I{n, 1, c} * F{1, c, f} -> O{n, 1, f}
1553+
StringRef par = Par().strRef, red = Red().strRef;
1554+
AffineExpr n, one, f, c;
1555+
bindDims(ctx, n, one, f, c);
1556+
// clang-format off
1557+
res = builder.create<vector::ContractionOp>(
1558+
loc, lhs, rhs, res,
1559+
/*indexingMaps=*/MapList{{n, one, c}, {one, c, f}, {n, one, f}},
1560+
/*iteratorTypes=*/ArrayRef<StringRef>{par, par, par, red});
1561+
// clang-format on
1562+
1563+
// Write back res slice: {n, 1, f} @ [0, w, 0].
1564+
write = builder.create<vector::TransferWriteOp>(
1565+
loc, res, resShaped, ValueRange{zero, wVal, zero});
1566+
if (write.getNumResults() == 1)
1567+
resShaped = write->getResult(0);
1568+
}
1569+
}
1570+
1571+
return write.getOperation();
1572+
}
1573+
1574+
/// Entry point that transposes into the common form:
1575+
/// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
1576+
FailureOr<Operation *> generateConv() {
1577+
AffineExpr n, w, f, kw, c;
1578+
bindDims(ctx, n, w, f, kw, c);
1579+
1580+
if (!iters({Par(), Par(), Par(), Red(), Red()}))
1581+
return failure();
1582+
1583+
// No transposition needed.
1584+
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
1585+
/*rhsIndex*/ {kw, c, f},
1586+
/*resIndex*/ {n, w, f}}))
1587+
return conv();
1588+
return failure();
1589+
}
1590+
1591+
private:
1592+
bool valid;
1593+
int strideW, dilationW;
1594+
Value lhsShaped, rhsShaped, resShaped;
1595+
ShapedType lhsShapedType, rhsShapedType, resShapedType;
1596+
};
1597+
} // namespace
1598+
1599+
/// Helper function to vectorize a `linalgOp` with convolution semantics.
1600+
// TODO: extend the generic vectorization to support windows and drop this.
1601+
static FailureOr<Operation *>
1602+
vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) {
1603+
// TODO: these are legitimately part of ConvolutionOpInterface.
1604+
auto strides = convOp->getAttrOfType<DenseIntElementsAttr>("strides");
1605+
auto dilations = convOp->getAttrOfType<DenseIntElementsAttr>("dilations");
1606+
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
1607+
auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
1608+
LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation());
1609+
Conv1D_NWC_WCF_Generator e(b, linalgOp, stride, dilation);
1610+
return e.generateConv();
1611+
}
1612+
1613+
struct VectorizeConvolution
1614+
: public OpInterfaceRewritePattern<ConvolutionOpInterface> {
1615+
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
1616+
1617+
LogicalResult matchAndRewrite(ConvolutionOpInterface convOp,
1618+
PatternRewriter &rewriter) const override {
1619+
FailureOr<Operation *> resultOrFail =
1620+
vectorizeConvolution(rewriter, convOp);
1621+
if (failed(resultOrFail))
1622+
return failure();
1623+
Operation *newOp = *resultOrFail;
1624+
if (newOp->getNumResults() == 0) {
1625+
rewriter.eraseOp(convOp.getOperation());
1626+
return success();
1627+
}
1628+
assert(newOp->getNumResults() == 1 && "expected single result");
1629+
rewriter.replaceOp(convOp.getOperation(), newOp->getResult(0));
1630+
return success();
1631+
}
1632+
};
1633+
1634+
void mlir::linalg::populateConvolutionVectorizationPatterns(
1635+
RewritePatternSet &patterns, PatternBenefit benefit) {
1636+
patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
1637+
}

0 commit comments

Comments
 (0)