Skip to content

Commit 74cd9e5

Browse files
committed
[mlir][sparse] hoist loop invariant tensor loads in sparse compiler
After bufferization, the backend has much more trouble hoisting loop invariant loads from the loops generated by the sparse compiler. Therefore, this is done during sparse code generation. Note that we don't bother hoisting derived invariant expressions on SSA values, since the backend does that very well. Still TBD: scalarize reductions to avoid load-add-store cycles Reviewed By: penpornk Differential Revision: https://reviews.llvm.org/D92534
1 parent 1c98f98 commit 74cd9e5

File tree

3 files changed

+161
-42
lines changed

3 files changed

+161
-42
lines changed

mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -59,22 +59,31 @@ enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
5959
/// children tensor expressions.
6060
struct TensorExp {
6161
TensorExp(Kind k, unsigned x, unsigned y, Value v)
62-
: kind(k), e0(x), e1(y), val(v) {}
62+
: kind(k), e0(x), e1(y), val(v) {
63+
assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
64+
(kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
65+
(kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
66+
}
6367
Kind kind;
68+
/// Indices of children expression(s).
6469
unsigned e0;
6570
unsigned e1;
71+
/// Direct link to IR for an invariant. During code generation,
72+
/// field is used to cache "hoisted" loop invariant tensor loads.
6673
Value val;
6774
};
6875

69-
/// Lattice point. Each lattice point consist of a conjunction of tensor
76+
/// Lattice point. Each lattice point consists of a conjunction of tensor
7077
/// loop indices (encoded in a bitvector) and the index of the corresponding
7178
/// tensor expression.
7279
struct LatPoint {
7380
LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
7481
bits.set(b);
7582
}
7683
LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
84+
/// Conjunction of tensor loop indices as bitvector.
7785
llvm::BitVector bits;
86+
/// Index of the tensor expresssion.
7887
unsigned exp;
7988
};
8089

@@ -502,8 +511,16 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
502511
/// Generates a load on a dense or sparse tensor.
503512
static Value genTensorLoad(Merger &merger, CodeGen &codegen,
504513
PatternRewriter &rewriter, linalg::GenericOp op,
505-
unsigned tensor) {
514+
unsigned exp) {
515+
// Test if the load was hoisted to a higher loop nest.
516+
Value val = merger.exp(exp).val;
517+
if (val) {
518+
merger.exp(exp).val = Value(); // reset
519+
return val;
520+
}
521+
// Actual load.
506522
SmallVector<Value, 4> args;
523+
unsigned tensor = merger.exp(exp).e0;
507524
auto map = op.getIndexingMap(tensor);
508525
bool sparse = false;
509526
for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
@@ -515,7 +532,9 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
515532
args.push_back(codegen.pidxs[tensor][idx]); // position index
516533
}
517534
}
518-
return rewriter.create<LoadOp>(op.getLoc(), codegen.buffers[tensor], args);
535+
Location loc = op.getLoc();
536+
Value ptr = codegen.buffers[tensor];
537+
return rewriter.create<LoadOp>(loc, ptr, args);
519538
}
520539

521540
/// Generates a store on a dense tensor.
@@ -528,25 +547,33 @@ static void genTensorStore(Merger &merger, CodeGen &codegen,
528547
unsigned idx = map.getDimPosition(i);
529548
args.push_back(codegen.loops[idx]); // universal dense index
530549
}
531-
rewriter.create<StoreOp>(op.getLoc(), rhs, codegen.buffers[tensor], args);
550+
Location loc = op.getLoc();
551+
Value ptr = codegen.buffers[tensor];
552+
rewriter.create<StoreOp>(loc, rhs, ptr, args);
532553
}
533554

534555
/// Generates a pointer/index load from the sparse storage scheme.
535-
static Value genIntLoad(PatternRewriter &rewriter, Location loc, Value ptr,
536-
Value s) {
556+
static Value genLoad(PatternRewriter &rewriter, Location loc, Value ptr,
557+
Value s) {
537558
Value load = rewriter.create<LoadOp>(loc, ptr, s);
538559
return load.getType().isa<IndexType>()
539560
? load
540561
: rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType());
541562
}
542563

