Skip to content

Commit c962038

Browse files
committed
[mlir][nfc] Expose linalg tiling helpers.
Differential Revision: https://reviews.llvm.org/D119330
1 parent fd0417a commit c962038

File tree

2 files changed

+76
-67
lines changed

2 files changed

+76
-67
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,75 @@ struct LinalgTransformationFilter {
481481
using TileSizeComputationFunction =
482482
std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
483483

484+
/// Creates a number of ranges equal to the number of non-zero in `tileSizes`.
485+
/// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument
486+
/// has one entry per surrounding loop. It uses zero as the convention that a
487+
/// particular loop is not tiled. This convention simplifies implementations by
488+
/// avoiding affine map manipulations.
489+
/// The returned ranges correspond to the loop ranges, in the proper order, that
490+
/// are tiled and for which new loops will be created. Also the function returns
491+
/// a map from loop indices of the LinalgOp to the corresponding non-empty range
492+
/// indices of newly created loops.
493+
using LoopIndexToRangeIndexMap = DenseMap<int, int>;
494+
std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
495+
makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
496+
ValueRange allShapeSizes, ValueRange allTileSizes);
497+
498+
/// All indices returned by IndexOp should be invariant with respect to tiling.
499+
/// Therefore, if an operation is tiled, we have to transform the indices
500+
/// accordingly, i.e. offset them by the values of the corresponding induction
501+
/// variables that are captured implicitly in the body of the op.
502+
///
503+
/// Example. `linalg.generic` before tiling:
504+
///
505+
/// #id_2d = (i, j) -> (i, j)
506+
/// #pointwise_2d_trait = {
507+
/// indexing_maps = [#id_2d, #id_2d],
508+
/// iterator_types = ["parallel", "parallel"]
509+
/// }
510+
/// linalg.generic #pointwise_2d_trait %operand, %result {
511+
/// ^bb0(%operand_in: f32, %result_in: f32):
512+
/// %i = linalg.index 0 : index
513+
/// %j = linalg.index 1 : index
514+
/// <some operations that use %i, %j>
515+
/// }: memref<50x100xf32>, memref<50x100xf32>
516+
///
517+
/// After tiling pass with tiles sizes 10 and 25:
518+
///
519+
/// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
520+
///
521+
/// %c1 = arith.constant 1 : index
522+
/// %c0 = arith.constant 0 : index
523+
/// %c25 = arith.constant 25 : index
524+
/// %c10 = arith.constant 10 : index
525+
/// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
526+
/// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
527+
/// scf.for %k = %c0 to operand_dim_0 step %c10 {
528+
/// scf.for %l = %c0 to operand_dim_1 step %c25 {
529+
/// %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
530+
/// : memref<50x100xf32> to memref<?x?xf32, #strided>
531+
/// %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
532+
/// : memref<50x100xf32> to memref<?x?xf32, #strided>
533+
/// linalg.generic pointwise_2d_trait %4, %5 {
534+
/// ^bb0(%operand_in: f32, %result_in: f32):
535+
/// %i = linalg.index 0 : index
536+
/// %j = linalg.index 1 : index
537+
/// // Indices `k` and `l` are implicitly captured in the body.
538+
/// %transformed_i = arith.addi %i, %k : index // index `i` is offset by
539+
/// %k %transformed_j = arith.addi %j, %l : index // index `j` is offset
540+
/// by %l
541+
/// // Every use of %i, %j is replaced with %transformed_i, %transformed_j
542+
/// <some operations that use %transformed_i, %transformed_j>
543+
/// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
544+
/// }
545+
/// }
546+
///
547+
/// TODO: Investigate whether mixing implicit and explicit indices
548+
/// does not lead to losing information.
549+
void transformIndexOps(RewriterBase &b, LinalgOp op,
550+
SmallVectorImpl<Value> &ivs,
551+
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex);
552+
484553
/// Callback returning the padding value to use for a given OpOperand or failure
485554
/// for no padding. This should be a function of both the operation and the
486555
/// operand type.

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 7 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,10 @@ static bool isZero(Value v) {
3939
return false;
4040
}
4141

