Skip to content

Commit 0352690

Browse files
author
Peiming Liu
committed
[mlir][sparse] make foreach operation support sparse tensor slices.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D140713
1 parent dc6f8ef commit 0352690

File tree

5 files changed

+259
-24
lines changed

5 files changed

+259
-24
lines changed

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,10 @@ getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
534534
enc.getContext(), dlts,
535535
AffineMap(), // dimOrdering (irrelavant to storage speicifer)
536536
AffineMap(), // highLvlOrdering (irrelavant to storage specifer)
537-
enc.getPointerBitWidth(), enc.getIndexBitWidth());
537+
enc.getPointerBitWidth(), enc.getIndexBitWidth(),
538+
// FIXME: we should keep the slice information, for now it is okay as only
539+
// constant can be used for slice
540+
ArrayRef<SparseTensorDimSliceAttr>{} /*enc.getDimSlices()*/);
538541
}
539542

540543
StorageSpecifierType

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp

Lines changed: 107 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,50 @@ static Value genIndexLoad(OpBuilder &builder, Location loc, Value ptr,
4242
return load;
4343
}
4444

45+
// TODO: Support dynamic sized slice.
46+
static Value getSliceOffset(OpBuilder &builder, Location loc,
47+
SparseTensorEncodingAttr enc, unsigned lvl) {
48+
return constantIndex(builder, loc, *enc.getStaticLvlSliceOffset(lvl));
49+
}
50+
51+
static Value getSliceSize(OpBuilder &builder, Location loc,
52+
SparseTensorEncodingAttr enc, unsigned lvl) {
53+
return constantIndex(builder, loc, *enc.getStaticLvlSliceSize(lvl));
54+
}
55+
56+
static Value getSliceStride(OpBuilder &builder, Location loc,
57+
SparseTensorEncodingAttr enc, unsigned lvl) {
58+
return constantIndex(builder, loc, *enc.getStaticLvlSliceStride(lvl));
59+
}
60+
61+
// Converts a coordinate relative to the slice to the coordinate relative
62+
// to the underlying tensor.
63+
static Value toSliceCoord(OpBuilder &builder, Location loc, Value v,
64+
SparseTensorEncodingAttr enc, unsigned lvl) {
65+
66+
Value stride = getSliceStride(builder, loc, enc, lvl);
67+
Value offset = getSliceOffset(builder, loc, enc, lvl);
68+
// iv = iv * stride + offset
69+
v = builder.create<arith::MulIOp>(loc, v, stride);
70+
v = builder.create<arith::AddIOp>(loc, v, offset);
71+
return v;
72+
}
73+
74+
// Converts a coordinate relative to the underlying tensor to the coordinate
75+
// relative to the slice, returns a extra reminder value
76+
static std::pair<Value, Value> fromSliceCoord(OpBuilder &builder, Location loc,
77+
Value v,
78+
SparseTensorEncodingAttr enc,
79+
unsigned lvl) {
80+
Value stride = getSliceStride(builder, loc, enc, lvl);
81+
Value offset = getSliceOffset(builder, loc, enc, lvl);
82+
// iv = (iv - offset) / stride
83+
v = builder.create<arith::SubIOp>(loc, v, offset);
84+
Value rem = builder.create<arith::RemUIOp>(loc, v, stride);
85+
v = builder.create<arith::DivUIOp>(loc, v, stride);
86+
return std::make_pair(v, rem);
87+
}
88+
4589
//===----------------------------------------------------------------------===//
4690
// Sparse tensor loop emitter class implementations
4791
//===----------------------------------------------------------------------===//
@@ -50,6 +94,10 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, size_t tid,
5094
size_t dim, Value iv) {
5195
Value p = dim == 0 ? constantIndex(builder, loc, 0) : pidxs[tid][dim - 1];
5296
Value mul = builder.create<arith::MulIOp>(loc, highs[tid][dim], p);
97+
if (isSparseSlices[tid]) {
98+
auto enc = getSparseTensorEncoding(tensors[tid].getType());
99+
iv = toSliceCoord(builder, loc, iv, enc, dim);
100+
}
53101
Value add = builder.create<arith::AddIOp>(loc, mul, iv);
54102
return add;
55103
}
@@ -67,6 +115,7 @@ void LoopEmitter::initialize(ValueRange tensors, StringAttr loopTag,
67115
this->hasOutput = hasOutput;
68116
this->isSparseOut = isSparseOut;
69117
this->tensors.assign(tensors.begin(), tensors.end());
118+
this->isSparseSlices.assign(tensors.size(), false);
70119
this->dimTypes.assign(tensors.size(), std::vector<DimLevelType>());
71120
this->pidxs.assign(tensors.size(), std::vector<Value>());
72121
this->coord.assign(tensors.size(), std::vector<Value>());
@@ -87,10 +136,11 @@ void LoopEmitter::initialize(ValueRange tensors, StringAttr loopTag,
87136
auto enc = getSparseTensorEncoding(rtp);
88137
// We always treat sparse output tensor as dense so that we always iterate
89138
// it based on dim size.
90-
if (enc && !(isOutputTensor(tid) && isSparseOut))
139+
if (enc && !(isOutputTensor(tid) && isSparseOut)) {
140+
isSparseSlices[tid] = enc.isSlice();
91141
for (auto dimTp : enc.getDimLevelType())
92142
dimTypes[tid].push_back(dimTp);
93-
else
143+
} else
94144
dimTypes[tid].assign(rank, DimLevelType::Dense);
95145

96146
// Initialize using empty value.
@@ -218,7 +268,6 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
218268
ArrayRef<size_t> dims, MutableArrayRef<Value> reduc, bool isParallel) {
219269
// TODO: support multiple return on parallel for?
220270
assert(!isParallel || reduc.size() <= 1);
221-
222271
bool isSparseInput = false;
223272
size_t tid = tids.front(), dim = dims.front();
224273
for (auto [t, d] : llvm::zip(tids, dims)) {
@@ -239,10 +288,13 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
239288
isSparseInput = isSparseInput || isSparse;
240289
}
241290

291+
auto enc = getSparseTensorEncoding(tensors[tid].getType());
292+
// TODO: support dynamic slices.
242293
Value step = constantIndex(builder, loc, 1);
243294
Value lo = isSparseInput ? pidxs[tid][dim] // current offset
244-
: loopSeqStack.back(); // univeral tid
295+
: loopSeqStack.back(); // universal index
245296
Value hi = highs[tid][dim];
297+
246298
Operation *loop = nullptr;
247299
Value iv;
248300
if (isParallel) {
@@ -275,15 +327,64 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
275327
}
276328
assert(loop && iv);
277329

330+
Value c;
278331
if (isSparseInput) {
279332
pidxs[tid][dim] = iv;
280333
// Generating a load on the indices array yields the coordinate.
281334
Value ptr = idxBuffer[tid][dim];
282-
coord[tid][dim] = genIndexLoad(builder, loc, ptr, iv);
335+
c = genIndexLoad(builder, loc, ptr, iv);
283336
} else {
284337
// Dense tensor, the coordinates is the inducation variable.
285-
coord[tid][dim] = iv;
338+
c = iv;
286339
}
340+
341+
if (isSparseSlices[tid] && isSparseInput) {
342+
// For sparse level slices, we need to filter out invalid coordinates that
343+
// are not included in the slice.
344+
std::pair<Value, Value> trans = fromSliceCoord(builder, loc, c, enc, dim);
345+
SmallVector<Type> types;
346+
for (Value red : reduc)
347+
types.push_back(red.getType());
348+
349+
// First, coord >= offset (TODO: seems unsigned >= 0 won't be folded, skip
350+
// the check if the offset is zero).
351+
auto geOff =
352+
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, c,
353+
getSliceOffset(builder, loc, enc, dim));
354+
// Second, coords < length
355+
auto ltLen = builder.create<arith::CmpIOp>(
356+
loc, arith::CmpIPredicate::ult, trans.first,
357+
getSliceSize(builder, loc, enc, dim));
358+
359+
// Third, rem == 0; confirmed that (a % 1) will be folded to 0
360+
auto fitStride = builder.create<arith::CmpIOp>(
361+
loc, arith::CmpIPredicate::eq, trans.second,
362+
constantIndex(builder, loc, 0));
363+
364+
auto pred = builder.create<arith::AndIOp>(loc, geOff, ltLen);
365+
pred = builder.create<arith::AndIOp>(loc, pred, fitStride);
366+
bool hasReduc = !types.empty();
367+
scf::IfOp ifOp =
368+
builder.create<scf::IfOp>(loc, types, pred, /*else*/ hasReduc);
369+
if (hasReduc) {
370+
// scf.for (a) -> v
371+
// %s = scf.if (a) -> v
372+
// user-generated code.
373+
// else
374+
// yield a
375+
// yield %s
376+
builder.create<scf::YieldOp>(loc, ifOp.getResults());
377+
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
378+
// On mismatch.
379+
builder.create<scf::YieldOp>(loc, reduc);
380+
}
381+
// Set the insertion point to matched branch.
382+
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
383+
c = trans.first;
384+
}
385+
386+
assert(c);
387+
coord[tid][dim] = c;
287388
// NOTE: we can also prepare for next dim here in advance
288389
// Push the loop into stack
289390
loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), loop,

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -259,22 +259,25 @@ class LoopEmitter {
259259
std::vector<std::vector<Value>> idxBuffer; // to_indices
260260
std::vector<Value> valBuffer; // to_value
261261

262-
// Loop Stack, stores the information of all the nested loops that are
263-
// alive.
262+
/// Whether the sparse input is a slice.
263+
std::vector<bool> isSparseSlices;
264+
265+
/// Loop Stack, stores the information of all the nested loops that are
266+
/// alive.
264267
std::vector<LoopLevelInfo> loopStack;
265268

266-
// Loop Sequence Stack, stores the unversial index for the current loop
267-
// sequence.
269+
/// Loop Sequence Stack, stores the unversial index for the current loop
270+
/// sequence.
268271
std::vector<Value> loopSeqStack;
269272

270-
// Maps AffineDimExpr to the index of the loop in loopStack.
271-
// TODO: We should probably use a callback function here to make it more
272-
// general.
273+
/// Maps AffineDimExpr to the index of the loop in loopStack.
274+
/// TODO: We should probably use a callback function here to make it more
275+
/// general.
273276
std::vector<unsigned> sparsiferLoopLvlMap;
274277

275-
// TODO: not yet used, it should track the current level for each tensor
276-
// to help eliminate `dim` paramters from above APIs.
277-
// std::vector<size_t> curLv;
278+
/// TODO: not yet used, it should track the current level for each tensor
279+
/// to help eliminate `dim` paramters from above APIs.
280+
/// std::vector<size_t> curLv;
278281
};
279282

