Skip to content

Commit 1ece4d3

Browse files
authored
[mlir][sparse] code simplification: always use synthetical tensor for… (#73597)
… loop bound.
1 parent a3529aa commit 1ece4d3

File tree

2 files changed

+10
-20
lines changed

2 files changed

+10
-20
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
339339
const SparseTensorType stt(rtp);
340340
lvlRank = stt.getLvlRank();
341341

342-
// We always treat sparse output tensor as dense so that we always iterate
343-
// it based on lvl size.
344-
if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) {
342+
if (stt.hasEncoding()) {
345343
const auto enc = stt.getEncoding();
346344
isSparseSlices[tid] = enc.isSlice();
347345
for (auto lvlTp : enc.getLvlTypes())

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,28 +1059,20 @@ static bool translateBitsToTidLvlPairs(
10591059
}
10601060
if (isUndefLT(lt)) {
10611061
// An undefined lt in the lattices, we probably mean to
1062-
// iterate based on the level of output tensor. E.g., this
1063-
// could be a synthetic tensor (for invariants and sparse
1064-
// output tensor).
1065-
auto itType = env.op().getIteratorTypesArray()[ldx];
1066-
if (linalg::isReductionIterator(itType) &&
1067-
env.merger().getSynTensorID() == tid) {
1068-
// Coiterating with an invariant, and this is a reduction loop
1062+
// generate a dense loop according to the synthetic tensor (for
1063+
// invariants and sparse output tensor).
1064+
if (env.merger().getSynTensorID() == tid) {
1065+
// Coiterating with an invariant
10691066
// e.g., out = prod(in[i][j] op invariant);
1070-
// In this case, we can not infer the loop bound from output
1071-
// (whose level is reduced). Instead we use the synthetic tensor
1072-
// to infer the bound.
1067+
// or a broadcast
1068+
// e.g., out[i][j] = in[i] (j is undef for input)
1069+
//
10731070
// The level of the synthetic tensor is the current loop depth;
10741071
// the rank of the synthetic tensor equals to number of loops.
10751072
lvl = env.emitter().getCurrentDepth();
1076-
} else {
1077-
// or a broadcast
1078-
// out[i][j] = in[i] (j is undef for input)
1079-
tid = outTid;
1080-
lvl = outLvl;
1073+
} else if (!lvl) {
10811074
// Skips invalid lvl (e.g., when this is a zero ranked tensor).
1082-
if (!lvl)
1083-
return;
1075+
return;
10841076
}
10851077
}
10861078
hasNonUnique = !isUniqueLT(lt) || hasNonUnique;

0 commit comments

Comments
 (0)