@@ -348,12 +348,12 @@ void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns,
348
348
void populateDropUnitDimWithShapeCastPatterns (RewritePatternSet &patterns,
349
349
PatternBenefit benefit = 1 );
350
350
351
- // / Collect a set of patterns to flatten n-D vector transfers on contiguous
352
- // / memref.
351
+ // / Collect a set of patterns to flatten/linearize n-D vector transfers on
352
+ // / contiguous memref.
353
353
// /
354
354
// / These patterns insert memref.collapse_shape + vector.shape_cast patterns
355
- // / to transform multiple small n-D transfers into a larger 1-D transfer where
356
- // / the memref contiguity properties allow it.
355
+ // / to transform a n-D transfer into a larger 1-D transfer where the memref
356
+ // / contiguity properties allow it.
357
357
// /
358
358
// / Flattening is only applied if the bitwidth of the trailing vector dimension
359
359
// / is smaller or equal to `targetVectorBitwidth`.
@@ -362,6 +362,20 @@ void populateFlattenVectorTransferPatterns(
362
362
unsigned targetVectorBitwidth = std::numeric_limits<unsigned >::max(),
363
363
PatternBenefit benefit = 1);
364
364
365
+ // / Collect a set of patterns to flatten/linearize operations on vectors.
366
+ // / TODO(newling) combine API with `populateFlattenVectorTransferPatterns`.
367
+ void populateForVectorLinearize (
368
+ RewritePatternSet &patterns,
369
+ const std::function<LogicalResult(Operation *)> &preCondition =
370
+ [](Operation *) { return success (); },
371
+ PatternBenefit benefit = 1 );
372
+
373
+ // / Collect a set of patterns to rewrite vector.extract_strided_slice and
374
+ // / vector.insert_strided_slice operations to be as low-rank as possible.
375
+ // / This is done by using shape_cast to combine non-strided dimensions.
376
+ void populateForStridedRankReduction (RewritePatternSet &patterns,
377
+ PatternBenefit benefit = 1 );
378
+
365
379
// / Collect a set of patterns that bubble up/down bitcast ops.
366
380
// /
367
381
// / These patterns move vector.bitcast ops to be before insert ops or after
@@ -408,39 +422,6 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
408
422
void populateVectorTransposeNarrowTypeRewritePatterns (
409
423
RewritePatternSet &patterns, PatternBenefit benefit = 1 );
410
424
411
- // / Initialize `typeConverter` and `conversionTarget` for vector linearization.
412
- // /
413
- // / Definition: here 'linearization' means converting a single operation with
414
- // / 1+ vector operand/result of rank>1, into a new single operation whose
415
- // / vector operands and results are all of rank<=1.
416
- // /
417
- // / This function registers (1) which operations are legal, and hence should not
418
- // / be linearized, (2) what the converted types are (rank-1 vectors) and how to
419
- // / materialze the conversion (with shape_cast)
420
- // /
421
- // / Note: the set of legal operations can be extended by a user if for example
422
- // / certain rank>1 vectors are considered valid, by adding additional
423
- // / dynamically legal ops to `conversionTarget`.
424
- // /
425
- // / Further note: the choice to use a dialect conversion design for
426
- // / linearization is to make it easy to reuse generic structural type
427
- // / conversions for linearizing scf/cf/func operations
428
- void populateForVectorLinearize (TypeConverter &typeConverter,
429
- ConversionTarget &conversionTarget);
430
-
431
- // / Populates `patterns` for ND vector (N >= 2) linearization. This currently
432
- // / contains patterns for converting ConstantLike, Vectorizable, and
433
- // / vector::BitCast ops.
434
- void populateVectorLinearizeBasePatterns (const TypeConverter &,
435
- const ConversionTarget &,
436
- RewritePatternSet &patterns);
437
-
438
- // / Populates `patterns` for linearizing ND (N >= 2) vector operations
439
- // / to 1D vector shuffle operations.
440
- void populateVectorLinearizeShuffleLikeOpsPatterns (const TypeConverter &,
441
- const ConversionTarget &,
442
- RewritePatternSet &patterns);
443
-
444
425
} // namespace vector
445
426
} // namespace mlir
446
427
0 commit comments