20
20
#include " mlir/Dialect/Tensor/IR/Tensor.h"
21
21
#include " mlir/Dialect/Utils/StructuredOpsUtils.h"
22
22
#include " mlir/Dialect/Vector/VectorOps.h"
23
+ #include " mlir/Dialect/Vector/VectorTransforms.h"
23
24
#include " mlir/IR/AffineExpr.h"
24
25
#include " mlir/IR/Matchers.h"
25
26
#include " mlir/IR/PatternMatch.h"
@@ -44,6 +45,12 @@ using llvm::dbgs;
44
45
#define DBGS () (llvm::dbgs() << ' [' << DEBUG_TYPE << " ] " )
45
46
#define LDBG (X ) LLVM_DEBUG(DBGS() << X)
46
47
48
+ // Forward declarations.
49
+ static LogicalResult vectorizeContraction (OpBuilder &b, LinalgOp linalgOp,
50
+ SmallVectorImpl<Value> &newResults);
51
+ static FailureOr<Operation *>
52
+ vectorizeConvolution (OpBuilder &b, ConvolutionOpInterface convOp);
53
+
47
54
// / Return the unique instance of OpType in `block` if it is indeed unique.
48
55
// / Return null if none or more than 1 instances exist.
49
56
template <typename OpType>
@@ -147,7 +154,7 @@ static Operation *matchLinalgReduction(OpOperand *outputOperand) {
147
154
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner ());
148
155
unsigned outputPos =
149
156
outputOperand->getOperandNumber () - linalgOp.getNumInputs ();
150
- // Only single combiner operatios are supported for now.
157
+ // Only single combiner operations are supported for now.
151
158
SmallVector<Operation *, 4 > combinerOps;
152
159
if (!matchReduction (linalgOp.getRegionOutputArgs (), outputPos, combinerOps) ||
153
160
combinerOps.size () != 1 )
@@ -575,6 +582,11 @@ LogicalResult vectorizeAsLinalgGeneric(
575
582
return success ();
576
583
}
577
584
585
+ // / Helper function to vectorize a `linalgOp` with contraction semantics in a
586
+ // / generic fashion.
587
+ // / This helper is needed atm because the truly generic implementation requires
588
+ // / good vector.multi_reduce folding patterns that are currently NYI.
589
+ // TODO: drop reliance on a specific pattern.
578
590
static LogicalResult vectorizeContraction (OpBuilder &b, LinalgOp linalgOp,
579
591
SmallVectorImpl<Value> &newResults) {
580
592
assert (isaContractionOpInterface (linalgOp) &&
@@ -664,6 +676,11 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
664
676
return success ();
665
677
if (isaContractionOpInterface (linalgOp))
666
678
return success ();
679
+ // TODO: isaConvolutionOpInterface that can also infer from generic features.
680
+ // But we will still need stride/dilation attributes that will be annoying to
681
+ // reverse-engineer...
682
+ if (isa<ConvolutionOpInterface>(op))
683
+ return success ();
667
684
// TODO: the common vector shape is equal to the static loop sizes only when
668
685
// all indexing maps are projected permutations. For convs and stencils the
669
686
// logic will need to evolve.
@@ -688,6 +705,18 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
688
705
if (isaContractionOpInterface (linalgOp))
689
706
return vectorizeContraction (b, linalgOp, newResults);
690
707
708
+ // TODO: isaConvolutionOpInterface that can also infer from generic features.
709
+ // But we will still need stride/dilation attributes that will be annoying to
710
+ // reverse-engineer...
711
+ if (auto convOp = dyn_cast<ConvolutionOpInterface>(op)) {
712
+ FailureOr<Operation *> resultOrFail = vectorizeConvolution (b, convOp);
713
+ if (failed (resultOrFail))
714
+ return failure ();
715
+ Operation *newOp = *resultOrFail;
716
+ llvm::append_range (newResults, newOp->getResults ());
717
+ return success ();
718
+ }
719
+
691
720
LDBG (" "
692
721
<< " Vectorize linalg op as a generic by broadcasting to "
693
722
" maximal common shape: "
@@ -1421,3 +1450,188 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
1421
1450
1422
1451
return success ();
1423
1452
}
1453
+
1454
+ // ===----------------------------------------------------------------------===//
1455
+ // Convolution vectorization patterns
1456
+ // ===----------------------------------------------------------------------===//
1457
+ namespace {
1458
+ // / Generate a vector implementation for:
1459
+ // / ```
1460
+ // / Op def: ( n, w, c, kw, f )
1461
+ // / Iters: ({Par(), Par(), Par(), Red(), Red()})
1462
+ // / Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
1463
+ // / ```
1464
+ // / w and kw are unrolled.
1465
+ // / TODO: do not unroll w (resp. kw) when the strideW ( resp. dilationW) is > 1.
1466
+ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator <LinalgOp> {
1467
+ Conv1D_NWC_WCF_Generator (OpBuilder &builder, LinalgOp linalgOp, int strideW,
1468
+ int dilationW)
1469
+ : StructuredGenerator<LinalgOp>(builder, linalgOp), valid(false ),
1470
+ strideW (strideW), dilationW(dilationW) {
1471
+ // Determine whether `linalgOp` can be generated with this generator
1472
+ if (linalgOp.getNumInputs () != 2 || linalgOp.getNumOutputs () != 1 )
1473
+ return ;
1474
+ lhsShaped = linalgOp.inputs ()[0 ];
1475
+ rhsShaped = linalgOp.inputs ()[1 ];
1476
+ resShaped = linalgOp.outputs ()[0 ];
1477
+ lhsShapedType = lhsShaped.getType ().dyn_cast <ShapedType>();
1478
+ rhsShapedType = rhsShaped.getType ().dyn_cast <ShapedType>();
1479
+ resShapedType = resShaped.getType ().dyn_cast <ShapedType>();
1480
+ if (!lhsShapedType || !rhsShapedType || !resShapedType)
1481
+ return ;
1482
+ if (lhsShapedType.getRank () != 3 || rhsShapedType.getRank () != 3 ||
1483
+ resShapedType.getRank () != 3 )
1484
+ return ;
1485
+
1486
+ // Check for reduction `add` preceded by `mul`.
1487
+ Operation *reduceOp = matchLinalgReduction (linalgOp.getOutputOperand (0 ));
1488
+ if (!reduceOp)
1489
+ return ;
1490
+ llvm::Optional<vector::CombiningKind> maybeKind;
1491
+ maybeKind = getKindForOp (reduceOp);
1492
+ if (!maybeKind || *maybeKind != vector::CombiningKind::ADD)
1493
+ return ;
1494
+ maybeKind = getKindForOp (&(linalgOp->getRegion (0 ).front ().front ()));
1495
+ if (!maybeKind || *maybeKind != vector::CombiningKind::MUL)
1496
+ return ;
1497
+
1498
+ // The op is now known to be valid.
1499
+ valid = true ;
1500
+ }
1501
+
1502
+ // / Generate a vector implementation for:
1503
+ // / ```
1504
+ // / Op def: ( n, w, c, kw, f )
1505
+ // / Iters: ({Par(), Par(), Par(), Red(), Red()})
1506
+ // / Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
1507
+ // / ```
1508
+ // / w and kw are unrolled.
1509
+ // / TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
1510
+ FailureOr<Operation *> conv () {
1511
+ if (!valid)
1512
+ return failure ();
1513
+
1514
+ int nSize = lhsShapedType.getShape ()[0 ];
1515
+ int wSize = resShapedType.getShape ()[1 ];
1516
+ int cSize = lhsShapedType.getShape ()[2 ];
1517
+ int kwSize = rhsShapedType.getShape ()[0 ];
1518
+ int fSize = rhsShapedType.getShape ()[2 ];
1519
+
1520
+ vector::TransferWriteOp write;
1521
+ Value zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1522
+
1523
+ // Unroll along kw and read slices of lhs and rhs.
1524
+ // Alternatively we could preload both 3-d slices and extract smaller slices
1525
+ // iteratively without touching memory. But this will quickly spill.
1526
+ for (int64_t kw = 0 ; kw < kwSize; ++kw) {
1527
+ // Read rhs slice of size {1, c, f} @ [kw, 0, 0].
1528
+ Value kwVal = builder.create <arith::ConstantIndexOp>(loc, kw);
1529
+ VectorType rhsType =
1530
+ VectorType::get ({1 , cSize, fSize }, rhsShapedType.getElementType ());
1531
+ Value rhs = builder.create <vector::TransferReadOp>(
1532
+ loc, rhsType, rhsShaped, ValueRange{kwVal, zero, zero});
1533
+
1534
+ for (int64_t w = 0 ; w < wSize; ++w) {
1535
+ // Read lhs slice of size {n, 1, c} @ [0, sw * w + dw * kw, 0].
1536
+ Value lhsStridedIdx = builder.create <arith::ConstantIndexOp>(
1537
+ loc, strideW * w + dilationW * kw);
1538
+ VectorType lhsType =
1539
+ VectorType::get ({nSize, 1 , cSize}, lhsShapedType.getElementType ());
1540
+ Value lhs = builder.create <vector::TransferReadOp>(
1541
+ loc, lhsType, lhsShaped, ValueRange{zero, lhsStridedIdx, zero});
1542
+
1543
+ // Read res slice: {n, 1, f} @ [0, w, 0].
1544
+ Value wVal = builder.create <arith::ConstantIndexOp>(loc, w);
1545
+ VectorType resType =
1546
+ VectorType::get ({nSize, 1 , fSize }, resShapedType.getElementType ());
1547
+ // When operating on tensors, reading from the updated value is required
1548
+ // for vector.transfer_read/write hoisting to function as expected.
1549
+ Value res = builder.create <vector::TransferReadOp>(
1550
+ loc, resType, resShaped, ValueRange{zero, wVal, zero});
1551
+
1552
+ // Compute contraction: I{n, 1, c} * F{1, c, f} -> O{n, 1, f}
1553
+ StringRef par = Par ().strRef , red = Red ().strRef ;
1554
+ AffineExpr n, one, f, c;
1555
+ bindDims (ctx, n, one, f, c);
1556
+ // clang-format off
1557
+ res = builder.create <vector::ContractionOp>(
1558
+ loc, lhs, rhs, res,
1559
+ /* indexingMaps=*/ MapList{{n, one, c}, {one, c, f}, {n, one, f}},
1560
+ /* iteratorTypes=*/ ArrayRef<StringRef>{par, par, par, red});
1561
+ // clang-format on
1562
+
1563
+ // Write back res slice: {n, 1, f} @ [0, w, 0].
1564
+ write = builder.create <vector::TransferWriteOp>(
1565
+ loc, res, resShaped, ValueRange{zero, wVal, zero});
1566
+ if (write.getNumResults () == 1 )
1567
+ resShaped = write->getResult (0 );
1568
+ }
1569
+ }
1570
+
1571
+ return write.getOperation ();
1572
+ }
1573
+
1574
+ // / Entry point that transposes into the common form:
1575
+ // / {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
1576
+ FailureOr<Operation *> generateConv () {
1577
+ AffineExpr n, w, f, kw, c;
1578
+ bindDims (ctx, n, w, f, kw, c);
1579
+
1580
+ if (!iters ({Par (), Par (), Par (), Red (), Red ()}))
1581
+ return failure ();
1582
+
1583
+ // No transposition needed.
1584
+ if (layout ({/* lhsIndex*/ {n, strideW * w + dilationW * kw, c},
1585
+ /* rhsIndex*/ {kw, c, f},
1586
+ /* resIndex*/ {n, w, f}}))
1587
+ return conv ();
1588
+ return failure ();
1589
+ }
1590
+
1591
+ private:
1592
+ bool valid;
1593
+ int strideW, dilationW;
1594
+ Value lhsShaped, rhsShaped, resShaped;
1595
+ ShapedType lhsShapedType, rhsShapedType, resShapedType;
1596
+ };
1597
+ } // namespace
1598
+
1599
+ // / Helper function to vectorize a `linalgOp` with convolution semantics.
1600
+ // TODO: extend the generic vectorization to support windows and drop this.
1601
+ static FailureOr<Operation *>
1602
+ vectorizeConvolution (OpBuilder &b, ConvolutionOpInterface convOp) {
1603
+ // TODO: these are legitimately part of ConvolutionOpInterface.
1604
+ auto strides = convOp->getAttrOfType <DenseIntElementsAttr>(" strides" );
1605
+ auto dilations = convOp->getAttrOfType <DenseIntElementsAttr>(" dilations" );
1606
+ auto stride = strides ? *strides.getValues <uint64_t >().begin () : 1 ;
1607
+ auto dilation = dilations ? *dilations.getValues <uint64_t >().begin () : 1 ;
1608
+ LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation ());
1609
+ Conv1D_NWC_WCF_Generator e (b, linalgOp, stride, dilation);
1610
+ return e.generateConv ();
1611
+ }
1612
+
1613
+ struct VectorizeConvolution
1614
+ : public OpInterfaceRewritePattern<ConvolutionOpInterface> {
1615
+ using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
1616
+
1617
+ LogicalResult matchAndRewrite (ConvolutionOpInterface convOp,
1618
+ PatternRewriter &rewriter) const override {
1619
+ FailureOr<Operation *> resultOrFail =
1620
+ vectorizeConvolution (rewriter, convOp);
1621
+ if (failed (resultOrFail))
1622
+ return failure ();
1623
+ Operation *newOp = *resultOrFail;
1624
+ if (newOp->getNumResults () == 0 ) {
1625
+ rewriter.eraseOp (convOp.getOperation ());
1626
+ return success ();
1627
+ }
1628
+ assert (newOp->getNumResults () == 1 && " expected single result" );
1629
+ rewriter.replaceOp (convOp.getOperation (), newOp->getResult (0 ));
1630
+ return success ();
1631
+ }
1632
+ };
1633
+
1634
+ void mlir::linalg::populateConvolutionVectorizationPatterns (
1635
+ RewritePatternSet &patterns, PatternBenefit benefit) {
1636
+ patterns.add <VectorizeConvolution>(patterns.getContext (), benefit);
1637
+ }
0 commit comments