42-
using LoopIndexToRangeIndexMap = DenseMap<int, int>;
43-
44-
// Creates a number of ranges equal to the number of non-zero in `tileSizes`.
45-
// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has
46-
// one entry per surrounding loop. It uses zero as the convention that a
47-
// particular loop is not tiled. This convention simplifies implementations by
48-
// avoiding affine map manipulations.
49-
// The returned ranges correspond to the loop ranges, in the proper order, that
50-
// are tiled and for which new loops will be created. Also the function returns
51-
// a map from loop indices of the LinalgOp to the corresponding non-empty range
52-
// indices of newly created loops.
53-
static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
54-
makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
55-
ValueRange allShapeSizes, ValueRange allTileSizes) {
42+
std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
43+
mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
44+
ValueRange allShapeSizes,
45+
ValueRange allTileSizes) {
5646
assert(allTileSizes.size() == map.getNumResults());
5747
// Apply `map` to get shape sizes in loop order.
5848
auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
@@ -78,59 +68,9 @@ makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
7868
return std::make_tuple(res, loopIndexToRangeIndex);
7969
}
8070

81-
// All indices returned by IndexOp should be invariant with respect to tiling.
82-
// Therefore, if an operation is tiled, we have to transform the indices
83-
// accordingly, i.e. offset them by the values of the corresponding induction
84-
// variables that are captured implicitly in the body of the op.
85-
//
86-
// Example. `linalg.generic` before tiling:
87-
//
88-
// #id_2d = (i, j) -> (i, j)
89-
// #pointwise_2d_trait = {
90-
// indexing_maps = [#id_2d, #id_2d],
91-
// iterator_types = ["parallel", "parallel"]
92-
// }
93-
// linalg.generic #pointwise_2d_trait %operand, %result {
94-
// ^bb0(%operand_in: f32, %result_in: f32):
95-
// %i = linalg.index 0 : index
96-
// %j = linalg.index 1 : index
97-
// <some operations that use %i, %j>
98-
// }: memref<50x100xf32>, memref<50x100xf32>
99-
//
100-
// After tiling pass with tiles sizes 10 and 25:
101-
//
102-
// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
103-
//
104-
// %c1 = arith.constant 1 : index
105-
// %c0 = arith.constant 0 : index
106-
// %c25 = arith.constant 25 : index
107-
// %c10 = arith.constant 10 : index
108-
// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
109-
// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
110-
// scf.for %k = %c0 to operand_dim_0 step %c10 {
111-
// scf.for %l = %c0 to operand_dim_1 step %c25 {
112-
// %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
113-
// : memref<50x100xf32> to memref<?x?xf32, #strided>
114-
// %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
115-
// : memref<50x100xf32> to memref<?x?xf32, #strided>
116-
// linalg.generic pointwise_2d_trait %4, %5 {
117-
// ^bb0(%operand_in: f32, %result_in: f32):
118-
// %i = linalg.index 0 : index
119-
// %j = linalg.index 1 : index
120-
// // Indices `k` and `l` are implicitly captured in the body.
121-
// %transformed_i = arith.addi %i, %k : index // index `i` is offset by %k
122-
// %transformed_j = arith.addi %j, %l : index // index `j` is offset by %l
123-
// // Every use of %i, %j is replaced with %transformed_i, %transformed_j
124-
// <some operations that use %transformed_i, %transformed_j>
125-
// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
126-
// }
127-
// }
128-
//
129-
// TODO: Investigate whether mixing implicit and explicit indices
130-
// does not lead to losing information.
131-
static void
132-
transformIndexOps(RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
133-
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
71+
void mlir::linalg::transformIndexOps(
72+
RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
73+
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
13474
SmallVector<Value> allIvs(op.getNumLoops(), nullptr);
13575
for (auto &en : enumerate(allIvs)) {
13676
auto rangeIndex = loopIndexToRangeIndex.find(en.index());

0 commit comments

Comments
 (0)