@@ -42,6 +42,50 @@ static Value genIndexLoad(OpBuilder &builder, Location loc, Value ptr,
42
42
return load;
43
43
}
44
44
45
+ // TODO: Support dynamic sized slice.
46
+ static Value getSliceOffset (OpBuilder &builder, Location loc,
47
+ SparseTensorEncodingAttr enc, unsigned lvl) {
48
+ return constantIndex (builder, loc, *enc.getStaticLvlSliceOffset (lvl));
49
+ }
50
+
51
+ static Value getSliceSize (OpBuilder &builder, Location loc,
52
+ SparseTensorEncodingAttr enc, unsigned lvl) {
53
+ return constantIndex (builder, loc, *enc.getStaticLvlSliceSize (lvl));
54
+ }
55
+
56
+ static Value getSliceStride (OpBuilder &builder, Location loc,
57
+ SparseTensorEncodingAttr enc, unsigned lvl) {
58
+ return constantIndex (builder, loc, *enc.getStaticLvlSliceStride (lvl));
59
+ }
60
+
61
+ // Converts a coordinate relative to the slice to the coordinate relative
62
+ // to the underlying tensor.
63
+ static Value toSliceCoord (OpBuilder &builder, Location loc, Value v,
64
+ SparseTensorEncodingAttr enc, unsigned lvl) {
65
+
66
+ Value stride = getSliceStride (builder, loc, enc, lvl);
67
+ Value offset = getSliceOffset (builder, loc, enc, lvl);
68
+ // iv = iv * stride + offset
69
+ v = builder.create <arith::MulIOp>(loc, v, stride);
70
+ v = builder.create <arith::AddIOp>(loc, v, offset);
71
+ return v;
72
+ }
73
+
74
+ // Converts a coordinate relative to the underlying tensor to the coordinate
75
+ // relative to the slice, returns a extra reminder value
76
+ static std::pair<Value, Value> fromSliceCoord (OpBuilder &builder, Location loc,
77
+ Value v,
78
+ SparseTensorEncodingAttr enc,
79
+ unsigned lvl) {
80
+ Value stride = getSliceStride (builder, loc, enc, lvl);
81
+ Value offset = getSliceOffset (builder, loc, enc, lvl);
82
+ // iv = (iv - offset) / stride
83
+ v = builder.create <arith::SubIOp>(loc, v, offset);
84
+ Value rem = builder.create <arith::RemUIOp>(loc, v, stride);
85
+ v = builder.create <arith::DivUIOp>(loc, v, stride);
86
+ return std::make_pair (v, rem);
87
+ }
88
+
45
89
// ===----------------------------------------------------------------------===//
46
90
// Sparse tensor loop emitter class implementations
47
91
// ===----------------------------------------------------------------------===//
@@ -50,6 +94,10 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, size_t tid,
50
94
size_t dim, Value iv) {
51
95
Value p = dim == 0 ? constantIndex (builder, loc, 0 ) : pidxs[tid][dim - 1 ];
52
96
Value mul = builder.create <arith::MulIOp>(loc, highs[tid][dim], p);
97
+ if (isSparseSlices[tid]) {
98
+ auto enc = getSparseTensorEncoding (tensors[tid].getType ());
99
+ iv = toSliceCoord (builder, loc, iv, enc, dim);
100
+ }
53
101
Value add = builder.create <arith::AddIOp>(loc, mul, iv);
54
102
return add;
55
103
}
@@ -67,6 +115,7 @@ void LoopEmitter::initialize(ValueRange tensors, StringAttr loopTag,
67
115
this ->hasOutput = hasOutput;
68
116
this ->isSparseOut = isSparseOut;
69
117
this ->tensors .assign (tensors.begin (), tensors.end ());
118
+ this ->isSparseSlices .assign (tensors.size (), false );
70
119
this ->dimTypes .assign (tensors.size (), std::vector<DimLevelType>());
71
120
this ->pidxs .assign (tensors.size (), std::vector<Value>());
72
121
this ->coord .assign (tensors.size (), std::vector<Value>());
@@ -87,10 +136,11 @@ void LoopEmitter::initialize(ValueRange tensors, StringAttr loopTag,
87
136
auto enc = getSparseTensorEncoding (rtp);
88
137
// We always treat sparse output tensor as dense so that we always iterate
89
138
// it based on dim size.
90
- if (enc && !(isOutputTensor (tid) && isSparseOut))
139
+ if (enc && !(isOutputTensor (tid) && isSparseOut)) {
140
+ isSparseSlices[tid] = enc.isSlice ();
91
141
for (auto dimTp : enc.getDimLevelType ())
92
142
dimTypes[tid].push_back (dimTp);
93
- else
143
+ } else
94
144
dimTypes[tid].assign (rank, DimLevelType::Dense);
95
145
96
146
// Initialize using empty value.
@@ -218,7 +268,6 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
218
268
ArrayRef<size_t > dims, MutableArrayRef<Value> reduc, bool isParallel) {
219
269
// TODO: support multiple return on parallel for?
220
270
assert (!isParallel || reduc.size () <= 1 );
221
-
222
271
bool isSparseInput = false ;
223
272
size_t tid = tids.front (), dim = dims.front ();
224
273
for (auto [t, d] : llvm::zip (tids, dims)) {
@@ -239,10 +288,13 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
239
288
isSparseInput = isSparseInput || isSparse;
240
289
}
241
290
291
+ auto enc = getSparseTensorEncoding (tensors[tid].getType ());
292
+ // TODO: support dynamic slices.
242
293
Value step = constantIndex (builder, loc, 1 );
243
294
Value lo = isSparseInput ? pidxs[tid][dim] // current offset
244
- : loopSeqStack.back (); // univeral tid
295
+ : loopSeqStack.back (); // universal index
245
296
Value hi = highs[tid][dim];
297
+
246
298
Operation *loop = nullptr ;
247
299
Value iv;
248
300
if (isParallel) {
@@ -275,15 +327,64 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
275
327
}
276
328
assert (loop && iv);
277
329
330
+ Value c;
278
331
if (isSparseInput) {
279
332
pidxs[tid][dim] = iv;
280
333
// Generating a load on the indices array yields the coordinate.
281
334
Value ptr = idxBuffer[tid][dim];
282
- coord[tid][dim] = genIndexLoad (builder, loc, ptr, iv);
335
+ c = genIndexLoad (builder, loc, ptr, iv);
283
336
} else {
284
337
// Dense tensor, the coordinates is the inducation variable.
285
- coord[tid][dim] = iv;
338
+ c = iv;
286
339
}
340
+
341
+ if (isSparseSlices[tid] && isSparseInput) {
342
+ // For sparse level slices, we need to filter out invalid coordinates that
343
+ // are not included in the slice.
344
+ std::pair<Value, Value> trans = fromSliceCoord (builder, loc, c, enc, dim);
345
+ SmallVector<Type> types;
346
+ for (Value red : reduc)
347
+ types.push_back (red.getType ());
348
+
349
+ // First, coord >= offset (TODO: seems unsigned >= 0 won't be folded, skip
350
+ // the check if the offset is zero).
351
+ auto geOff =
352
+ builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::uge, c,
353
+ getSliceOffset (builder, loc, enc, dim));
354
+ // Second, coords < length
355
+ auto ltLen = builder.create <arith::CmpIOp>(
356
+ loc, arith::CmpIPredicate::ult, trans.first ,
357
+ getSliceSize (builder, loc, enc, dim));
358
+
359
+ // Third, rem == 0; confirmed that (a % 1) will be folded to 0
360
+ auto fitStride = builder.create <arith::CmpIOp>(
361
+ loc, arith::CmpIPredicate::eq, trans.second ,
362
+ constantIndex (builder, loc, 0 ));
363
+
364
+ auto pred = builder.create <arith::AndIOp>(loc, geOff, ltLen);
365
+ pred = builder.create <arith::AndIOp>(loc, pred, fitStride);
366
+ bool hasReduc = !types.empty ();
367
+ scf::IfOp ifOp =
368
+ builder.create <scf::IfOp>(loc, types, pred, /* else*/ hasReduc);
369
+ if (hasReduc) {
370
+ // scf.for (a) -> v
371
+ // %s = scf.if (a) -> v
372
+ // user-generated code.
373
+ // else
374
+ // yield a
375
+ // yield %s
376
+ builder.create <scf::YieldOp>(loc, ifOp.getResults ());
377
+ builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
378
+ // On mismatch.
379
+ builder.create <scf::YieldOp>(loc, reduc);
380
+ }
381
+ // Set the insertion point to matched branch.
382
+ builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
383
+ c = trans.first ;
384
+ }
385
+
386
+ assert (c);
387
+ coord[tid][dim] = c;
287
388
// NOTE: we can also prepare for next dim here in advance
288
389
// Push the loop into stack
289
390
loopStack.emplace_back (ArrayRef<size_t >(tid), ArrayRef<size_t >(dim), loop,
0 commit comments