Skip to content

Commit ace5108

Browse files
authored
feat(linalg): add a way to pass controlFn to foldIntoPackUnpackPatterns (#143685)
This PR adds a mechanism, so that downstream consumers can pass in control functions for the application of these patterns. This change shouldn't affect any consumers of this method that do not specify a controlFn. The controlFn always gets the source operand of the consumer in each of the patterns as a parameter. In IREE, we (will) use it to control preventing folding patterns that would inhibit fusion. See IREE issue [#20896](iree-org/iree#20896) for more details.
1 parent f9413e1 commit ace5108

File tree

4 files changed

+165
-10
lines changed

4 files changed

+165
-10
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1984,10 +1984,15 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
19841984
/// convert to a `linalg.dot`.
19851985
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
19861986

1987+
/// Function type which is used to control folding operations like `tensor.pad`
1988+
/// and `tensor.extract_slice` into linalg.pack/unpack ops.
1989+
using ControlFoldIntoPackUnpackFn = std::function<bool(OpOperand *opOperand)>;
19871990
/// Populates `patterns` with patterns that fold operations like `tensor.pad`
19881991
/// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
19891992
/// respectively.
1990-
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns);
1993+
void populateFoldIntoPackAndUnpackPatterns(
1994+
RewritePatternSet &patterns,
1995+
const ControlFoldIntoPackUnpackFn &controlFn = nullptr);
19911996

19921997
/// Populates `patterns` with patterns that fold operations like `linalg.pack`
19931998
/// and `linalg.unpack` into `tensor.empty`.

mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Linalg/IR/Linalg.h"
10+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1011
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1112
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1213
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -197,7 +198,9 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
197198
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
198199
/// the pad op has zero low paddings, or if `pack` has no padding values.
199200
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
200-
using OpRewritePattern<PackOp>::OpRewritePattern;
201+
public:
202+
FoldPadWithPackOp(MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
203+
: OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
201204

202205
LogicalResult matchAndRewrite(PackOp packOp,
203206
PatternRewriter &rewriter) const override {
@@ -206,6 +209,10 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
206209
if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
207210
return failure();
208211

212+
// User controlled folding function.
213+
if (controlFn && !controlFn(&packOp.getSourceMutable()))
214+
return failure();
215+
209216
Value constantPaddingValue = padOp.getConstantPaddingValue();
210217
if (!constantPaddingValue)
211218
return failure();
@@ -220,20 +227,31 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
220227
packOp.getOuterDimsPerm());
221228
return success();
222229
}
230+
231+
private:
232+
ControlFoldIntoPackUnpackFn controlFn;
223233
};
224234

225235
/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
226236
/// has extract_slice semantics.
227237
struct FoldUnpackWithExtractSliceOp
228238
: public OpRewritePattern<tensor::ExtractSliceOp> {
229-
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
239+
public:
240+
FoldUnpackWithExtractSliceOp(MLIRContext *context,
241+
ControlFoldIntoPackUnpackFn controlFn)
242+
: OpRewritePattern<tensor::ExtractSliceOp>(context),
243+
controlFn(std::move(controlFn)) {}
230244

231245
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
232246
PatternRewriter &rewriter) const override {
233247
auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
234248
if (!unpackOp)
235249
return failure();
236250

251+
// User controlled folding function.
252+
if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
253+
return failure();
254+
237255
if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
238256
return rewriter.notifyMatchFailure(
239257
sliceOp, "rank-reduced folding is not supported");
@@ -255,6 +273,9 @@ struct FoldUnpackWithExtractSliceOp
255273
unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
256274
return success();
257275
}
276+
277+
private:
278+
ControlFoldIntoPackUnpackFn controlFn;
258279
};
259280