564+
/// Generates an invariant value.
565+
static Value genInvariantValue(Merger &merger, CodeGen &codegen,
566+
PatternRewriter &rewriter, unsigned exp) {
567+
return merger.exp(exp).val;
568+
}
569+
543570
/// Recursively generates tensor expression.
544571
static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
545572
linalg::GenericOp op, unsigned exp) {
546573
if (merger.exp(exp).kind == Kind::kTensor)
547-
return genTensorLoad(merger, codegen, rewriter, op, merger.exp(exp).e0);
574+
return genTensorLoad(merger, codegen, rewriter, op, exp);
548575
else if (merger.exp(exp).kind == Kind::kInvariant)
549-
return merger.exp(exp).val;
576+
return genInvariantValue(merger, codegen, rewriter, exp);
550577
Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0);
551578
Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1);
552579
switch (merger.exp(exp).kind) {
@@ -564,6 +591,33 @@ static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
564591
}
565592
}
566593

594+
/// Hoists loop invariant tensor loads for which indices have been exhausted.
595+
static void genInvariants(Merger &merger, CodeGen &codegen,
596+
PatternRewriter &rewriter, linalg::GenericOp op,
597+
unsigned exp) {
598+
if (merger.exp(exp).kind == Kind::kTensor) {
599+
unsigned lhs = op.getNumInputsAndOutputs() - 1;
600+
unsigned tensor = merger.exp(exp).e0;
601+
if (tensor == lhs)
602+
return; // TODO: scalarize reduction as well (using scf.yield)
603+
auto map = op.getIndexingMap(tensor);
604+
for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
605+
unsigned idx = map.getDimPosition(i);
606+
if (!codegen.loops[idx])
607+
return; // still in play
608+
}
609+
// All exhausted at this level.
610+
merger.exp(exp).val = genTensorLoad(merger, codegen, rewriter, op, exp);
611+
612+
} else if (merger.exp(exp).kind != Kind::kInvariant) {
613+
// Traverse into the binary operations. Note that we only hoist
614+
// tensor loads, since subsequent MLIR/LLVM passes know how to
615+
// deal with all other kinds of derived loop invariants.
616+
genInvariants(merger, codegen, rewriter, op, merger.exp(exp).e0);
617+
genInvariants(merger, codegen, rewriter, op, merger.exp(exp).e1);
618+
}
619+
}
620+
567621
/// Generates initialization code for the subsequent loop sequence at
568622
/// current index level. Returns true if the loop sequence needs to
569623
/// maintain the universal index.
@@ -590,9 +644,9 @@ static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
590644
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
591645
Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
592646
: codegen.pidxs[tensor][topSort[pat - 1]];
593-
codegen.pidxs[tensor][idx] = genIntLoad(rewriter, loc, ptr, p0);
647+
codegen.pidxs[tensor][idx] = genLoad(rewriter, loc, ptr, p0);
594648
Value p1 = rewriter.create<AddIOp>(loc, p0, one);
595-
codegen.highs[tensor][idx] = genIntLoad(rewriter, loc, ptr, p1);
649+
codegen.highs[tensor][idx] = genLoad(rewriter, loc, ptr, p1);
596650
} else {
597651
// Dense index still in play.
598652
needsUniv = true;
@@ -608,7 +662,8 @@ static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
608662
/// Generates a for-loop on a single index.
609663
static Operation *genFor(Merger &merger, CodeGen &codegen,
610664
PatternRewriter &rewriter, linalg::GenericOp op,
611-
bool isOuter, unsigned idx, llvm::BitVector &indices) {
665+
bool isOuter, bool isInner, unsigned idx,
666+
llvm::BitVector &indices) {
612667
unsigned fb = indices.find_first();
613668
unsigned tensor = merger.tensor(fb);
614669
assert(idx == merger.index(fb));
@@ -725,10 +780,15 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen,
725780
/// singleton iteration or co-iteration over the given conjunction.
726781
static Operation *genLoop(Merger &merger, CodeGen &codegen,
727782
PatternRewriter &rewriter, linalg::GenericOp op,
728-
bool isOuter, unsigned idx, bool needsUniv,
729-
llvm::BitVector &indices) {
730-
if (indices.count() == 1)
731-
return genFor(merger, codegen, rewriter, op, isOuter, idx, indices);
783+
std::vector<unsigned> &topSort, unsigned at,
784+
bool needsUniv, llvm::BitVector &indices) {
785+
unsigned idx = topSort[at];
786+
if (indices.count() == 1) {
787+
bool isOuter = at == 0;
788+
bool isInner = at == topSort.size() - 1;
789+
return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
790+
indices);
791+
}
732792
return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
733793
}
734794

@@ -749,7 +809,7 @@ static void genLocals(Merger &merger, CodeGen &codegen,
749809
assert(idx == merger.index(b));
750810
Value ptr = codegen.indices[tensor][idx];
751811
Value s = codegen.pidxs[tensor][idx];
752-
Value load = genIntLoad(rewriter, loc, ptr, s);
812+
Value load = genLoad(rewriter, loc, ptr, s);
753813
codegen.idxs[tensor][idx] = load;
754814
if (!needsUniv) {
755815
if (min) {
@@ -886,6 +946,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
886946
assert(lsize != 0);
887947
unsigned l0 = merger.set(lts)[0];
888948
LatPoint lat0 = merger.lat(l0);
949+
genInvariants(merger, codegen, rewriter, op, exp);
889950
bool needsUniv =
890951
genInit(merger, codegen, rewriter, op, topSort, at, lat0.bits) &&
891952
lsize > 1;
@@ -897,9 +958,8 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
897958
// Emit loop.
898959
llvm::BitVector indices = lati.bits;
899960
optimizeIndices(merger, lsize, indices);
900-
bool isOuter = at == 0;
901-
Operation *loop = genLoop(merger, codegen, rewriter, op, isOuter, idx,
902-
needsUniv, indices);
961+
Operation *loop =
962+
genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices);
903963
genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, lati.bits);
904964

905965
// Visit all lattices points with Li >= Lj to generate the
@@ -931,6 +991,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
931991
}
932992
rewriter.setInsertionPointAfter(loop);
933993
}
994+
codegen.loops[idx] = Value();
934995
}
935996

