@@ -39,20 +39,10 @@ static bool isZero(Value v) {
39
39
return false ;
40
40
}
41
41
42
- using LoopIndexToRangeIndexMap = DenseMap<int , int >;
43
-
44
- // Creates a number of ranges equal to the number of non-zero in `tileSizes`.
45
- // One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has
46
- // one entry per surrounding loop. It uses zero as the convention that a
47
- // particular loop is not tiled. This convention simplifies implementations by
48
- // avoiding affine map manipulations.
49
- // The returned ranges correspond to the loop ranges, in the proper order, that
50
- // are tiled and for which new loops will be created. Also the function returns
51
- // a map from loop indices of the LinalgOp to the corresponding non-empty range
52
- // indices of newly created loops.
53
- static std::tuple<SmallVector<Range, 4 >, LoopIndexToRangeIndexMap>
54
- makeTiledLoopRanges (RewriterBase &b, Location loc, AffineMap map,
55
- ValueRange allShapeSizes, ValueRange allTileSizes) {
42
+ std::tuple<SmallVector<Range, 4 >, LoopIndexToRangeIndexMap>
43
+ mlir::linalg::makeTiledLoopRanges (RewriterBase &b, Location loc, AffineMap map,
44
+ ValueRange allShapeSizes,
45
+ ValueRange allTileSizes) {
56
46
assert (allTileSizes.size () == map.getNumResults ());
57
47
// Apply `map` to get shape sizes in loop order.
58
48
auto shapeSizes = applyMapToValues (b, loc, map, allShapeSizes);
@@ -78,59 +68,9 @@ makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
78
68
return std::make_tuple (res, loopIndexToRangeIndex);
79
69
}
80
70
81
- // All indices returned by IndexOp should be invariant with respect to tiling.
82
- // Therefore, if an operation is tiled, we have to transform the indices
83
- // accordingly, i.e. offset them by the values of the corresponding induction
84
- // variables that are captured implicitly in the body of the op.
85
- //
86
- // Example. `linalg.generic` before tiling:
87
- //
88
- // #id_2d = (i, j) -> (i, j)
89
- // #pointwise_2d_trait = {
90
- // indexing_maps = [#id_2d, #id_2d],
91
- // iterator_types = ["parallel", "parallel"]
92
- // }
93
- // linalg.generic #pointwise_2d_trait %operand, %result {
94
- // ^bb0(%operand_in: f32, %result_in: f32):
95
- // %i = linalg.index 0 : index
96
- // %j = linalg.index 1 : index
97
- // <some operations that use %i, %j>
98
- // }: memref<50x100xf32>, memref<50x100xf32>
99
- //
100
- // After tiling pass with tiles sizes 10 and 25:
101
- //
102
- // #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
103
- //
104
- // %c1 = arith.constant 1 : index
105
- // %c0 = arith.constant 0 : index
106
- // %c25 = arith.constant 25 : index
107
- // %c10 = arith.constant 10 : index
108
- // operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
109
- // operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
110
- // scf.for %k = %c0 to operand_dim_0 step %c10 {
111
- // scf.for %l = %c0 to operand_dim_1 step %c25 {
112
- // %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
113
- // : memref<50x100xf32> to memref<?x?xf32, #strided>
114
- // %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
115
- // : memref<50x100xf32> to memref<?x?xf32, #strided>
116
- // linalg.generic pointwise_2d_trait %4, %5 {
117
- // ^bb0(%operand_in: f32, %result_in: f32):
118
- // %i = linalg.index 0 : index
119
- // %j = linalg.index 1 : index
120
- // // Indices `k` and `l` are implicitly captured in the body.
121
- // %transformed_i = arith.addi %i, %k : index // index `i` is offset by %k
122
- // %transformed_j = arith.addi %j, %l : index // index `j` is offset by %l
123
- // // Every use of %i, %j is replaced with %transformed_i, %transformed_j
124
- // <some operations that use %transformed_i, %transformed_j>
125
- // }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
126
- // }
127
- // }
128
- //
129
- // TODO: Investigate whether mixing implicit and explicit indices
130
- // does not lead to losing information.
131
- static void
132
- transformIndexOps (RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
133
- const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
71
+ void mlir::linalg::transformIndexOps (
72
+ RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
73
+ const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
134
74
SmallVector<Value> allIvs (op.getNumLoops (), nullptr );
135
75
for (auto &en : enumerate(allIvs)) {
136
76
auto rangeIndex = loopIndexToRangeIndex.find (en.index ());
0 commit comments