260281
// Applies 'permutation' on 'inVec' and stores the result in resVec.
@@ -284,7 +305,12 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
284305
/// semantics.
285306
struct FoldProducerPackWithConsumerLinalgTransposeOp
286307
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
287-
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
308+
309+
public:
310+
FoldProducerPackWithConsumerLinalgTransposeOp(
311+
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
312+
: OpInterfaceRewritePattern<linalg::LinalgOp>(context),
313+
controlFn(std::move(controlFn)) {}
288314

289315
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
290316
PatternRewriter &rewriter) const override {
@@ -293,6 +319,10 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
293319
if (!packOp)
294320
return failure();
295321

322+
// User controlled folding function.
323+
if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
324+
return failure();
325+
296326
FailureOr<SmallVector<int64_t>> maybePerm =
297327
getTransposeOpPermutation(linalgOp);
298328
if (failed(maybePerm))
@@ -331,20 +361,31 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
331361

332362
return success();
333363
}
364+
365+
private:
366+
ControlFoldIntoPackUnpackFn controlFn;
334367
};
335368

336369
/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
337370
/// semantics.
338371
struct FoldConsumerPackWithProducerLinalgTransposeOp
339372
: public OpRewritePattern<PackOp> {
340-
using OpRewritePattern<PackOp>::OpRewritePattern;
373+
374+
public:
375+
FoldConsumerPackWithProducerLinalgTransposeOp(
376+
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
377+
: OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
341378

342379
LogicalResult matchAndRewrite(PackOp packOp,
343380
PatternRewriter &rewriter) const override {
344381
auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
345382
if (!linalgOp)
346383
return failure();
347384

385+
// User controlled folding function.
386+
if (controlFn && !controlFn(&packOp.getSourceMutable()))
387+
return failure();
388+
348389
FailureOr<SmallVector<int64_t>> maybePerm =
349390
getTransposeOpPermutation(linalgOp);
350391
if (failed(maybePerm))
@@ -375,13 +416,21 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
375416

376417
return success();
377418
}
419+
420+
private:
421+
ControlFoldIntoPackUnpackFn controlFn;
378422
};
379423

380424
/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
381425
/// transpose semantics.
382426
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
383427
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
384-
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
428+
429+
public:
430+
FoldProducerUnPackWithConsumerLinalgTransposeOp(
431+
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
432+
: OpInterfaceRewritePattern<linalg::LinalgOp>(context),
433+
controlFn(std::move(controlFn)) {}
385434

386435
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
387436
PatternRewriter &rewriter) const override {
@@ -390,6 +439,10 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
390439
if (!unPackOp)
391440
return failure();
392441

442+
// User controlled folding function.
443+
if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
444+
return failure();
445+
393446
FailureOr<SmallVector<int64_t>> maybePerm =
394447
getTransposeOpPermutation(linalgOp);
395448
if (failed(maybePerm))
@@ -416,6 +469,9 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
416469

417470
return success();
418471
}
472+
473+
private:
474+
ControlFoldIntoPackUnpackFn controlFn;
419475
};
420476

421477
/// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
@@ -424,12 +480,21 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
424480
: public OpRewritePattern<UnPackOp> {
425481
using OpRewritePattern<UnPackOp>::OpRewritePattern;
426482

483+
public:
484+
FoldConsumerUnPackWithProducerLinalgTransposeOp(
485+
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
486+
: OpRewritePattern<UnPackOp>(context), controlFn(std::move(controlFn)) {}
487+
427488
LogicalResult matchAndRewrite(UnPackOp unPackOp,
428489
PatternRewriter &rewriter) const override {
429490
auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
430491
if (!linalgOp)
431492
return failure();
432493

494+
// User controlled folding function.
495+
if (controlFn && !controlFn(&unPackOp.getSourceMutable()))
496+
return failure();
497+
433498
FailureOr<SmallVector<int64_t>> maybePerm =
434499
getTransposeOpPermutation(linalgOp);
435500
if (failed(maybePerm))
@@ -474,6 +539,9 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
474539

475540
return success();
476541
}
542+
543+
private:
544+
ControlFoldIntoPackUnpackFn controlFn;
477545
};
478546

479547
/// tensor.empty does not define any tensor contents, so an unpadded pack
@@ -521,13 +589,14 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
521589

522590
} // namespace
523591

