Skip to content

Commit 56bae89

Browse files
committed
squash
1 parent 02f60fd commit 56bae89

File tree

12 files changed

+2519
-1326
lines changed

12 files changed

+2519
-1326
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -348,12 +348,12 @@ void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns,
348348
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns,
349349
PatternBenefit benefit = 1);
350350

351-
/// Collect a set of patterns to flatten n-D vector transfers on contiguous
352-
/// memref.
351+
/// Collect a set of patterns to flatten/linearize n-D vector transfers on
352+
/// contiguous memref.
353353
///
354354
/// These patterns insert memref.collapse_shape + vector.shape_cast patterns
355-
/// to transform multiple small n-D transfers into a larger 1-D transfer where
356-
/// the memref contiguity properties allow it.
355+
/// to transform a n-D transfer into a larger 1-D transfer where the memref
356+
/// contiguity properties allow it.
357357
///
358358
/// Flattening is only applied if the bitwidth of the trailing vector dimension
359359
/// is smaller or equal to `targetVectorBitwidth`.
@@ -362,6 +362,20 @@ void populateFlattenVectorTransferPatterns(
362362
unsigned targetVectorBitwidth = std::numeric_limits<unsigned>::max(),
363363
PatternBenefit benefit = 1);
364364

365+
/// Collect a set of patterns to flatten/linearize operations on vectors.
366+
/// TODO(newling) combine API with `populateFlattenVectorTransferPatterns`.
367+
void populateForVectorLinearize(
368+
RewritePatternSet &patterns,
369+
const std::function<LogicalResult(Operation *)> &preCondition =
370+
[](Operation *) { return success(); },
371+
PatternBenefit benefit = 1);
372+
373+
/// Collect a set of patterns to rewrite vector.extract_strided_slice and
374+
/// vector.insert_strided_slice operations to be as low-rank as possible.
375+
/// This is done by using shape_cast to combine non-strided dimensions.
376+
void populateForStridedRankReduction(RewritePatternSet &patterns,
377+
PatternBenefit benefit = 1);
378+
365379
/// Collect a set of patterns that bubble up/down bitcast ops.
366380
///
367381
/// These patterns move vector.bitcast ops to be before insert ops or after
@@ -408,39 +422,6 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
408422
void populateVectorTransposeNarrowTypeRewritePatterns(
409423
RewritePatternSet &patterns, PatternBenefit benefit = 1);
410424

411-
/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
412-
///
413-
/// Definition: here 'linearization' means converting a single operation with
414-
/// 1+ vector operand/result of rank>1, into a new single operation whose
415-
/// vector operands and results are all of rank<=1.
416-
///
417-
/// This function registers (1) which operations are legal, and hence should not
418-
/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
419-
/// materialze the conversion (with shape_cast)
420-
///
421-
/// Note: the set of legal operations can be extended by a user if for example
422-
/// certain rank>1 vectors are considered valid, by adding additional
423-
/// dynamically legal ops to `conversionTarget`.
424-
///
425-
/// Further note: the choice to use a dialect conversion design for
426-
/// linearization is to make it easy to reuse generic structural type
427-
/// conversions for linearizing scf/cf/func operations
428-
void populateForVectorLinearize(TypeConverter &typeConverter,
429-
ConversionTarget &conversionTarget);
430-
431-
/// Populates `patterns` for ND vector (N >= 2) linearization. This currently
432-
/// contains patterns for converting ConstantLike, Vectorizable, and
433-
/// vector::BitCast ops.
434-
void populateVectorLinearizeBasePatterns(const TypeConverter &,
435-
const ConversionTarget &,
436-
RewritePatternSet &patterns);
437-
438-
/// Populates `patterns` for linearizing ND (N >= 2) vector operations
439-
/// to 1D vector shuffle operations.
440-
void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
441-
const ConversionTarget &,
442-
RewritePatternSet &patterns);
443-
444425
} // namespace vector
445426
} // namespace mlir
446427

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3861,7 +3861,11 @@ Type OuterProductOp::getExpectedMaskType() {
38613861
static Type inferStridedSliceOpResultType(VectorType vectorType,
38623862
ArrayAttr offsets, ArrayAttr sizes,
38633863
ArrayAttr strides) {
3864-
assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3864+
3865+
assert(offsets.size() == sizes.size() &&
3866+
"offsets and sizes must be same size");
3867+
assert(offsets.size() == strides.size() &&
3868+
"offsets and strides must be same size");
38653869
SmallVector<int64_t, 4> shape;
38663870
shape.reserve(vectorType.getRank());
38673871
unsigned idx = 0;
@@ -5881,13 +5885,21 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
58815885

58825886
VectorType resultType = getType();
58835887

5884-
// No-op shape cast.
5885-
if (getSource().getType() == resultType)
5886-
return getSource();
5888+
// y = shape_cast(shape_cast(shape_cast(x)))
5889+
// -> shape_cast(x) # if x and y different types
5890+
// -> x # if x and y same type
5891+
// Value newSource = getSource();
5892+
ShapeCastOp parent = *this;
5893+
while (auto precedingShapeCast =
5894+
parent.getSource().getDefiningOp<ShapeCastOp>()) {
5895+
parent = precedingShapeCast;
5896+
}
5897+
5898+
if (parent.getSource().getType() == resultType)
5899+
return parent.getSource();
58875900

5888-
// shape_cast(shape_cast(x)) -> shape_cast(x)
5889-
if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
5890-
setOperand(precedingShapeCast.getSource());
5901+
if (parent != *this) {
5902+
setOperand(parent.getSource());
58915903
return getResult();
58925904
}
58935905

@@ -5907,14 +5919,20 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
59075919
return bcastOp.getSource();
59085920
}
59095921

5910-
// shape_cast(constant) -> constant
5911-
if (auto splatAttr =
5912-
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
5913-
return splatAttr.reshape(getType());
5922+
Attribute attr = adaptor.getSource();
5923+
if (attr) {
5924+
// shape_cast(constant) -> constant
5925+
if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
5926+
return splatAttr.reshape(getType());
59145927

5915-
// shape_cast(poison) -> poison
5916-
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
5917-
return ub::PoisonAttr::get(getContext());
5928+
if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(attr)) {
5929+
return dstElementsAttr.reshape(getType());
5930+
}
5931+
5932+
// shape_cast(poison) -> poison
5933+
if (llvm::dyn_cast<ub::PoisonAttr>(attr)) {
5934+
return ub::PoisonAttr::get(getContext());
5935+
}
59185936
}
59195937

59205938
return {};

0 commit comments

Comments
 (0)