280283
} // namespace sparse_tensor

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,40 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
10101010
}
10111011
};
10121012

1013+
class SparseExtractSliceCoverter
1014+
: public OpConversionPattern<tensor::ExtractSliceOp> {
1015+
public:
1016+
using OpConversionPattern::OpConversionPattern;
1017+
LogicalResult
1018+
matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
1019+
ConversionPatternRewriter &rewriter) const override {
1020+
auto srcEnc = getSparseTensorEncoding(op.getSourceType());
1021+
auto dstEnc = getSparseTensorEncoding(op.getResult().getType());
1022+
if (!srcEnc && !dstEnc)
1023+
return failure();
1024+
1025+
// TODO: We should check these in ExtractSliceOp::verify.
1026+
assert(srcEnc && dstEnc && dstEnc.isSlice());
1027+
assert(srcEnc.getDimLevelType() == dstEnc.getDimLevelType());
1028+
assert(srcEnc.getDimOrdering() == dstEnc.getDimOrdering());
1029+
assert(srcEnc.getHigherOrdering() == dstEnc.getHigherOrdering());
1030+
assert(srcEnc.getPointerBitWidth() == dstEnc.getPointerBitWidth());
1031+
assert(srcEnc.getIndexBitWidth() == dstEnc.getIndexBitWidth());
1032+
1033+
// TODO: support dynamic slices.
1034+
for (int i = 0, e = op.getSourceType().getRank(); i < e; i++) {
1035+
assert(op.getStaticStrides()[i] == dstEnc.getStaticDimSliceStride(i));
1036+
assert(op.getStaticOffsets()[i] == dstEnc.getStaticDimSliceOffset(i));
1037+
assert(op.getStaticSizes()[i] == dstEnc.getStaticDimSliceSize(i));
1038+
}
1039+
1040+
// TODO: create a new specifer for slices (need to encode slice metadata).
1041+
// It does not matter now because only constant offset/stride are allowed.
1042+
rewriter.replaceOp(op, adaptor.getSource());
1043+
return success();
1044+
}
1045+
};
1046+
10131047
/// Sparse codegen rule for number of entries operator.
10141048
class SparseNumberOfEntriesConverter
10151049
: public OpConversionPattern<NumberOfEntriesOp> {
@@ -1133,13 +1167,13 @@ void mlir::populateSparseTensorCodegenPatterns(
11331167
bool enableBufferInitialization) {
11341168
patterns.add<SparsePackOpConverter, SparseReturnConverter,
11351169
SparseCallConverter, SparseDimOpConverter, SparseCastConverter,
1136-
SparseTensorDeallocConverter, SparseTensorLoadConverter,
1137-
SparseExpandConverter, SparseCompressConverter,
1138-
SparseInsertConverter, SparseToPointersConverter,
1139-
SparseToIndicesConverter, SparseToIndicesBufferConverter,
1140-
SparseToValuesConverter, SparseConvertConverter,
1141-
SparseNumberOfEntriesConverter>(typeConverter,
1142-
patterns.getContext());
1170+
SparseTensorDeallocConverter, SparseExtractSliceCoverter,
1171+
SparseTensorLoadConverter, SparseExpandConverter,
1172+
SparseCompressConverter, SparseInsertConverter,
1173+
SparseToPointersConverter, SparseToIndicesConverter,
1174+
SparseToIndicesBufferConverter, SparseToValuesConverter,
1175+
SparseConvertConverter, SparseNumberOfEntriesConverter>(
1176+
typeConverter, patterns.getContext());
11431177
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
11441178
enableBufferInitialization);
11451179
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// DEFINE: %{option} = enable-runtime-library=false
2+
// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \
3+
// DEFINE: mlir-cpu-runner \
4+
// DEFINE: -e entry -entry-point-result=void \
5+
// DEFINE: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
6+
// DEFINE: FileCheck %s
7+
//
8+
// RUN: %{command}
9+
//
10+
11+
// TODO: support slices on lib path
12+
#CSR = #sparse_tensor.encoding<{
13+
dimLevelType = [ "dense", "compressed" ]
14+
}>
15+
16+
#CSR_SLICE = #sparse_tensor.encoding<{
17+
dimLevelType = [ "dense", "compressed" ],
18+
slice = [ (1, 4, 1), (1, 4, 2) ]
19+
}>
20+
21+
module {
22+
func.func @foreach_print_non_slice(%A: tensor<4x4xf64, #CSR>) {
23+
sparse_tensor.foreach in %A : tensor<4x4xf64, #CSR> do {
24+
^bb0(%1: index, %2: index, %v: f64) :
25+
vector.print %1: index
26+
vector.print %2: index
27+
vector.print %v: f64
28+
}
29+
return
30+
}
31+
32+
func.func @foreach_print_slice(%A: tensor<4x4xf64, #CSR_SLICE>) {
33+
sparse_tensor.foreach in %A : tensor<4x4xf64, #CSR_SLICE> do {
34+
^bb0(%1: index, %2: index, %v: f64) :
35+
vector.print %1: index
36+
vector.print %2: index
37+
vector.print %v: f64
38+
}
39+
return
40+
}
41+
42+
func.func @entry() {
43+
%c0 = arith.constant 0 : index
44+
%sa = arith.constant dense<[
45+
[ 0.0, 2.1, 0.0, 0.0, 0.0, 6.1, 0.0, 0.0 ],
46+
[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
47+
[ 0.0, 2.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
48+
[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 ],
49+
[ 0.0, 0.0, 0.1, 0.0, 0.0, 2.1, 0.0, 0.0 ],
50+
[ 0.0, 0.0, 0.0, 0.0, 3.1, 0.0, 0.0, 0.0 ],
51+
[ 0.0, 2.3, 0.0, 0.0, 0.0, 0.0, 3.3, 0.0 ],
52+
[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 ]
53+
]> : tensor<8x8xf64>
54+
55+
%tmp = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #CSR>
56+
%a = tensor.extract_slice %tmp[1, 1][4, 4][1, 2] : tensor<8x8xf64, #CSR> to
57+
tensor<4x4xf64, #CSR_SLICE>
58+
// Foreach on sparse tensor slices directly
59+
//
60+
// CHECK: 1
61+
// CHECK-NEXT: 0
62+
// CHECK-NEXT: 2.3
63+
// CHECK-NEXT: 2
64+
// CHECK-NEXT: 3
65+
// CHECK-NEXT: 1
66+
// CHECK-NEXT: 3
67+
// CHECK-NEXT: 2
68+
// CHECK-NEXT: 2.1
69+
//
70+
call @foreach_print_slice(%a) : (tensor<4x4xf64, #CSR_SLICE>) -> ()
71+
72+
// FIXME: investigate why a tensor copy is inserted for this slice
73+
// %dense = tensor.extract_slice %sa[1, 1][4, 4][1, 2] : tensor<8x8xf64> to
74+
// tensor<4x4xf64>
75+
// %b = sparse_tensor.convert %dense : tensor<4x4xf64> to tensor<4x4xf64, #CSR>
76+
// // Foreach on sparse tensor instead of slice they should yield the same result.
77+
// //
78+
// // C_HECK-NEXT: 1
79+
// // C_HECK-NEXT: 0
80+
// // C_HECK-NEXT: 2.3
81+
// // C_HECK-NEXT: 2
82+
// // C_HECK-NEXT: 3
83+
// // C_HECK-NEXT: 1
84+
// // C_HECK-NEXT: 3
85+
// // C_HECK-NEXT: 2
86+
// // C_HECK-NEXT: 2.1
87+
// //
88+
// call @foreach_print_non_slice(%b) : (tensor<4x4xf64, #CSR>) -> ()
89+
// bufferization.dealloc_tensor %b : tensor<4x4xf64, #CSR>
90+
91+
bufferization.dealloc_tensor %tmp : tensor<8x8xf64, #CSR>
92+
return
93+
}
94+
}

0 commit comments

Comments
 (0)