524-
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
592+
void populateFoldIntoPackAndUnpackPatterns(
593+
RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn) {
525594
patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
526595
FoldProducerPackWithConsumerLinalgTransposeOp,
527596
FoldConsumerPackWithProducerLinalgTransposeOp,
528597
FoldConsumerUnPackWithProducerLinalgTransposeOp,
529598
FoldProducerUnPackWithConsumerLinalgTransposeOp>(
530-
patterns.getContext());
599+
patterns.getContext(), controlFn);
531600
}
532601

533602
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {

mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack %s | FileCheck %s
2+
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack-control %s | FileCheck %s --check-prefix=CONTROL
23

34
func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
45
%arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
@@ -373,6 +374,36 @@ func.func @linalg_transpose_linalg.pack_fold(%arg0: tensor<56x57x1x64xf32>) -> t
373374

374375
// -----
375376

377+
func.func @linalg_transpose_linalg.pack_fold_multi_result(%arg0: tensor<56x57x1x64xf32>) -> (tensor<1x56x57x64xf32>, tensor<1x57x56x2x32xf32>) {
378+
%0 = tensor.empty() : tensor<1x56x57x64xf32>
379+
%transposed = linalg.transpose
380+
ins(%arg0 : tensor<56x57x1x64xf32>)
381+
outs(%0 : tensor<1x56x57x64xf32>)
382+
permutation = [2, 0, 1, 3]
383+
384+
%1 = tensor.empty() : tensor<1x57x56x2x32xf32>
385+
%pack = linalg.pack %transposed
386+
outer_dims_perm = [0, 2, 1, 3]
387+
inner_dims_pos = [3]
388+
inner_tiles = [32]
389+
into %1 : tensor<1x56x57x64xf32> -> tensor<1x57x56x2x32xf32>
390+
return %transposed, %pack : tensor<1x56x57x64xf32>, tensor<1x57x56x2x32xf32>
391+
}
392+
// CHECK-LABEL: func @linalg_transpose_linalg.pack_fold_multi_result(
393+
// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
394+
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose
395+
// CHECK: %[[PACK:.+]] = linalg.pack %[[ARG0]]
396+
// CHECK-SAME: outer_dims_perm = [2, 1, 0, 3]
397+
// CHECK: return %[[TRANSPOSE]], %[[PACK]]
398+
399+
// CONTROL-LABEL: func @linalg_transpose_linalg.pack_fold_multi_result(
400+
// CONTROL: %[[TRANSPOSE:.+]] = linalg.transpose
401+
// CONTROL: %[[PACK:.+]] = linalg.pack %[[TRANSPOSE]]
402+
// CONTROL-SAME: outer_dims_perm = [0, 2, 1, 3]
403+
// CONTROL: return %[[TRANSPOSE]], %[[PACK]]
404+
405+
// -----
406+
376407
func.func @linalg_transpose_linalg.pack_fold_with_padding(%arg0: tensor<56x57x1x55xf32>, %padding: f32) -> tensor<1x57x56x2x32xf32> {
377408
%0 = tensor.empty() : tensor<1x56x57x55xf32>
378409
%transpose = linalg.transpose
@@ -550,6 +581,36 @@ func.func @linalg_transpose_linalg.unpack_fold(%arg0: tensor<1x1x4x16xi32>) -> t
550581

551582
// -----
552583

584+
func.func @linalg_transpose_linalg.unpack_fold_multi_result(%arg0: tensor<1x1x4x16xi32>) -> (tensor<1x1x16x4xi32>, tensor<16x4xi32>) {
585+
%0 = tensor.empty() : tensor<1x1x16x4xi32>
586+
%transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>)
587+
outs(%0 : tensor<1x1x16x4xi32>)
588+
permutation = [1, 0, 3, 2]
589+
%1 = tensor.empty() : tensor<16x4xi32>
590+
%unpack = linalg.unpack %transposed
591+
outer_dims_perm = [0, 1]
592+
inner_dims_pos = [0, 1]
593+
inner_tiles = [16, 4] into
594+
%1 : tensor<1x1x16x4xi32> -> tensor<16x4xi32>
595+
return %transposed, %unpack : tensor<1x1x16x4xi32>, tensor<16x4xi32>
596+
}
597+
//CHECK-LABEL: func.func @linalg_transpose_linalg.unpack_fold_multi_result(
598+
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x4x16xi32>)
599+
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose
600+
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
601+
// CHECK-SAME: outer_dims_perm = [1, 0]
602+
// CHECK: return %[[TRANSPOSE]], %[[UNPACK]]
603+
// CHECK: }
604+
605+
//CONTROL-LABEL: func.func @linalg_transpose_linalg.unpack_fold_multi_result(
606+
// CONTROL: %[[TRANSPOSE:.+]] = linalg.transpose
607+
// CONTROL: %[[UNPACK:.+]] = linalg.unpack %[[TRANSPOSE]]
608+
// CONTROL-SAME: outer_dims_perm = [0, 1]
609+
// CONTROL: return %[[TRANSPOSE]], %[[UNPACK]]
610+
// CONTROL: }
611+
612+
// -----
613+
553614
func.func @linalg_transpose_linalg.unpack_fold_partial_tile(%arg0: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
554615
%0 = tensor.empty() : tensor<1x1x16x4xi32>
555616
%transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>)

mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,11 @@ struct TestLinalgTransforms
130130
*this, "test-fold-into-pack-and-unpack",
131131
llvm::cl::desc("Test folding ops into linalg.pack and linalg.unpack"),
132132
llvm::cl::init(false)};
133+
Option<bool> testFoldIntoPackAndUnpackWithControlFn{
134+
*this, "test-fold-into-pack-and-unpack-control",
135+
llvm::cl::desc(
136+
"Test controlling folding ops into linalg.pack and linalg.unpack"),
137+
llvm::cl::init(false)};
133138
Option<bool> testSimplifyPackUnpackPatterns{
134139
*this, "test-simplify-pack-unpack-patterns",
135140
llvm::cl::desc("Test patterns to simplify linalg.pack and linalg.unpack"),
@@ -222,9 +227,11 @@ static void applyDecomposeWinogradOps(func::FuncOp funcOp) {
222227
(void)applyPatternsGreedily(funcOp, std::move(patterns));
223228
}
224229

225-
static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) {
230+
static void applyFoldIntoPackAndUnpackPatterns(
231+
Operation *rootOp,
232+
linalg::ControlFoldIntoPackUnpackFn controlFn = nullptr) {
226233
RewritePatternSet patterns(rootOp->getContext());
227-
linalg::populateFoldIntoPackAndUnpackPatterns(patterns);
234+
linalg::populateFoldIntoPackAndUnpackPatterns(patterns, controlFn);
228235
(void)applyPatternsGreedily(rootOp, std::move(patterns));
229236
}
230237

@@ -263,6 +270,19 @@ void TestLinalgTransforms::runOnOperation() {
263270
Operation *rootOp = getOperation();
264271
if (testFoldIntoPackAndUnpack)
265272
applyFoldIntoPackAndUnpackPatterns(rootOp);
273+
if (testFoldIntoPackAndUnpackWithControlFn) {
274+
linalg::ControlFoldIntoPackUnpackFn controlFn = [](OpOperand *opOperand) {
275+
Operation *producer = opOperand->get().getDefiningOp();
276+
Operation *consumer = opOperand->getOwner();
277+
// If we have a pack/unpack consumer and a producer that has multiple
278+
// uses, do not apply the folding patterns.
279+
if (isa<linalg::PackOp, linalg::UnPackOp>(consumer) &&
280+
isa<TilingInterface>(producer) && !producer->hasOneUse())
281+
return false;
282+
return true;
283+
};
284+
applyFoldIntoPackAndUnpackPatterns(rootOp, controlFn);
285+
}
266286
if (testSimplifyPackUnpackPatterns)
267287
applySimplifyPackUnpackPatterns(rootOp);
268288
}

0 commit comments

Comments
 (0)