Skip to content

Commit c223521

Browse files
[mlir][TilingInterface] Allow tile and fuse to work with ReductionTilingStrategy::PartialReductionOuterParallelStrategy. (#147593)
Since `scf::tileUsingSCF` is the core method used for tiling the root operation within the `scf::tileConsumersAndFuseProducersUsingSCF`, the latter can fuse into any tiled loop generated using `scf::tileUsingSCF`. This patch adds a test for tiling a root operation using `ReductionTilingStrategy::PartialReductionOuterParallelStrategy` and fusing producers with it. Since this strategy generates a rank-reducing extract slice `tensor::replaceExtractSliceWithTiledProducer` which is the core method used for the fusion was extended to handle the rank-reducing slices. Also fix a small bug in the computation of the reduction induction variable (which needs to use `floorDiv` instead of `ceilDiv`) Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
1 parent f924200 commit c223521

File tree

6 files changed

+183
-14
lines changed

6 files changed

+183
-14
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ getSplitReductionIvs(RewriterBase &rewriter, Location loc,
681681
splitReductionIvs.resize(reductionDims.size(), rewriter.getIndexAttr(0));
682682
AffineExpr s0, s1;
683683
bindSymbols(rewriter.getContext(), s0, s1);
684-
AffineExpr divExpr = s0.ceilDiv(s1);
684+
AffineExpr divExpr = s0.floorDiv(s1);
685685
int ivIndex = 0;
686686
if (reductionStrategy ==
687687
ReductionTilingStrategy::PartialReductionOuterParallel) {

mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,23 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
3939
if (failed(tiledResult))
4040
return failure();
4141

42+
// For cases where the slice was rank-reducing, create a rank-reducing slice
43+
// to get the same type back.
44+
llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
45+
if (droppedDims.any()) {
46+
assert(tiledResult->tiledValues.size() == 1 &&
47+
"expected only a single tiled result value to replace the extract "
48+
"slice");
49+
SmallVector<OpFoldResult> offsets(sliceOp.getSourceType().getRank(),
50+
builder.getIndexAttr(0));
51+
SmallVector<OpFoldResult> strides(sliceOp.getSourceType().getRank(),
52+
builder.getIndexAttr(1));
53+
auto newSliceOp = builder.create<tensor::ExtractSliceOp>(
54+
sliceOp.getLoc(), sliceOp.getType(), tiledResult->tiledValues[0],
55+
offsets, sliceOp.getMixedSizes(), strides);
56+
tiledResult->tiledValues[0] = newSliceOp;
57+
}
58+
4259
return *tiledResult;
4360
}
4461

mlir/test/Dialect/Linalg/transform-tile-reduction.mlir

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,7 @@ func.func @reduction_tile_parallel_using_tile_sizes(
555555
}
556556
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 ceildiv 5)>
557557
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
558+
// CHECK-DAG: #[[MAP2:.*]] = affine_map<()[s0] -> (s0 floordiv 5)>
558559
// CHECK: func @reduction_tile_parallel_using_tile_sizes(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
559560
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
560561
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
@@ -566,7 +567,7 @@ func.func @reduction_tile_parallel_using_tile_sizes(
566567
// CHECK-SAME: outs(%[[E]] :
567568
// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (%[[D1]]) step (5) shared_outs(%[[ARG3:.+]] = %[[F]])
568569
// CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP1]](%[[IV]])[%[[D1]]]
569-
// CHECK-DAG: %[[INIT_OFFSET:.+]] = affine.apply #[[MAP0]]()[%[[IV]]]
570+
// CHECK-DAG: %[[INIT_OFFSET:.+]] = affine.apply #[[MAP2]]()[%[[IV]]]
570571
// CHECK-DAG: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [%[[D0]], %[[TS0]]] [1, 1]
571572
// CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3]][0, %[[INIT_OFFSET]]] [%[[D0]], 1] [1, 1]
572573
// CHECK: %[[PARTIAL:.+]] = linalg.generic
@@ -619,7 +620,7 @@ module {
619620
}
620621
}
621622
}
622-
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 ceildiv 64)>
623+
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 64)>
623624
// CHECK: func @reduction_using_forall_tile_single_of_multiple_reduction_inner(%[[ARG0:.+]]: tensor<86x128xf32>, %[[ARG1:.+]]: tensor<4096x86x128xf32>, %[[ARG2:.+]]: tensor<4096xf32>)
624625
// CHECK: %[[E:.*]] = tensor.empty() : tensor<4096x2xf32>
625626
// CHECK: %[[F:.*]] = linalg.fill
@@ -671,7 +672,7 @@ module {
671672
}
672673
}
673674
}
674-
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 ceildiv 64)>
675+
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 64)>
675676
// CHECK: func @reduction_using_forall_tilesize_0_of_multiple_reduction_inner(%[[ARG0:.+]]: tensor<86x128xf32>, %[[ARG1:.+]]: tensor<4096x86x128xf32>, %[[ARG2:.+]]: tensor<4096xf32>)
676677
// CHECK: %[[E:.*]] = tensor.empty() : tensor<4096x2xf32>
677678
// CHECK: %[[F:.*]] = linalg.fill
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// RUN: mlir-opt -transform-interpreter -cse -mlir-print-local-scope -split-input-file -verify-diagnostics %s | FileCheck %s
2+
3+
// Check tile+ fuse works with partial reduction outer parallel strategy.
4+
5+
module{
6+
func.func @tile_and_fuse_with_partial_reduction_outer_parallel(
7+
%arg0 : tensor<?x?xf32>) -> tensor<?xf32> {
8+
%c0 = arith.constant 0 : index
9+
%cst = arith.constant 0.0 : f32
10+
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
11+
%empty = tensor.empty(%d0) : tensor<?xf32>
12+
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<?xf32>) -> tensor<?xf32>
13+
%generic = linalg.generic {
14+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
15+
iterator_types = ["parallel", "reduction"]}
16+
ins(%arg0 : tensor<?x?xf32>) outs(%fill : tensor<?xf32>) {
17+
^bb0(%b0 : f32, %b1 : f32):
18+
%0 = arith.addf %b0, %b1 : f32
19+
linalg.yield %0 : f32
20+
} -> tensor<?xf32>
21+
return %generic : tensor<?xf32>
22+
}
23+
}
24+
module attributes {transform.with_named_sequence} {
25+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
26+
%generic = transform.structured.match ops{["linalg.generic"]} in %arg1
27+
: (!transform.any_op) -> !transform.any_op
28+
%a, %loop = transform.test.tile_and_fuse_outer_parallel_partial_reduction
29+
%generic tile_sizes = [128]
30+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
31+
transform.yield
32+
}
33+
}
34+
// CHECK-LABEL: func @tile_and_fuse_with_partial_reduction_outer_parallel(
35+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>)
36+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
37+
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
38+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
39+
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
40+
// CHECK: %[[REDUCTION_NUM:.+]] = affine.apply affine_map<()[s0] -> (s0 ceildiv 128)>()[%[[D1]]]
41+
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[D0]], %[[REDUCTION_NUM]])
42+
// CHECK: %[[FORALL:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]]) =
43+
// CHECK-SAME: shared_outs(%[[ITER_ARG:.+]] = %[[EMPTY]])
44+
// CHECK-DAG: %[[TILESIZE:.+]] = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 128)>(%[[IV0]])[%[[D1]]]
45+
// CHECK-DAG: %[[REDUCTION_IV:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 128)>()[%[[IV0]]]
46+
// CHECK-DAG: %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV0]]] [%[[D0]], %[[TILESIZE]]] [1, 1]
47+
// CHECK: %[[ITER_ARG_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]][0, %[[REDUCTION_IV]]] [%[[D0]], 1] [1, 1]
48+
// CHECK: %[[FILL:.+]] = linalg.fill
49+
// CHECK-SAME: outs(%[[ITER_ARG_SLICE]] : tensor<?x1xf32>)
50+
// CHECK: %[[REDUCING_SLICE:.+]] = tensor.extract_slice %[[FILL]][0, 0] [%[[D0]], 1] [1, 1] : tensor<?x1xf32> to tensor<?xf32>
51+
// CHECK: %[[GENERIC:.+]] = linalg.generic
52+
// CHECK-SAME: ins(%[[ARG0_SLICE]] :
53+
// CHECK-SAME: outs(%[[REDUCING_SLICE]] :
54+
// CHECK: tensor.parallel_insert_slice %[[GENERIC]] into %[[ITER_ARG]]
55+
// CHECK-SAME: [0, %[[REDUCTION_IV]]] [%[[D0]], 1] [1, 1]
56+
// CHECK: %[[REDUCE:.+]] = linalg.reduce
57+
// CHECK-SAME: ins(%[[FORALL]] :
58+
// CHECK: return %[[REDUCE]]

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp

Lines changed: 79 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
1919
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
2020
#include "mlir/Dialect/Utils/StaticValueUtils.h"
21+
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2122
#include "mlir/IR/Dominance.h"
2223
#include "mlir/IR/OpImplementation.h"
2324
#include "mlir/Interfaces/TilingInterface.h"
@@ -60,8 +61,7 @@ template <typename Range>
6061
static LogicalResult
6162
applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
6263
Range &&payloadOps, unsigned numLoops,
63-
ArrayRef<OpFoldResult> tileSizes,
64-
ArrayRef<int64_t> interchange, bool useForall,
64+
scf::SCFTilingOptions tilingOptions,
6565
TransformResults &transformResults) {
6666
SmallVector<Operation *> tiledOps;
6767
SmallVector<SmallVector<Operation *>> loopOps(numLoops);
@@ -83,12 +83,6 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
8383
}
8484
}
8585

86-
scf::SCFTilingOptions tilingOptions;
87-
tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
88-
if (useForall) {
89-
tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
90-
}
91-
9286
scf::SCFTileAndFuseOptions tileAndFuseOptions;
9387
tileAndFuseOptions.setTilingOptions(tilingOptions);
9488

@@ -157,10 +151,16 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
157151
SmallVector<OpFoldResult> tileSizesOfr =
158152
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
159153

