7
7
// ===----------------------------------------------------------------------===//
8
8
9
9
#include " mlir/Dialect/Linalg/IR/Linalg.h"
10
+ #include " mlir/Dialect/Linalg/Transforms/Transforms.h"
10
11
#include " mlir/Dialect/Tensor/IR/Tensor.h"
11
12
#include " mlir/Dialect/Tensor/Transforms/Transforms.h"
12
13
#include " mlir/Dialect/Utils/IndexingUtils.h"
@@ -197,7 +198,9 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
197
198
// / Fold a `pad` -> `pack` into `pack` if they have the same padding values and
198
199
// / the pad op has zero low paddings, or if `pack` has no padding values.
199
200
struct FoldPadWithPackOp : public OpRewritePattern <PackOp> {
200
- using OpRewritePattern<PackOp>::OpRewritePattern;
201
+ public:
202
+ FoldPadWithPackOp (MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
203
+ : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
201
204
202
205
LogicalResult matchAndRewrite (PackOp packOp,
203
206
PatternRewriter &rewriter) const override {
@@ -206,6 +209,10 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
206
209
if (!padOp || padOp.getNofold () || !padOp.hasZeroLowPad ())
207
210
return failure ();
208
211
212
+ // User controlled folding function.
213
+ if (controlFn && !controlFn (&packOp.getSourceMutable ()))
214
+ return failure ();
215
+
209
216
Value constantPaddingValue = padOp.getConstantPaddingValue ();
210
217
if (!constantPaddingValue)
211
218
return failure ();
@@ -220,20 +227,31 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
220
227
packOp.getOuterDimsPerm ());
221
228
return success ();
222
229
}
230
+
231
+ private:
232
+ ControlFoldIntoPackUnpackFn controlFn;
223
233
};
224
234
225
235
// / Fold a `unpack` -> `extract_slice` into the `unpack` since it already
226
236
// / has extract_slice semantics.
227
237
struct FoldUnpackWithExtractSliceOp
228
238
: public OpRewritePattern<tensor::ExtractSliceOp> {
229
- using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
239
+ public:
240
+ FoldUnpackWithExtractSliceOp (MLIRContext *context,
241
+ ControlFoldIntoPackUnpackFn controlFn)
242
+ : OpRewritePattern<tensor::ExtractSliceOp>(context),
243
+ controlFn (std::move(controlFn)) {}
230
244
231
245
LogicalResult matchAndRewrite (tensor::ExtractSliceOp sliceOp,
232
246
PatternRewriter &rewriter) const override {
233
247
auto unpackOp = sliceOp.getSource ().getDefiningOp <UnPackOp>();
234
248
if (!unpackOp)
235
249
return failure ();
236
250
251
+ // User controlled folding function.
252
+ if (controlFn && !controlFn (&sliceOp.getSourceMutable ()))
253
+ return failure ();
254
+
237
255
if (sliceOp.getResultType ().getRank () != unpackOp.getDestType ().getRank ()) {
238
256
return rewriter.notifyMatchFailure (
239
257
sliceOp, " rank-reduced folding is not supported" );
@@ -255,6 +273,9 @@ struct FoldUnpackWithExtractSliceOp
255
273
unpackOp.getMixedTiles (), unpackOp.getOuterDimsPerm ());
256
274
return success ();
257
275
}
276
+
277
+ private:
278
+ ControlFoldIntoPackUnpackFn controlFn;
258
279
};
259
280
260
281
// Applies 'permutation' on 'inVec' and stores the result in resVec.
@@ -284,7 +305,12 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
284
305
// / semantics.
285
306
struct FoldProducerPackWithConsumerLinalgTransposeOp
286
307
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
287
- using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
308
+
309
+ public:
310
+ FoldProducerPackWithConsumerLinalgTransposeOp (
311
+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
312
+ : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
313
+ controlFn (std::move(controlFn)) {}
288
314
289
315
LogicalResult matchAndRewrite (linalg::LinalgOp linalgOp,
290
316
PatternRewriter &rewriter) const override {
@@ -293,6 +319,10 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
293
319
if (!packOp)
294
320
return failure ();
295
321
322
+ // User controlled folding function.
323
+ if (controlFn && !controlFn (&linalgOp->getOpOperand (0 )))
324
+ return failure ();
325
+
296
326
FailureOr<SmallVector<int64_t >> maybePerm =
297
327
getTransposeOpPermutation (linalgOp);
298
328
if (failed (maybePerm))
@@ -331,20 +361,31 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
331
361
332
362
return success ();
333
363
}
364
+
365
+ private:
366
+ ControlFoldIntoPackUnpackFn controlFn;
334
367
};
335
368
336
369
// / Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
337
370
// / semantics.
338
371
struct FoldConsumerPackWithProducerLinalgTransposeOp
339
372
: public OpRewritePattern<PackOp> {
340
- using OpRewritePattern<PackOp>::OpRewritePattern;
373
+
374
+ public:
375
+ FoldConsumerPackWithProducerLinalgTransposeOp (
376
+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
377
+ : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
341
378
342
379
LogicalResult matchAndRewrite (PackOp packOp,
343
380
PatternRewriter &rewriter) const override {
344
381
auto linalgOp = packOp.getSource ().getDefiningOp <linalg::LinalgOp>();
345
382
if (!linalgOp)
346
383
return failure ();
347
384
385
+ // User controlled folding function.
386
+ if (controlFn && !controlFn (&packOp.getSourceMutable ()))
387
+ return failure ();
388
+
348
389
FailureOr<SmallVector<int64_t >> maybePerm =
349
390
getTransposeOpPermutation (linalgOp);
350
391
if (failed (maybePerm))
@@ -375,13 +416,21 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
375
416
376
417
return success ();
377
418
}
419
+
420
+ private:
421
+ ControlFoldIntoPackUnpackFn controlFn;
378
422
};
379
423
380
424
// / Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
381
425
// / transpose semantics.
382
426
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
383
427
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
384
- using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
428
+
429
+ public:
430
+ FoldProducerUnPackWithConsumerLinalgTransposeOp (
431
+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
432
+ : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
433
+ controlFn (std::move(controlFn)) {}
385
434
386
435
LogicalResult matchAndRewrite (linalg::LinalgOp linalgOp,
387
436
PatternRewriter &rewriter) const override {
@@ -390,6 +439,10 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
390
439
if (!unPackOp)
391
440
return failure ();
392
441
442
+ // User controlled folding function.
443
+ if (controlFn && !controlFn (&linalgOp->getOpOperand (0 )))
444
+ return failure ();
445
+
393
446
FailureOr<SmallVector<int64_t >> maybePerm =
394
447
getTransposeOpPermutation (linalgOp);
395
448
if (failed (maybePerm))
@@ -416,6 +469,9 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
416
469
417
470
return success ();
418
471
}
472
+
473
+ private:
474
+ ControlFoldIntoPackUnpackFn controlFn;
419
475
};
420
476
421
477
// / Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
@@ -424,12 +480,21 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
424
480
: public OpRewritePattern<UnPackOp> {
425
481
using OpRewritePattern<UnPackOp>::OpRewritePattern;
426
482
483
+ public:
484
+ FoldConsumerUnPackWithProducerLinalgTransposeOp (
485
+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
486
+ : OpRewritePattern<UnPackOp>(context), controlFn(std::move(controlFn)) {}
487
+
427
488
LogicalResult matchAndRewrite (UnPackOp unPackOp,
428
489
PatternRewriter &rewriter) const override {
429
490
auto linalgOp = unPackOp.getSource ().getDefiningOp <linalg::LinalgOp>();
430
491
if (!linalgOp)
431
492
return failure ();
432
493
494
+ // User controlled folding function.
495
+ if (controlFn && !controlFn (&unPackOp.getSourceMutable ()))
496
+ return failure ();
497
+
433
498
FailureOr<SmallVector<int64_t >> maybePerm =
434
499
getTransposeOpPermutation (linalgOp);
435
500
if (failed (maybePerm))
@@ -474,6 +539,9 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
474
539
475
540
return success ();
476
541
}
542
+
543
+ private:
544
+ ControlFoldIntoPackUnpackFn controlFn;
477
545
};
478
546
479
547
// / tensor.empty does not define any tensor contents, so an unpadded pack
@@ -521,13 +589,14 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
521
589
522
590
} // namespace
523
591
524
- void populateFoldIntoPackAndUnpackPatterns (RewritePatternSet &patterns) {
592
+ void populateFoldIntoPackAndUnpackPatterns (
593
+ RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn) {
525
594
patterns.insert <FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
526
595
FoldProducerPackWithConsumerLinalgTransposeOp,
527
596
FoldConsumerPackWithProducerLinalgTransposeOp,
528
597
FoldConsumerUnPackWithProducerLinalgTransposeOp,
529
598
FoldProducerUnPackWithConsumerLinalgTransposeOp>(
530
- patterns.getContext ());
599
+ patterns.getContext (), controlFn );
531
600
}
532
601
533
602
void populateSimplifyPackAndUnpackPatterns (RewritePatternSet &patterns) {
0 commit comments