@@ -59,22 +59,31 @@ enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
59
59
// / children tensor expressions.
60
60
struct TensorExp {
61
61
TensorExp (Kind k, unsigned x, unsigned y, Value v)
62
- : kind(k), e0 (x), e1 (y), val(v) {}
62
+ : kind(k), e0 (x), e1 (y), val(v) {
63
+ assert ((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
64
+ (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
65
+ (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
66
+ }
63
67
Kind kind;
68
+ // / Indices of children expression(s).
64
69
unsigned e0 ;
65
70
unsigned e1 ;
71
+ // / Direct link to IR for an invariant. During code generation,
72
+ // / field is used to cache "hoisted" loop invariant tensor loads.
66
73
Value val;
67
74
};
68
75
69
- // / Lattice point. Each lattice point consist of a conjunction of tensor
76
+ // / Lattice point. Each lattice point consists of a conjunction of tensor
70
77
// / loop indices (encoded in a bitvector) and the index of the corresponding
71
78
// / tensor expression.
72
79
struct LatPoint {
73
80
LatPoint (unsigned n, unsigned e, unsigned b) : bits(n, false ), exp(e) {
74
81
bits.set (b);
75
82
}
76
83
LatPoint (const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
84
+ // / Conjunction of tensor loop indices as bitvector.
77
85
llvm::BitVector bits;
86
+ // / Index of the tensor expresssion.
78
87
unsigned exp;
79
88
};
80
89
@@ -502,8 +511,16 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
502
511
// / Generates a load on a dense or sparse tensor.
503
512
static Value genTensorLoad (Merger &merger, CodeGen &codegen,
504
513
PatternRewriter &rewriter, linalg::GenericOp op,
505
- unsigned tensor) {
514
+ unsigned exp) {
515
+ // Test if the load was hoisted to a higher loop nest.
516
+ Value val = merger.exp (exp).val ;
517
+ if (val) {
518
+ merger.exp (exp).val = Value (); // reset
519
+ return val;
520
+ }
521
+ // Actual load.
506
522
SmallVector<Value, 4 > args;
523
+ unsigned tensor = merger.exp (exp).e0 ;
507
524
auto map = op.getIndexingMap (tensor);
508
525
bool sparse = false ;
509
526
for (unsigned i = 0 , m = map.getNumResults (); i < m; ++i) {
@@ -515,7 +532,9 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
515
532
args.push_back (codegen.pidxs [tensor][idx]); // position index
516
533
}
517
534
}
518
- return rewriter.create <LoadOp>(op.getLoc (), codegen.buffers [tensor], args);
535
+ Location loc = op.getLoc ();
536
+ Value ptr = codegen.buffers [tensor];
537
+ return rewriter.create <LoadOp>(loc, ptr, args);
519
538
}
520
539
521
540
// / Generates a store on a dense tensor.
@@ -528,25 +547,33 @@ static void genTensorStore(Merger &merger, CodeGen &codegen,
528
547
unsigned idx = map.getDimPosition (i);
529
548
args.push_back (codegen.loops [idx]); // universal dense index
530
549
}
531
- rewriter.create <StoreOp>(op.getLoc (), rhs, codegen.buffers [tensor], args);
550
+ Location loc = op.getLoc ();
551
+ Value ptr = codegen.buffers [tensor];
552
+ rewriter.create <StoreOp>(loc, rhs, ptr, args);
532
553
}
533
554
534
555
// / Generates a pointer/index load from the sparse storage scheme.
535
- static Value genIntLoad (PatternRewriter &rewriter, Location loc, Value ptr,
536
- Value s) {
556
+ static Value genLoad (PatternRewriter &rewriter, Location loc, Value ptr,
557
+ Value s) {
537
558
Value load = rewriter.create <LoadOp>(loc, ptr, s);
538
559
return load.getType ().isa <IndexType>()
539
560
? load
540
561
: rewriter.create <IndexCastOp>(loc, load, rewriter.getIndexType ());
541
562
}
542
563
564
+ // / Generates an invariant value.
565
+ static Value genInvariantValue (Merger &merger, CodeGen &codegen,
566
+ PatternRewriter &rewriter, unsigned exp) {
567
+ return merger.exp (exp).val ;
568
+ }
569
+
543
570
// / Recursively generates tensor expression.
544
571
static Value genExp (Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
545
572
linalg::GenericOp op, unsigned exp) {
546
573
if (merger.exp (exp).kind == Kind::kTensor )
547
- return genTensorLoad (merger, codegen, rewriter, op, merger. exp (exp). e0 );
574
+ return genTensorLoad (merger, codegen, rewriter, op, exp);
548
575
else if (merger.exp (exp).kind == Kind::kInvariant )
549
- return merger. exp (exp). val ;
576
+ return genInvariantValue ( merger, codegen, rewriter, exp) ;
550
577
Value v0 = genExp (merger, codegen, rewriter, op, merger.exp (exp).e0 );
551
578
Value v1 = genExp (merger, codegen, rewriter, op, merger.exp (exp).e1 );
552
579
switch (merger.exp (exp).kind ) {
@@ -564,6 +591,33 @@ static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
564
591
}
565
592
}
566
593
594
+ // / Hoists loop invariant tensor loads for which indices have been exhausted.
595
+ static void genInvariants (Merger &merger, CodeGen &codegen,
596
+ PatternRewriter &rewriter, linalg::GenericOp op,
597
+ unsigned exp) {
598
+ if (merger.exp (exp).kind == Kind::kTensor ) {
599
+ unsigned lhs = op.getNumInputsAndOutputs () - 1 ;
600
+ unsigned tensor = merger.exp (exp).e0 ;
601
+ if (tensor == lhs)
602
+ return ; // TODO: scalarize reduction as well (using scf.yield)
603
+ auto map = op.getIndexingMap (tensor);
604
+ for (unsigned i = 0 , m = map.getNumResults (); i < m; ++i) {
605
+ unsigned idx = map.getDimPosition (i);
606
+ if (!codegen.loops [idx])
607
+ return ; // still in play
608
+ }
609
+ // All exhausted at this level.
610
+ merger.exp (exp).val = genTensorLoad (merger, codegen, rewriter, op, exp);
611
+
612
+ } else if (merger.exp (exp).kind != Kind::kInvariant ) {
613
+ // Traverse into the binary operations. Note that we only hoist
614
+ // tensor loads, since subsequent MLIR/LLVM passes know how to
615
+ // deal with all other kinds of derived loop invariants.
616
+ genInvariants (merger, codegen, rewriter, op, merger.exp (exp).e0 );
617
+ genInvariants (merger, codegen, rewriter, op, merger.exp (exp).e1 );
618
+ }
619
+ }
620
+
567
621
// / Generates initialization code for the subsequent loop sequence at
568
622
// / current index level. Returns true if the loop sequence needs to
569
623
// / maintain the universal index.
@@ -590,9 +644,9 @@ static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
590
644
Value one = rewriter.create <ConstantIndexOp>(loc, 1 );
591
645
Value p0 = (pat == 0 ) ? rewriter.create <ConstantIndexOp>(loc, 0 )
592
646
: codegen.pidxs [tensor][topSort[pat - 1 ]];
593
- codegen.pidxs [tensor][idx] = genIntLoad (rewriter, loc, ptr, p0);
647
+ codegen.pidxs [tensor][idx] = genLoad (rewriter, loc, ptr, p0);
594
648
Value p1 = rewriter.create <AddIOp>(loc, p0, one);
595
- codegen.highs [tensor][idx] = genIntLoad (rewriter, loc, ptr, p1);
649
+ codegen.highs [tensor][idx] = genLoad (rewriter, loc, ptr, p1);
596
650
} else {
597
651
// Dense index still in play.
598
652
needsUniv = true ;
@@ -608,7 +662,8 @@ static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
608
662
// / Generates a for-loop on a single index.
609
663
static Operation *genFor (Merger &merger, CodeGen &codegen,
610
664
PatternRewriter &rewriter, linalg::GenericOp op,
611
- bool isOuter, unsigned idx, llvm::BitVector &indices) {
665
+ bool isOuter, bool isInner, unsigned idx,
666
+ llvm::BitVector &indices) {
612
667
unsigned fb = indices.find_first ();
613
668
unsigned tensor = merger.tensor (fb);
614
669
assert (idx == merger.index (fb));
@@ -725,10 +780,15 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen,
725
780
// / singleton iteration or co-iteration over the given conjunction.
726
781
static Operation *genLoop (Merger &merger, CodeGen &codegen,
727
782
PatternRewriter &rewriter, linalg::GenericOp op,
728
- bool isOuter, unsigned idx, bool needsUniv,
729
- llvm::BitVector &indices) {
730
- if (indices.count () == 1 )
731
- return genFor (merger, codegen, rewriter, op, isOuter, idx, indices);
783
+ std::vector<unsigned > &topSort, unsigned at,
784
+ bool needsUniv, llvm::BitVector &indices) {
785
+ unsigned idx = topSort[at];
786
+ if (indices.count () == 1 ) {
787
+ bool isOuter = at == 0 ;
788
+ bool isInner = at == topSort.size () - 1 ;
789
+ return genFor (merger, codegen, rewriter, op, isOuter, isInner, idx,
790
+ indices);
791
+ }
732
792
return genWhile (merger, codegen, rewriter, op, idx, needsUniv, indices);
733
793
}
734
794
@@ -749,7 +809,7 @@ static void genLocals(Merger &merger, CodeGen &codegen,
749
809
assert (idx == merger.index (b));
750
810
Value ptr = codegen.indices [tensor][idx];
751
811
Value s = codegen.pidxs [tensor][idx];
752
- Value load = genIntLoad (rewriter, loc, ptr, s);
812
+ Value load = genLoad (rewriter, loc, ptr, s);
753
813
codegen.idxs [tensor][idx] = load;
754
814
if (!needsUniv) {
755
815
if (min) {
@@ -886,6 +946,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
886
946
assert (lsize != 0 );
887
947
unsigned l0 = merger.set (lts)[0 ];
888
948
LatPoint lat0 = merger.lat (l0);
949
+ genInvariants (merger, codegen, rewriter, op, exp);
889
950
bool needsUniv =
890
951
genInit (merger, codegen, rewriter, op, topSort, at, lat0.bits ) &&
891
952
lsize > 1 ;
@@ -897,9 +958,8 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
897
958
// Emit loop.
898
959
llvm::BitVector indices = lati.bits ;
899
960
optimizeIndices (merger, lsize, indices);
900
- bool isOuter = at == 0 ;
901
- Operation *loop = genLoop (merger, codegen, rewriter, op, isOuter, idx,
902
- needsUniv, indices);
961
+ Operation *loop =
962
+ genLoop (merger, codegen, rewriter, op, topSort, at, needsUniv, indices);
903
963
genLocals (merger, codegen, rewriter, op, topSort, at, needsUniv, lati.bits );
904
964
905
965
// Visit all lattices points with Li >= Lj to generate the
@@ -931,6 +991,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
931
991
}
932
992
rewriter.setInsertionPointAfter (loop);
933
993
}
994
+ codegen.loops [idx] = Value ();
934
995
}
935
996
936
997
namespace {
0 commit comments