936997
namespace {

mlir/test/Dialect/Linalg/sparse_2d.mlir

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,8 +1071,8 @@ func @matvec(%argA: tensor<16x32xf32>, %argb: tensor<32xf32>, %argx: tensor<16xf
10711071
}
10721072

10731073
// CHECK-LABEL: func @sum_reduction(
1074-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<10x20xf32>,
1075-
// CHECK-SAME: %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
1074+
// CHECK-SAME: %[[VAL_0:.*0]]: tensor<10x20xf32>,
1075+
// CHECK-SAME: %[[VAL_1:.*1]]: tensor<f32>) -> tensor<f32> {
10761076
// CHECK: %[[VAL_2:.*]] = constant 999 : index
10771077
// CHECK: %[[VAL_3:.*]] = constant 10 : index
10781078
// CHECK: %[[VAL_4:.*]] = constant 0 : index
@@ -1200,19 +1200,19 @@ func @scale(%arga: tensor<?x?xf64>) -> tensor<?x?xf64> {
12001200
// CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_21]] to %[[VAL_22]] step %[[VAL_6]] {
12011201
// CHECK: %[[VAL_24:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_23]]] : memref<?xindex>
12021202
// CHECK: scf.for %[[VAL_25:.*]] = %[[VAL_5]] to %[[VAL_15]] step %[[VAL_6]] {
1203-
// CHECK: %[[VAL_26:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref<?xindex>
1204-
// CHECK: %[[VAL_27:.*]] = addi %[[VAL_23]], %[[VAL_6]] : index
1205-
// CHECK: %[[VAL_28:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_27]]] : memref<?xindex>
1206-
// CHECK: scf.for %[[VAL_29:.*]] = %[[VAL_26]] to %[[VAL_28]] step %[[VAL_6]] {
1207-
// CHECK: %[[VAL_30:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_29]]] : memref<?xindex>
1208-
// CHECK: %[[VAL_31:.*]] = load %[[VAL_20]]{{\[}}%[[VAL_24]], %[[VAL_30]]] : memref<?x?xf32>
1209-
// CHECK: %[[VAL_32:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref<?xf32>
1210-
// CHECK: %[[VAL_33:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_25]]] : memref<?x?xf32>
1211-
// CHECK: %[[VAL_34:.*]] = load %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_30]]] : memref<?x?xf32>
1212-
// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_33]], %[[VAL_34]] : f32
1213-
// CHECK: %[[VAL_36:.*]] = mulf %[[VAL_32]], %[[VAL_35]] : f32
1214-
// CHECK: %[[VAL_37:.*]] = addf %[[VAL_31]], %[[VAL_36]] : f32
1215-
// CHECK: store %[[VAL_37]], %[[VAL_20]]{{\[}}%[[VAL_24]], %[[VAL_30]]] : memref<?x?xf32>
1203+
// CHECK: %[[VAL_26:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_25]]] : memref<?x?xf32>
1204+
// CHECK: %[[VAL_27:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref<?xindex>
1205+
// CHECK: %[[VAL_28:.*]] = addi %[[VAL_23]], %[[VAL_6]] : index
1206+
// CHECK: %[[VAL_29:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_28]]] : memref<?xindex>
1207+
// CHECK: scf.for %[[VAL_30:.*]] = %[[VAL_27]] to %[[VAL_29]] step %[[VAL_6]] {
1208+
// CHECK: %[[VAL_31:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_30]]] : memref<?xindex>
1209+
// CHECK: %[[VAL_32:.*]] = load %[[VAL_20]]{{\[}}%[[VAL_24]], %[[VAL_31]]] : memref<?x?xf32>
1210+
// CHECK: %[[VAL_33:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<?xf32>
1211+
// CHECK: %[[VAL_34:.*]] = load %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_31]]] : memref<?x?xf32>
1212+
// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_26]], %[[VAL_34]] : f32
1213+
// CHECK: %[[VAL_36:.*]] = mulf %[[VAL_33]], %[[VAL_35]] : f32
1214+
// CHECK: %[[VAL_37:.*]] = addf %[[VAL_32]], %[[VAL_36]] : f32
1215+
// CHECK: store %[[VAL_37]], %[[VAL_20]]{{\[}}%[[VAL_24]], %[[VAL_31]]] : memref<?x?xf32>
12161216
// CHECK: }
12171217
// CHECK: }
12181218
// CHECK: }

mlir/test/Dialect/Linalg/sparse_3d.mlir

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,15 +1192,15 @@ func @mul_sss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<
11921192
// CHECK: %[[VAL_25:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref<?xindex>
11931193
// CHECK: scf.for %[[VAL_26:.*]] = %[[VAL_23]] to %[[VAL_25]] step %[[VAL_6]] {
11941194
// CHECK: %[[VAL_27:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref<?xindex>
1195-
// CHECK: scf.for %[[VAL_28:.*]] = %[[VAL_5]] to %[[VAL_17]] step %[[VAL_6]] {
1196-
// CHECK: %[[VAL_29:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_26]]] : memref<?xf32>
1197-
// CHECK: %[[VAL_30:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_20]], %[[VAL_28]]] : memref<?x?xf32>
1198-
// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_29]], %[[VAL_30]] : f32
1199-
// CHECK: %[[VAL_32:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_27]], %[[VAL_28]]] : memref<?x?xf32>
1195+
// CHECK: %[[VAL_28:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_26]]] : memref<?xf32>
1196+
// CHECK: scf.for %[[VAL_29:.*]] = %[[VAL_5]] to %[[VAL_17]] step %[[VAL_6]] {
1197+
// CHECK: %[[VAL_30:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_20]], %[[VAL_29]]] : memref<?x?xf32>
1198+
// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_28]], %[[VAL_30]] : f32
1199+
// CHECK: %[[VAL_32:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_27]], %[[VAL_29]]] : memref<?x?xf32>
12001200
// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_31]], %[[VAL_32]] : f32
1201-
// CHECK: %[[VAL_34:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_19]], %[[VAL_28]]] : memref<?x?xf32>
1201+
// CHECK: %[[VAL_34:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_19]], %[[VAL_29]]] : memref<?x?xf32>
12021202
// CHECK: %[[VAL_35:.*]] = addf %[[VAL_33]], %[[VAL_34]] : f32
1203-
// CHECK: store %[[VAL_35]], %[[VAL_18]]{{\[}}%[[VAL_19]], %[[VAL_28]]] : memref<?x?xf32>
1203+
// CHECK: store %[[VAL_35]], %[[VAL_18]]{{\[}}%[[VAL_19]], %[[VAL_29]]] : memref<?x?xf32>
12041204
// CHECK: }
12051205
// CHECK: }
12061206
// CHECK: }
@@ -1281,3 +1281,61 @@ func @sum_reduction(%arga: tensor<10x20x30xf32>, %argx: tensor<f32>) -> tensor<f
12811281
} -> tensor<f32>
12821282
return %0 : tensor<f32>
12831283
}
1284+
1285+
#trait_invariants = {
1286+
indexing_maps = [
1287+
affine_map<(i,j,k) -> (i)>, // a
1288+
affine_map<(i,j,k) -> (j)>, // b
1289+
affine_map<(i,j,k) -> (k)>, // c
1290+
affine_map<(i,j,k) -> (i,j,k)> // x
1291+
],
1292+
sparse = [
1293+
[ "D" ], // a
1294+
[ "D" ], // b
1295+
[ "D" ], // c
1296+
[ "D", "D", "D" ] // x
1297+
],
1298+
iterator_types = ["parallel", "parallel", "parallel"],
1299+
doc = "x(i,j,k) = a(i) * b(j) * c(k)"
1300+
}
1301+
1302+
// CHECK-LABEL: func @invariants(
1303+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xf32>,
1304+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<20xf32>,
1305+
// CHECK-SAME: %[[VAL_2:.*]]: tensor<30xf32>) -> tensor<10x20x30xf32> {
1306+
// CHECK: %[[VAL_3:.*]] = constant 10 : index
1307+
// CHECK: %[[VAL_4:.*]] = constant 20 : index
1308+
// CHECK: %[[VAL_5:.*]] = constant 30 : index
1309+
// CHECK: %[[VAL_6:.*]] = constant 0 : index
1310+
// CHECK: %[[VAL_7:.*]] = constant 1 : index
1311+
// CHECK: %[[VAL_8:.*]] = alloca() : memref<10xf32>
1312+
// CHECK: %[[VAL_9:.*]] = alloca() : memref<20xf32>
1313+
// CHECK: %[[VAL_10:.*]] = alloca() : memref<30xf32>
1314+
// CHECK: %[[VAL_11:.*]] = alloca() : memref<10x20x30xf32>
1315+
// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
1316+
// CHECK: %[[VAL_13:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<10xf32>
1317+
// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
1318+
// CHECK: %[[VAL_15:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_14]]] : memref<20xf32>
1319+
// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
1320+
// CHECK: %[[VAL_17:.*]] = mulf %[[VAL_13]], %[[VAL_15]] : f32
1321+
// CHECK: %[[VAL_18:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_16]]] : memref<30xf32>
1322+
// CHECK: %[[VAL_19:.*]] = mulf %[[VAL_17]], %[[VAL_18]] : f32
1323+
// CHECK: store %[[VAL_19]], %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_14]], %[[VAL_16]]] : memref<10x20x30xf32>
1324+
// CHECK: }
1325+
// CHECK: }
1326+
// CHECK: }
1327+
// CHECK: %[[VAL_20:.*]] = tensor_load %[[VAL_11]] : memref<10x20x30xf32>
1328+
// CHECK: return %[[VAL_20]] : tensor<10x20x30xf32>
1329+
// CHECK: }
1330+
func @invariants(%arga: tensor<10xf32>,
1331+
%argb: tensor<20xf32>,
1332+
%argc: tensor<30xf32>) -> tensor<10x20x30xf32> {
1333+
%0 = linalg.generic #trait_invariants
1334+
ins(%arga, %argb, %argc : tensor<10xf32>, tensor<20xf32>, tensor<30xf32>) {
1335+
^bb(%a : f32, %b : f32, %c : f32):
1336+
%0 = mulf %a, %b : f32
1337+
%1 = mulf %0, %c : f32
1338+
linalg.yield %1: f32
1339+
} -> tensor<10x20x30xf32>
1340+
return %0 : tensor<10x20x30xf32>
1341+
}

0 commit comments

Comments
 (0)