Skip to content

Commit 6c06bde

Browse files
authored
[mlir][sparse] support loop range query using SparseTensorLevel. (#75670)
1 parent b3e353d commit 6c06bde

File tree

4 files changed

+175
-112
lines changed

4 files changed

+175
-112
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp

Lines changed: 23 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,12 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, TensorId tid,
244244
Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
245245
TensorId tid, Level lvl, Value pLo,
246246
Value pHi) {
247-
SparseTensorLevel &level = *lvls[tid][lvl];
248-
const Value sameCrd = level.peekCrdAt(builder, loc, pLo);
247+
SparseTensorLevel &stl = *lvls[tid][lvl];
248+
const Value sameCrd = stl.peekCrdAt(builder, loc, pLo);
249249
auto whileOp = builder.create<scf::WhileOp>(
250250
loc, builder.getIndexType(), pLo,
251251
/*beforeBuilder=*/
252-
[pHi, &level, sameCrd](OpBuilder &builder, Location loc, ValueRange ivs) {
252+
[pHi, &stl, sameCrd](OpBuilder &builder, Location loc, ValueRange ivs) {
253253
const auto pos = ivs[0];
254254
Value inBound = builder.create<arith::CmpIOp>(
255255
loc, arith::CmpIPredicate::ult, pos, pHi);
@@ -260,7 +260,7 @@ Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
260260
// Load the next coordinates only when inbound (to avoid OOB
261261
// accesses).
262262
builder.setInsertionPointToStart(ifInBound.thenBlock());
263-
Value crd = level.peekCrdAt(builder, loc, pos);
263+
Value crd = stl.peekCrdAt(builder, loc, pos);
264264
Value isSameCrd = builder.create<arith::CmpIOp>(
265265
loc, arith::CmpIPredicate::eq, crd, sameCrd);
266266
YIELD(isSameCrd);
@@ -1226,27 +1226,19 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
12261226

12271227
const Value c0 = C_IDX(0);
12281228
const Value c1 = C_IDX(1);
1229-
const Value c2 = C_IDX(2);
12301229
// Either the first level, or the previous level has been set.
12311230
/// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
12321231
assert(lvl == 0 || posits[tid][lvl - 1]);
1233-
if (isCompressedLT(lvlTp) || isLooseCompressedLT(lvlTp)) {
1234-
// TODO: eliminate the cast upon feature complete.
1235-
const Value mem =
1236-
isCompressedLT(lvlTp)
1237-
? static_cast<CompressedLevel &>(*lvls[tid][lvl]).posBuffer
1238-
: static_cast<LooseCompressedLevel &>(*lvls[tid][lvl]).posBuffer;
1239-
1240-
Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
1241-
if (isLooseCompressedLT(lvlTp))
1242-
pLo = builder.create<arith::MulIOp>(loc, pLo, c2);
1243-
posits[tid][lvl] = genIndexLoad(builder, loc, mem, pLo);
1244-
1245-
const Value pHi = ADDI(pLo, c1);
1246-
highs[tid][lvl] = genIndexLoad(builder, loc, mem, pHi);
1232+
if (isCompressedLT(lvlTp) || isLooseCompressedLT(lvlTp) ||
1233+
is2OutOf4LT(lvlTp)) {
1234+
1235+
Value pos = lvl == 0 ? c0 : posits[tid][lvl - 1];
1236+
std::tie(posits[tid][lvl], highs[tid][lvl]) =
1237+
lvls[tid][lvl]->peekRangeAt(builder, loc, pos);
12471238
return;
12481239
}
12491240
if (isSingletonLT(lvlTp)) {
1241+
// TODO: merge this as well when SparseTensorLevel support dedup.
12501242
const Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
12511243
posits[tid][lvl] = pLo;
12521244

@@ -1262,13 +1254,6 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
12621254
: ADDI(pLo, c1);
12631255
return;
12641256
}
1265-
if (is2OutOf4LT(lvlTp)) {
1266-
const Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
1267-
// Each 2:4 block has exactly two specified elements.
1268-
posits[tid][lvl] = MULI(pLo, c2);
1269-
highs[tid][lvl] = ADDI(posits[tid][lvl], c2);
1270-
return;
1271-
}
12721257
llvm_unreachable("Unrecognized level-type!");
12731258
}
12741259

@@ -1824,18 +1809,11 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
18241809
auto [nxSz, stride] = sliceMeta[tid][lvl][1];
18251810
assert(stride == 1 && "Not yet implemented");
18261811
Value sPtrBuf = slicePosBuffer[tid][lvl][0];
1827-
Value pHi, pLo;
1828-
if (lvl == 0) {
1829-
pLo = c0;
1830-
// TODO: eliminate the cast upon feature complete.pLo = c0;
1831-
Value pBuf = static_cast<CompressedLevel &>(*lvls[tid][0]).posBuffer;
1832-
pHi = genIndexLoad(builder, loc, pBuf, c1);
1833-
} else {
1834-
// TODO: eliminate the cast upon feature complete.} else {
1835-
Value pBuf = static_cast<CompressedLevel &>(*lvls[tid][lvl]).posBuffer;
1836-
pLo = genIndexLoad(builder, loc, pBuf, posits[tid][lvl - 1]);
1837-
pHi = genIndexLoad(builder, loc, pBuf, ADDI(posits[tid][lvl - 1], c1));
1838-
}
1812+
const SparseTensorLevel &stl = *lvls[tid][lvl];
1813+
1814+
Value p = lvl == 0 ? c0 : posits[tid][lvl - 1];
1815+
auto [pLo, pHi] = stl.peekRangeAt(builder, loc, p);
1816+
18391817
// Fills out pIdxBuffer[tid][lvl][0] with [pLo, pHi]
18401818
updateSlicePos(builder, loc, sPtrBuf, pLo, c0, SlicePosKind::kLo);
18411819
updateSlicePos(builder, loc, sPtrBuf, pHi, c0, SlicePosKind::kHi);
@@ -1849,7 +1827,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
18491827
// nonempty. though we assume that even on empty sparse tensors, a non-empty
18501828
// ptr/idx buffer is allocated for each level so it would not cause OOB to
18511829
// avoid generating a ifOp here.
1852-
Value minCrd = lvls[tid][lvl]->peekCrdAt(builder, loc, pLo);
1830+
Value minCrd = stl.peekCrdAt(builder, loc, pLo);
18531831

18541832
// FIXME: We need the relative offset related to the base slice.
18551833
Value absOffset = offsetFromMinCoord(builder, loc, minCrd, nxSz, isNonEmpty);
@@ -1879,7 +1857,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
18791857
// }
18801858
void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
18811859
TensorId tid, Level lvl) {
1882-
Value c0 = C_IDX(0), c1 = C_IDX(1);
1860+
Value c0 = C_IDX(0);
18831861
unsigned depth = levelReducedDep[tid][lvl];
18841862
// The remaining slice size after reduction.
18851863
Value remSz = sliceMeta[tid][lvl][depth + 1].first;
@@ -1929,17 +1907,14 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
19291907

19301908
ValueRange result = genUnResolvedSliceTreeTraverse(
19311909
builder, loc, tid, unResSlices, firstResLvl, reduc,
1932-
[this, c1, tid, lvl, sPtrBuf](OpBuilder &builder, Location loc, Value iv,
1933-
MutableArrayRef<Value> reduc) {
1910+
[this, tid, lvl, sPtrBuf](OpBuilder &builder, Location loc, Value iv,
1911+
MutableArrayRef<Value> reduc) {
19341912
Value &nonEmpty = reduc[0];
19351913
Value &minCrd = reduc[1];
19361914
Value &curTupleCnt = reduc[2];
19371915

1938-
Value pHi = ADDI(iv, c1);
1939-
// TODO: eliminate the cast upon feature complete.
1940-
Value pBuf = static_cast<CompressedLevel &>(*lvls[tid][lvl]).posBuffer;
1941-
Value sPLo = genIndexLoad(builder, loc, pBuf, iv);
1942-
Value sPHi = genIndexLoad(builder, loc, pBuf, pHi);
1916+
const SparseTensorLevel &stl = *lvls[tid][lvl];
1917+
auto [sPLo, sPHi] = stl.peekRangeAt(builder, loc, iv);
19431918

19441919
// isNonEmpty = isNonEmpty || lvlNonEmpty, i.e., as long as there is
19451920
// one non-empty lvl, the slice is non-empty.
@@ -1957,7 +1932,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
19571932
// }
19581933
OpBuilder::InsertionGuard guard(builder);
19591934
builder.setInsertionPointToStart(ifNonEmpty.thenBlock());
1960-
Value curC = lvls[tid][lvl]->peekCrdAt(builder, loc, sPLo);
1935+
Value curC = stl.peekCrdAt(builder, loc, sPLo);
19611936
Value isSmaller = CMPI(ult, curC, minCrd);
19621937
Value newMin = SELECT(isSmaller, curC, minCrd);
19631938
YIELD(newMin);

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp

Lines changed: 135 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,133 @@
1313

1414
using namespace mlir;
1515
using namespace mlir::sparse_tensor;
16+
using ValuePair = std::pair<Value, Value>;
17+
18+
//===----------------------------------------------------------------------===//
19+
// File local helper functions/macros.
20+
//===----------------------------------------------------------------------===//
21+
#define CMPI(p, lhs, rhs) \
22+
(b.create<arith::CmpIOp>(l, arith::CmpIPredicate::p, (lhs), (rhs)))
23+
24+
#define C_IDX(v) (constantIndex(b, l, (v)))
25+
#define YIELD(vs) (b.create<scf::YieldOp>(l, (vs)))
26+
#define ADDI(lhs, rhs) (b.create<arith::AddIOp>(l, (lhs), (rhs)))
27+
#define ANDI(lhs, rhs) (b.create<arith::AndIOp>(l, (lhs), (rhs)))
28+
#define SUBI(lhs, rhs) (b.create<arith::SubIOp>(l, (lhs), (rhs)))
29+
#define MULI(lhs, rhs) (b.create<arith::MulIOp>(l, (lhs), (rhs)))
30+
#define REMUI(lhs, rhs) (b.create<arith::RemUIOp>(l, (lhs), (rhs)))
31+
#define DIVUI(lhs, rhs) (b.create<arith::DivUIOp>(l, (lhs), (rhs)))
32+
#define SELECT(c, lhs, rhs) (b.create<arith::SelectOp>(l, (c), (lhs), (rhs)))
33+
34+
static ValuePair constantRange(OpBuilder &b, Location l, Value lo, Value sz) {
35+
return std::make_pair(lo, ADDI(lo, sz));
36+
}
37+
38+
//===----------------------------------------------------------------------===//
39+
// SparseTensorLevel derived classes.
40+
//===----------------------------------------------------------------------===//
41+
42+
namespace {
43+
44+
class SparseLevel : public SparseTensorLevel {
45+
public:
46+
SparseLevel(LevelType lt, Value lvlSize, Value crdBuffer)
47+
: SparseTensorLevel(lt, lvlSize), crdBuffer(crdBuffer) {}
48+
49+
Value peekCrdAt(OpBuilder &b, Location l, Value pos) const override {
50+
return genIndexLoad(b, l, crdBuffer, pos);
51+
}
52+
53+
protected:
54+
const Value crdBuffer;
55+
};
56+
57+
class DenseLevel : public SparseTensorLevel {
58+
public:
59+
DenseLevel(Value lvlSize) : SparseTensorLevel(LevelType::Dense, lvlSize) {
60+
// Dense level, loop upper bound equals to the level size.
61+
loopHi = lvlSize;
62+
}
63+
64+
Value peekCrdAt(OpBuilder &, Location, Value pos) const override {
65+
return pos;
66+
}
67+
68+
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
69+
Value max) const override {
70+
assert(max == nullptr && "Dense level can not be non-unique.");
71+
return constantRange(b, l, C_IDX(0), lvlSize);
72+
}
73+
};
74+
75+
class CompressedLevel : public SparseLevel {
76+
public:
77+
CompressedLevel(LevelType lt, Value lvlSize, Value posBuffer, Value crdBuffer)
78+
: SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
79+
80+
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
81+
Value max) const override {
82+
if (max == nullptr) {
83+
Value pLo = genIndexLoad(b, l, posBuffer, p);
84+
Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
85+
return {pLo, pHi};
86+
}
87+
llvm_unreachable("TODO: dedup not implemented");
88+
}
89+
90+
private:
91+
const Value posBuffer;
92+
};
93+
94+
class LooseCompressedLevel : public SparseLevel {
95+
public:
96+
LooseCompressedLevel(LevelType lt, Value lvlSize, Value posBuffer,
97+
Value crdBuffer)
98+
: SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
99+
100+
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
101+
Value max) const override {
102+
// Allows this?
103+
assert(max == nullptr && "loss compressed level can not be non-unique.");
104+
105+
p = MULI(p, C_IDX(2));
106+
Value pLo = genIndexLoad(b, l, posBuffer, p);
107+
Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
108+
return {pLo, pHi};
109+
}
110+
111+
private:
112+
const Value posBuffer;
113+
};
114+
115+
class SingletonLevel : public SparseLevel {
116+
public:
117+
SingletonLevel(LevelType lt, Value lvlSize, Value crdBuffer)
118+
: SparseLevel(lt, lvlSize, crdBuffer) {}
119+
120+
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
121+
Value max) const override {
122+
if (max == nullptr)
123+
return constantRange(b, l, p, C_IDX(1));
124+
llvm_unreachable("TODO: dedup not implemented");
125+
}
126+
};
127+
128+
class TwoOutFourLevel : public SparseLevel {
129+
public:
130+
TwoOutFourLevel(LevelType lt, Value lvlSize, Value crdBuffer)
131+
: SparseLevel(lt, lvlSize, crdBuffer) {}
132+
133+
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
134+
Value max) const override {
135+
assert(max == nullptr && "2:4 level can not be non-unique.");
136+
// Each 2:4 block has exactly two specified elements.
137+
Value c2 = C_IDX(2);
138+
return constantRange(b, l, MULI(p, c2), c2);
139+
}
140+
};
141+
142+
} // namespace
16143

17144
std::unique_ptr<SparseTensorLevel>
18145
sparse_tensor::makeSparseTensorLevel(OpBuilder &builder, Location loc, Value t,
@@ -49,6 +176,11 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &builder, Location loc, Value t,
49176
llvm_unreachable("unrecognizable level format");
50177
}
51178

52-
Value SparseLevel::peekCrdAt(OpBuilder &b, Location l, Value pos) const {
53-
return genIndexLoad(b, l, crdBuffer, pos);
54-
}
179+
#undef CMPI
180+
#undef C_IDX
181+
#undef YIELD
182+
#undef ADDI
183+
#undef ANDI
184+
#undef SUBI
185+
#undef MULI
186+
#undef SELECT

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h

Lines changed: 9 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,22 @@ namespace sparse_tensor {
1717
class SparseTensorLevel {
1818
SparseTensorLevel(SparseTensorLevel &&) = delete;
1919
SparseTensorLevel(const SparseTensorLevel &) = delete;
20+
SparseTensorLevel &operator=(SparseTensorLevel &&) = delete;
21+
SparseTensorLevel &operator=(const SparseTensorLevel &) = delete;
2022

2123
public:
2224
SparseTensorLevel() : SparseTensorLevel(LevelType::Undef, nullptr){};
2325
virtual ~SparseTensorLevel() = default;
2426

2527
virtual Value peekCrdAt(OpBuilder &b, Location l, Value p) const = 0;
2628

29+
/// Peeks the lower and upper bound to *fully* traverse the level with
30+
/// the given position `p` that the immediate parent level is current at.
31+
/// `bound` is only used when the level is `non-unique` and deduplication is
32+
/// required. It specifies the max upper bound of the non-unique segment.
33+
virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l, Value p,
34+
Value bound = Value()) const = 0;
35+
2736
LevelType getLT() const { return lt; }
2837
Value getPos() const { return pos; }
2938
Value getCrd() const { return crd; }
@@ -49,60 +58,6 @@ class SparseTensorLevel {
4958
std::unique_ptr<SparseTensorLevel>
5059
makeSparseTensorLevel(OpBuilder &builder, Location loc, Value t, Level l);
5160

52-
class DenseLevel : public SparseTensorLevel {
53-
public:
54-
DenseLevel(Value lvlSize) : SparseTensorLevel(LevelType::Dense, lvlSize) {
55-
// Dense level, loop upper bound equals to the level size.
56-
loopHi = lvlSize;
57-
}
58-
59-
Value peekCrdAt(OpBuilder &, Location, Value pos) const override {
60-
return pos;
61-
}
62-
};
63-
64-
class SparseLevel : public SparseTensorLevel {
65-
public:
66-
SparseLevel(LevelType lt, Value lvlSize, Value crdBuffer)
67-
: SparseTensorLevel(lt, lvlSize), crdBuffer(crdBuffer) {}
68-
69-
Value peekCrdAt(OpBuilder &b, Location l, Value pos) const override;
70-
71-
public: // TODO: make these values private upon feature complete.
72-
const Value crdBuffer;
73-
};
74-
75-
class CompressedLevel : public SparseLevel {
76-
public:
77-
CompressedLevel(LevelType lt, Value lvlSize, Value posBuffer, Value crdBuffer)
78-
: SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
79-
80-
public: // TODO: make these values private upon feature complete.
81-
const Value posBuffer;
82-
};
83-
84-
class LooseCompressedLevel : public SparseLevel {
85-
public:
86-
LooseCompressedLevel(LevelType lt, Value lvlSize, Value posBuffer,
87-
Value crdBuffer)
88-
: SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
89-
90-
public: // TODO: make these values private upon feature complete.
91-
const Value posBuffer;
92-
};
93-
94-
class SingletonLevel : public SparseLevel {
95-
public:
96-
SingletonLevel(LevelType lt, Value lvlSize, Value crdBuffer)
97-
: SparseLevel(lt, lvlSize, crdBuffer) {}
98-
};
99-
100-
class TwoOutFourLevel : public SparseLevel {
101-
public:
102-
TwoOutFourLevel(LevelType lt, Value lvlSize, Value crdBuffer)
103-
: SparseLevel(lt, lvlSize, crdBuffer) {}
104-
};
105-
10661
} // namespace sparse_tensor
10762
} // namespace mlir
10863

0 commit comments

Comments
 (0)