Skip to content

Commit a02010b

Browse files
author
Peiming Liu
authored
[mlir][sparse] support sparsifying sparse kernels to sparse-iterator-based loop (#95858)
1 parent c67ecf3 commit a02010b

40 files changed

+745
-474
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class LevelSet {
8989
assert(i < 64);
9090
return (bits & (1 << i)) != 0;
9191
}
92-
92+
unsigned max() const { return 64 - llvm::countl_zero(bits); }
9393
unsigned count() const { return llvm::popcount(bits); }
9494
bool empty() const { return bits == 0; }
9595
};

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,6 +1493,10 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
14931493
```
14941494
}];
14951495

1496+
let arguments = (ins AnySparseTensor:$tensor,
1497+
Optional<AnySparseIterator>:$parentIter,
1498+
LevelAttr:$loLvl, LevelAttr:$hiLvl);
1499+
let results = (outs AnySparseIterSpace:$extractedSpace);
14961500

14971501
let extraClassDeclaration = [{
14981502
std::pair<Level, Level> getLvlRange() {
@@ -1506,10 +1510,20 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
15061510
}
15071511
}];
15081512

1509-
let arguments = (ins AnySparseTensor:$tensor,
1510-
Optional<AnySparseIterator>:$parentIter,
1511-
LevelAttr:$loLvl, LevelAttr:$hiLvl);
1512-
let results = (outs AnySparseIterSpace:$extractedSpace);
1513+
let builders = [
1514+
// Construct a 1-D iteration space.
1515+
OpBuilder<(ins "Value":$tensor, "Value":$parentIter,
1516+
"sparse_tensor::Level":$loLvl),
1517+
[{
1518+
build($_builder, $_state, tensor, parentIter, loLvl, loLvl + 1);
1519+
}]>,
1520+
// Construct a 1-D root iteration space
1521+
OpBuilder<(ins "Value":$tensor),
1522+
[{
1523+
build($_builder, $_state, tensor, nullptr, 0);
1524+
}]>
1525+
];
1526+
15131527
let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
15141528
" attr-dict `:` type($tensor) (`,` type($parentIter)^)? "
15151529
"`->` qualified(type($extractedSpace))";
@@ -1594,6 +1608,12 @@ def IterateOp : SparseTensor_Op<"iterate",
15941608
let results = (outs Variadic<AnyType>:$results);
15951609
let regions = (region SizedRegion<1>:$region);
15961610

1611+
let skipDefaultBuilders = 1;
1612+
let builders = [
1613+
OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs)>,
1614+
OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs, "LevelSet" :$crdUsedLvls)>
1615+
];
1616+
15971617
let extraClassDeclaration = [{
15981618
unsigned getSpaceDim() {
15991619
return getIterSpace().getType().getSpaceDim();

mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,20 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
5151
mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop,
5252
"any-storage-any-loop",
5353
"Enable sparse parallelization for any storage and loop."))};
54+
PassOptions::Option<mlir::SparseEmitStrategy> emitStrategy{
55+
*this, "sparse-emit-strategy",
56+
::llvm::cl::desc(
57+
"Emit functional code or interfaces (to debug) for sparse loops"),
58+
::llvm::cl::init(mlir::SparseEmitStrategy::kFunctional),
59+
llvm::cl::values(
60+
clEnumValN(mlir::SparseEmitStrategy::kFunctional, "functional",
61+
"Emit functional code (with scf.for/while)."),
62+
clEnumValN(mlir::SparseEmitStrategy::kSparseIterator,
63+
"sparse-iterator",
64+
"Emit (experimental) loops (with sparse.iterate)."),
65+
clEnumValN(
66+
mlir::SparseEmitStrategy::kDebugInterface, "debug-interface",
67+
"Emit non-functional but easy-to-read interfaces to debug."))};
5468

5569
PassOptions::Option<bool> enableRuntimeLibrary{
5670
*this, "enable-runtime-library",
@@ -143,7 +157,8 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
143157

144158
/// Projects out the options for `createSparsificationPass`.
145159
SparsificationOptions sparsificationOptions() const {
146-
return SparsificationOptions(parallelization, enableRuntimeLibrary);
160+
return SparsificationOptions(parallelization, emitStrategy,
161+
enableRuntimeLibrary);
147162
}
148163

149164
/// Projects out the options for `createConvertVectorToLLVMPass`.

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ enum class ReinterpretMapScope {
5151
/// Defines a scope for reinterpret map pass.
5252
enum class SparseEmitStrategy {
5353
kFunctional, // generate fully inlined (and functional) sparse iteration
54+
kSparseIterator, // generate (experimental) loop using sparse iterator.
5455
kDebugInterface, // generate only place-holder for sparse iteration
5556
};
5657

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,9 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
163163
"mlir::SparseEmitStrategy::kFunctional",
164164
"Emit functional code or interfaces (to debug) for sparse loops", [{llvm::cl::values(
165165
clEnumValN(mlir::SparseEmitStrategy::kFunctional, "functional",
166-
"Emit functional code."),
166+
"Emit functional code (with scf.for/while)."),
167+
clEnumValN(mlir::SparseEmitStrategy::kSparseIterator, "sparse-iterator",
168+
"Emit (experimental) loops (with sparse.iterate)."),
167169
clEnumValN(mlir::SparseEmitStrategy::kDebugInterface, "debug-interface",
168170
"Emit non-functional but easy-to-read interfaces to debug."))}]>,
169171
Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2300,6 +2300,41 @@ void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
23002300
results.add<RemoveUnusedLvlCrds>(context);
23012301
}
23022302

2303+
void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2304+
Value iterSpace, ValueRange initArgs) {
2305+
unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
2306+
// All ones.
2307+
LevelSet set((1 << rank) - 1);
2308+
return build(builder, odsState, iterSpace, initArgs, set);
2309+
}
2310+
2311+
void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2312+
Value iterSpace, ValueRange initArgs,
2313+
LevelSet crdUsedLvls) {
2314+
OpBuilder::InsertionGuard guard(builder);
2315+
2316+
odsState.addOperands(iterSpace);
2317+
odsState.addOperands(initArgs);
2318+
odsState.getOrAddProperties<Properties>().crdUsedLvls =
2319+
builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
2320+
Region *bodyRegion = odsState.addRegion();
2321+
odsState.addTypes(initArgs.getTypes());
2322+
Block *bodyBlock = builder.createBlock(bodyRegion);
2323+
2324+
// First argument, sparse iterator
2325+
bodyBlock->addArgument(
2326+
llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
2327+
odsState.location);
2328+
2329+
// Followed by a list of used coordinates.
2330+
for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
2331+
bodyBlock->addArgument(builder.getIndexType(), odsState.location);
2332+
2333+
// Followed by a list of user-provided loop arguments.
2334+
for (Value v : initArgs)
2335+
bodyBlock->addArgument(v.getType(), v.getLoc());
2336+
}
2337+
23032338
ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
23042339
OpAsmParser::Argument iterator;
23052340
OpAsmParser::UnresolvedOperand iterSpace;
@@ -2384,6 +2419,9 @@ LogicalResult IterateOp::verify() {
23842419
return emitOpError(
23852420
"mismatch in number of loop-carried values and defined values");
23862421
}
2422+
if (getCrdUsedLvls().max() > getSpaceDim())
2423+
return emitOpError("required out-of-bound coordinates");
2424+
23872425
return success();
23882426
}
23892427

mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
164164
// replace sparse_tensor.yield with scf.yield.
165165
rewriter.eraseOp(yieldOp);
166166
rewriter.create<scf::YieldOp>(loc, yields);
167-
168167
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
169168
rewriter.replaceOp(
170169
op, whileOp.getResults().drop_front(it->getCursor().size()),
@@ -192,6 +191,8 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
192191

193192
void mlir::populateLowerSparseIterationToSCFPatterns(
194193
TypeConverter &converter, RewritePatternSet &patterns) {
194+
195+
IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
195196
patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>(
196197
converter, patterns.getContext());
197198
}

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1071,7 +1071,11 @@ static bool getAllTidLvlsInLatPoints(
10711071
}
10721072
// If we just need to one loop conditions and the conditions is not imposed on
10731073
// non-unique level, the loop can be generated by a for loop.
1074-
return numloopCond == 1 && !hasNonUnique;
1074+
// Or, if we are generating sparse-iterator-based loops, we always generate
1075+
// `sparse_tensor.iterate` regardless whether the level is unique or not.
1076+
return numloopCond == 1 &&
1077+
(!hasNonUnique || env.options().sparseEmitStrategy ==
1078+
SparseEmitStrategy::kSparseIterator);
10751079
}
10761080

10771081
/// Starts a loop sequence at given level. Returns true if

mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,12 @@ class SparsificationAndBufferizationPass
159159
pm.addPass(createSparseGPUCodegenPass(0, enableRuntimeLibrary));
160160
pm.addPass(createSparseReinterpretMapPass(ReinterpretMapScope::kAll));
161161
pm.addPass(createSparsificationPass(sparsificationOptions));
162+
if (sparsificationOptions.sparseEmitStrategy ==
163+
SparseEmitStrategy::kSparseIterator) {
164+
pm.addNestedPass<func::FuncOp>(createSparseSpaceCollapsePass());
165+
pm.addNestedPass<func::FuncOp>(createLowerSparseIterationToSCFPass());
166+
}
167+
162168
pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
163169
pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
164170
/*enableConvert=*/true));

0 commit comments

Comments
 (0)