18
18
#include " mlir/Dialect/Transform/IR/TransformDialect.h"
19
19
#include " mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
20
20
#include " mlir/Dialect/Utils/StaticValueUtils.h"
21
+ #include " mlir/Dialect/Utils/StructuredOpsUtils.h"
21
22
#include " mlir/IR/Dominance.h"
22
23
#include " mlir/IR/OpImplementation.h"
23
24
#include " mlir/Interfaces/TilingInterface.h"
@@ -60,8 +61,7 @@ template <typename Range>
60
61
static LogicalResult
61
62
applyTileAndFuseToAll (RewriterBase &rewriter, Operation *transformOp,
62
63
Range &&payloadOps, unsigned numLoops,
63
- ArrayRef<OpFoldResult> tileSizes,
64
- ArrayRef<int64_t > interchange, bool useForall,
64
+ scf::SCFTilingOptions tilingOptions,
65
65
TransformResults &transformResults) {
66
66
SmallVector<Operation *> tiledOps;
67
67
SmallVector<SmallVector<Operation *>> loopOps (numLoops);
@@ -83,12 +83,6 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
83
83
}
84
84
}
85
85
86
- scf::SCFTilingOptions tilingOptions;
87
- tilingOptions.setTileSizes (tileSizes).setInterchange (interchange);
88
- if (useForall) {
89
- tilingOptions.setLoopType (scf::SCFTilingOptions::LoopType::ForallOp);
90
- }
91
-
92
86
scf::SCFTileAndFuseOptions tileAndFuseOptions;
93
87
tileAndFuseOptions.setTilingOptions (tilingOptions);
94
88
@@ -157,10 +151,16 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
157
151
SmallVector<OpFoldResult> tileSizesOfr =
158
152
getAsIndexOpFoldResult (rewriter.getContext (), tileSizes);
159
153
154
+ scf::SCFTilingOptions tilingOptions;
155
+ tilingOptions.setTileSizes (tileSizesOfr).setInterchange (tileInterchange);
156
+ if (getUseForall ()) {
157
+ tilingOptions.setLoopType (scf::SCFTilingOptions::LoopType::ForallOp);
158
+ }
159
+
160
160
LogicalResult result = applyTileAndFuseToAll (
161
161
rewriter, getOperation (), state.getPayloadOps (getTarget ()),
162
- tileSizes.size () - llvm::count (tileSizes, 0 ), tileSizesOfr ,
163
- tileInterchange, getUseForall (), transformResults);
162
+ tileSizes.size () - llvm::count (tileSizes, 0 ), tilingOptions ,
163
+ transformResults);
164
164
return failed (result) ? DiagnosedSilenceableFailure::definiteFailure ()
165
165
: DiagnosedSilenceableFailure::success ();
166
166
}
@@ -399,6 +399,75 @@ void transform::TestFuseUsingForallOp::getEffects(
399
399
modifiesPayload (effects);
400
400
}
401
401
402
+ // ===----------------------------------------------------------------------===//
403
+ // TestTileAndFuseOuterParallelPartialReduction
404
+ // ===----------------------------------------------------------------------===//
405
+
406
+ DiagnosedSilenceableFailure
407
+ transform::TestTileAndFuseOuterParallelPartialReductionOp::apply (
408
+ TransformRewriter &rewriter, TransformResults &transformResults,
409
+ TransformState &state) {
410
+ auto target =
411
+ dyn_cast<TilingInterface>(*state.getPayloadOps (getRootOp ()).begin ());
412
+ if (!target) {
413
+ emitOpError (" expected root operation to implement `TilingInterface`" );
414
+ return DiagnosedSilenceableFailure::definiteFailure ();
415
+ }
416
+
417
+ SmallVector<unsigned > reductionDims =
418
+ extractFromIntegerArrayAttr<unsigned >(getReductionDims ());
419
+ if (reductionDims.empty ()) {
420
+ for (auto [index, iterator] :
421
+ llvm::enumerate (target.getLoopIteratorTypes ()))
422
+ if (iterator == utils::IteratorType::reduction)
423
+ reductionDims.push_back (index);
424
+ }
425
+
426
+ if (reductionDims.empty ()) {
427
+ emitOpError (
428
+ " no reduction dimension specified or found in the target operation" );
429
+ return DiagnosedSilenceableFailure::definiteFailure ();
430
+ }
431
+
432
+ SmallVector<int64_t > reductionTileSizes =
433
+ extractFromIntegerArrayAttr<int64_t >(getTileSizes ());
434
+ if (reductionTileSizes.size () != reductionDims.size ()) {
435
+ emitOpError (
436
+ " missing tile sizes for reduction dimensions that are to be tiled" );
437
+ return DiagnosedSilenceableFailure::definiteFailure ();
438
+ }
439
+
440
+ // Adjust tile sizes so that it corresponds to the reduction iterator types.
441
+ SmallVector<OpFoldResult> tileSizes;
442
+ int reductionTileSizeNum = 0 ;
443
+ OpFoldResult zero = rewriter.getIndexAttr (0 );
444
+ for (auto iterator : target.getLoopIteratorTypes ()) {
445
+ if (iterator == utils::IteratorType::parallel) {
446
+ tileSizes.push_back (zero);
447
+ continue ;
448
+ }
449
+ tileSizes.push_back (
450
+ rewriter.getIndexAttr (reductionTileSizes[reductionTileSizeNum++]));
451
+ }
452
+
453
+ scf::SCFTilingOptions tilingOptions;
454
+ tilingOptions.setTileSizes (tileSizes)
455
+ .setLoopType (scf::SCFTilingOptions::LoopType::ForallOp)
456
+ .setReductionTilingStrategy (
457
+ ReductionTilingStrategy::PartialReductionOuterParallel)
458
+ .setReductionDims (reductionDims);
459
+ if (auto mapping = getMapping ()) {
460
+ tilingOptions.setMapping (getMapping ().value ());
461
+ }
462
+
463
+ LogicalResult result = applyTileAndFuseToAll (
464
+ rewriter, getOperation (), state.getPayloadOps (getRootOp ()),
465
+ /* numLoops =*/ 1 , tilingOptions, transformResults);
466
+
467
+ return failed (result) ? DiagnosedSilenceableFailure::definiteFailure ()
468
+ : DiagnosedSilenceableFailure::success ();
469
+ }
470
+
402
471
#define GET_OP_CLASSES
403
472
#include " TestTilingInterfaceTransformOps.cpp.inc"
404
473
0 commit comments