154+
scf::SCFTilingOptions tilingOptions;
155+
tilingOptions.setTileSizes(tileSizesOfr).setInterchange(tileInterchange);
156+
if (getUseForall()) {
157+
tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
158+
}
159+
160160
LogicalResult result = applyTileAndFuseToAll(
161161
rewriter, getOperation(), state.getPayloadOps(getTarget()),
162-
tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr,
163-
tileInterchange, getUseForall(), transformResults);
162+
tileSizes.size() - llvm::count(tileSizes, 0), tilingOptions,
163+
transformResults);
164164
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
165165
: DiagnosedSilenceableFailure::success();
166166
}
@@ -399,6 +399,75 @@ void transform::TestFuseUsingForallOp::getEffects(
399399
modifiesPayload(effects);
400400
}
401401

402+
//===----------------------------------------------------------------------===//
403+
// TestTileAndFuseOuterParallelPartialReduction
404+
//===----------------------------------------------------------------------===//
405+
406+
DiagnosedSilenceableFailure
407+
transform::TestTileAndFuseOuterParallelPartialReductionOp::apply(
408+
TransformRewriter &rewriter, TransformResults &transformResults,
409+
TransformState &state) {
410+
auto target =
411+
dyn_cast<TilingInterface>(*state.getPayloadOps(getRootOp()).begin());
412+
if (!target) {
413+
emitOpError("expected root operation to implement `TilingInterface`");
414+
return DiagnosedSilenceableFailure::definiteFailure();
415+
}
416+
417+
SmallVector<unsigned> reductionDims =
418+
extractFromIntegerArrayAttr<unsigned>(getReductionDims());
419+
if (reductionDims.empty()) {
420+
for (auto [index, iterator] :
421+
llvm::enumerate(target.getLoopIteratorTypes()))
422+
if (iterator == utils::IteratorType::reduction)
423+
reductionDims.push_back(index);
424+
}
425+
426+
if (reductionDims.empty()) {
427+
emitOpError(
428+
"no reduction dimension specified or found in the target operation");
429+
return DiagnosedSilenceableFailure::definiteFailure();
430+
}
431+
432+
SmallVector<int64_t> reductionTileSizes =
433+
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
434+
if (reductionTileSizes.size() != reductionDims.size()) {
435+
emitOpError(
436+
"missing tile sizes for reduction dimensions that are to be tiled");
437+
return DiagnosedSilenceableFailure::definiteFailure();
438+
}
439+
440+
// Adjust tile sizes so that it corresponds to the reduction iterator types.
441+
SmallVector<OpFoldResult> tileSizes;
442+
int reductionTileSizeNum = 0;
443+
OpFoldResult zero = rewriter.getIndexAttr(0);
444+
for (auto iterator : target.getLoopIteratorTypes()) {
445+
if (iterator == utils::IteratorType::parallel) {
446+
tileSizes.push_back(zero);
447+
continue;
448+
}
449+
tileSizes.push_back(
450+
rewriter.getIndexAttr(reductionTileSizes[reductionTileSizeNum++]));
451+
}
452+
453+
scf::SCFTilingOptions tilingOptions;
454+
tilingOptions.setTileSizes(tileSizes)
455+
.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp)
456+
.setReductionTilingStrategy(
457+
ReductionTilingStrategy::PartialReductionOuterParallel)
458+
.setReductionDims(reductionDims);
459+
if (auto mapping = getMapping()) {
460+
tilingOptions.setMapping(getMapping().value());
461+
}
462+
463+
LogicalResult result = applyTileAndFuseToAll(
464+
rewriter, getOperation(), state.getPayloadOps(getRootOp()),
465+
/*numLoops =*/1, tilingOptions, transformResults);
466+
467+
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
468+
: DiagnosedSilenceableFailure::success();
469+
}
470+
402471
#define GET_OP_CLASSES
403472
#include "TestTilingInterfaceTransformOps.cpp.inc"
404473

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,4 +126,28 @@ def TestFuseUsingForallOp : Op<Transform_Dialect, "test.fuse_using_forall",
126126
}];
127127
}
128128

129+
def TestTileAndFuseOuterParallelPartialReductionOp : Op<
130+
Transform_Dialect, "test.tile_and_fuse_outer_parallel_partial_reduction",
131+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
132+
DeclareOpInterfaceMethods<TransformOpInterface>,
133+
ReportTrackingListenerFailuresOpTrait]> {
134+
let description = [{
135+
Test operation to tile an operation using partial reduction with
136+
outer parallel strategy, and to fuse its producers.
137+
}];
138+
139+
let arguments = (ins TransformHandleTypeInterface:$root_op,
140+
DefaultValuedAttr<I64ArrayAttr, "{}">:$reduction_dims,
141+
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
142+
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
143+
144+
let results = (outs TransformHandleTypeInterface:$tiled_ops,
145+
Variadic<TransformHandleTypeInterface>:$loops);
146+
let assemblyFormat = [{
147+
$root_op (`reduction_dims` `=` $reduction_dims^)?
148+
(`tile_sizes` `=` $tile_sizes^)? (`mapping` `=` $mapping^)?
149+
attr-dict `:` functional-type(operands, results)
150+
}];
151+
}
152+
129153
#endif // TEST_TILINGINTERFACE_TRANSFORM_OPS

0 commit comments

Comments
 (0)