From a1e180f655fa769439bb982a43264731f6500380 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 8 Jul 2025 14:24:31 -0700 Subject: [PATCH 1/2] squash commits for linearization changes: gradual rewrite patterns --- .../Vector/Transforms/VectorRewritePatterns.h | 63 +- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 46 +- .../Vector/Transforms/VectorLinearize.cpp | 1744 +++++++++++------ .../Vector/linearize-subject-to-bitwidth.mlir | 58 - mlir/test/Dialect/Vector/linearize.mlir | 480 ----- .../linearize-subject-to-bitwidth.mlir | 73 + .../Dialect/Vector/linearize/linearize.mlir | 782 ++++++++ .../linearize/rank-reduce-strided-ops.mlir | 195 ++ mlir/test/lib/Dialect/Vector/CMakeLists.txt | 1 + .../Dialect/Vector/TestVectorLinearize.cpp | 268 +++ .../Dialect/Vector/TestVectorTransforms.cpp | 159 -- mlir/tools/mlir-opt/mlir-opt.cpp | 3 + 12 files changed, 2561 insertions(+), 1311 deletions(-) delete mode 100644 mlir/test/Dialect/Vector/linearize-subject-to-bitwidth.mlir delete mode 100644 mlir/test/Dialect/Vector/linearize.mlir create mode 100644 mlir/test/Dialect/Vector/linearize/linearize-subject-to-bitwidth.mlir create mode 100644 mlir/test/Dialect/Vector/linearize/linearize.mlir create mode 100644 mlir/test/Dialect/Vector/linearize/rank-reduce-strided-ops.mlir create mode 100644 mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index 3dc7d38440ca5..9cec163ee36a0 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -348,12 +348,12 @@ void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); -/// Collect a set of patterns to flatten n-D vector transfers on contiguous -/// memref. +/// Collect a set of patterns to flatten/linearize n-D vector transfers on +/// contiguous memref. /// /// These patterns insert memref.collapse_shape + vector.shape_cast patterns -/// to transform multiple small n-D transfers into a larger 1-D transfer where -/// the memref contiguity properties allow it. +/// to transform a n-D transfer into a larger 1-D transfer where the memref +/// contiguity properties allow it. /// /// Flattening is only applied if the bitwidth of the trailing vector dimension /// is smaller or equal to `targetVectorBitwidth`. @@ -362,6 +362,28 @@ void populateFlattenVectorTransferPatterns( unsigned targetVectorBitwidth = std::numeric_limits::max(), PatternBenefit benefit = 1); +/// Collect a set of patterns to flatten/linearize operations on vectors. +/// +/// These patterns insert vector.shape_cast to transform operations to have +/// lower rank operands and results. +/// +/// At the start of every pattern's `matchAndRewrite` call, `preCondition` +/// is called. If it returns failure, the pattern is not applied. +/// +/// TODO(newling) combine this API with `populateFlattenVectorTransferPatterns`. +void populateForVectorLinearize( + RewritePatternSet &patterns, + const std::function &preCondition = + [](Operation *) { return success(); }, + PatternBenefit benefit = 1); + +/// Collect a set of patterns to rewrite vector.extract_strided_slice and +/// vector.insert_strided_slice operations to have the lowest possible rank. +/// This is done by using shape_cast to combine consecutive dimensions whose +/// memory is contiguous. +void populateForStridedRankReduction(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Collect a set of patterns that bubble up/down bitcast ops. /// /// These patterns move vector.bitcast ops to be before insert ops or after @@ -408,39 +430,6 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, void populateVectorTransposeNarrowTypeRewritePatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); -/// Initialize `typeConverter` and `conversionTarget` for vector linearization. -/// -/// Definition: here 'linearization' means converting a single operation with -/// 1+ vector operand/result of rank>1, into a new single operation whose -/// vector operands and results are all of rank<=1. -/// -/// This function registers (1) which operations are legal, and hence should not -/// be linearized, (2) what the converted types are (rank-1 vectors) and how to -/// materialze the conversion (with shape_cast) -/// -/// Note: the set of legal operations can be extended by a user if for example -/// certain rank>1 vectors are considered valid, by adding additional -/// dynamically legal ops to `conversionTarget`. -/// -/// Further note: the choice to use a dialect conversion design for -/// linearization is to make it easy to reuse generic structural type -/// conversions for linearizing scf/cf/func operations -void populateForVectorLinearize(TypeConverter &typeConverter, - ConversionTarget &conversionTarget); - -/// Populates `patterns` for ND vector (N >= 2) linearization. This currently -/// contains patterns for converting ConstantLike, Vectorizable, and -/// vector::BitCast ops. -void populateVectorLinearizeBasePatterns(const TypeConverter &, - const ConversionTarget &, - RewritePatternSet &patterns); - -/// Populates `patterns` for linearizing ND (N >= 2) vector operations -/// to 1D vector shuffle operations. -void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &, - const ConversionTarget &, - RewritePatternSet &patterns); - } // namespace vector } // namespace mlir diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 214d2ba7e1b8e..90d78eddb861f 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3876,7 +3876,11 @@ Type OuterProductOp::getExpectedMaskType() { static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides) { - assert(offsets.size() == sizes.size() && offsets.size() == strides.size()); + + assert(offsets.size() == sizes.size() && + "offsets and sizes must be same size"); + assert(offsets.size() == strides.size() && + "offsets and strides must be same size"); SmallVector shape; shape.reserve(vectorType.getRank()); unsigned idx = 0; @@ -5896,13 +5900,21 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { VectorType resultType = getType(); - // No-op shape cast. - if (getSource().getType() == resultType) - return getSource(); + // y = shape_cast(shape_cast(shape_cast(x))) + // -> shape_cast(x) # if x and y different types + // -> x # if x and y same type + // Value newSource = getSource(); + ShapeCastOp parent = *this; + while (auto precedingShapeCast = + parent.getSource().getDefiningOp()) { + parent = precedingShapeCast; + } + + if (parent.getSource().getType() == resultType) + return parent.getSource(); - // shape_cast(shape_cast(x)) -> shape_cast(x) - if (auto precedingShapeCast = getSource().getDefiningOp()) { - setOperand(precedingShapeCast.getSource()); + if (parent != *this) { + setOperand(parent.getSource()); return getResult(); } @@ -5922,14 +5934,20 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { return bcastOp.getSource(); } - // shape_cast(constant) -> constant - if (auto splatAttr = - llvm::dyn_cast_if_present(adaptor.getSource())) - return splatAttr.reshape(getType()); + Attribute attr = adaptor.getSource(); + if (attr) { + // shape_cast(constant) -> constant + if (auto splatAttr = llvm::dyn_cast(attr)) + return splatAttr.reshape(getType()); - // shape_cast(poison) -> poison - if (llvm::dyn_cast_if_present(adaptor.getSource())) { - return ub::PoisonAttr::get(getContext()); + if (auto dstElementsAttr = dyn_cast(attr)) { + return dstElementsAttr.reshape(getType()); + } + + // shape_cast(poison) -> poison + if (llvm::dyn_cast(attr)) { + return ub::PoisonAttr::get(getContext()); + } } return {}; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 7cac1cbafdd64..ffb94c1b65c94 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -10,11 +10,13 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -22,117 +24,111 @@ #include "llvm/ADT/ArrayRef.h" #include #include -#include using namespace mlir; -static FailureOr -linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter, - VectorType resType, Attribute value) { +namespace { - if (auto dstElementsAttr = dyn_cast(value)) { - if (resType.isScalable() && !isa(value)) - return rewriter.notifyMatchFailure( - loc, - "Cannot linearize a constant scalable vector that's not a splat"); +/// Transform `values` to have 1 fewer element. Do this by combining the element +/// at index `index` with the preceding element. Combine these 2 consecutive +/// elements using the combining function `f`. +template +static void collapseWithPrevious(Container &values, unsigned index, + const Combiner &combiningFunction) { + + assert(values.size() > 1 && "values has fewer than 2 elements"); + assert(index > 0 && index < values.size() && + "index not in range [1, rank(values))"); + + auto combined = combiningFunction(values[index - 1], values[index]); + values[index - 1] = std::move(combined); + std::copy(values.begin() + index + 1, values.end(), values.begin() + index); + values.pop_back(); +} - return dstElementsAttr.reshape(resType); - } +/// Examples: +/// values = (2, 3, 4) index = 0 ===> assertion failure +/// values = (2, 3, 4) index = 1 ===> (6, 4) +/// values = (2, 3, 4) index = 2 ===> (2, 12) +static void collapseMul(SmallVector &values, unsigned index) { + + auto combiner = [](int64_t a, int64_t b) { return a * b; }; + return collapseWithPrevious(values, index, combiner); +} - if (auto poisonAttr = dyn_cast(value)) - return poisonAttr; +/// Examples: +/// values = (true, false, false) index = 0 ===> assertion failure +/// values = (true, false, false) index = 1 ===> (true, false) +/// values = (true, false, false) index = 2 ===> (true, false) +static void collapseOr(SmallVector &values, unsigned index) { - return rewriter.notifyMatchFailure(loc, "unsupported attr type"); + auto combiner = [](bool a, bool b) { return a || b; }; + return collapseWithPrevious(values, index, combiner); } -namespace { +/// Collapse dimension `dim` and the preceding dimension into a single +/// dimension, if possible. If not possible, return `vectorType`. +static VectorType getReducedType(VectorType vectorType, unsigned dim) { -struct LinearizeConstantLike final - : OpTraitConversionPattern { - using OpTraitConversionPattern::OpTraitConversionPattern; + if (!vectorType || vectorType.getRank() <= 1) + return vectorType; - LinearizeConstantLike(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) - : OpTraitConversionPattern(typeConverter, context, benefit) {} - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - if (op->getNumResults() != 1) - return rewriter.notifyMatchFailure(loc, "expected 1 result"); - - const TypeConverter &typeConverter = *getTypeConverter(); - auto resType = - typeConverter.convertType(op->getResult(0).getType()); - assert(resType && "expected 1-D vector type"); - - StringAttr attrName = rewriter.getStringAttr("value"); - Attribute value = op->getAttr(attrName); - if (!value) - return rewriter.notifyMatchFailure(loc, "no 'value' attr"); - - FailureOr newValue = - linearizeConstAttr(loc, rewriter, resType, value); - if (failed(newValue)) - return failure(); - - FailureOr convertResult = - convertOpResultTypes(op, /*operands=*/{}, typeConverter, rewriter); - if (failed(convertResult)) - return failure(); - - Operation *newOp = *convertResult; - newOp->setAttr(attrName, *newValue); - rewriter.replaceOp(op, newOp); - return success(); - } -}; + ArrayRef scalableDims = vectorType.getScalableDims(); + assert(scalableDims.size() > 1 && "rank and mask size not same size"); -struct LinearizeVectorizable final - : OpTraitConversionPattern { - using OpTraitConversionPattern::OpTraitConversionPattern; + // 2 scalable dimensions cannot be collapsed together. + if (scalableDims[dim - 1] && scalableDims[dim]) + return vectorType; -public: - LinearizeVectorizable(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) - : OpTraitConversionPattern(typeConverter, context, benefit) {} - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - FailureOr newOp = - convertOpResultTypes(op, operands, *getTypeConverter(), rewriter); - if (failed(newOp)) - return failure(); - - rewriter.replaceOp(op, (*newOp)->getResults()); - return success(); - } -}; + SmallVector newMask(vectorType.getScalableDims()); + collapseOr(newMask, dim); + + SmallVector newShape(vectorType.getShape()); + collapseMul(newShape, dim); -template -static bool stridesAllOne(TOp op) { - static_assert( - std::is_same_v || - std::is_same_v, - "expected vector.extract_strided_slice or vector.insert_strided_slice"); - ArrayAttr strides = op.getStrides(); - return llvm::all_of(strides, isOneInteger); + return VectorType::get(newShape, vectorType.getElementType(), newMask); } -/// Convert an array of attributes into a vector of integers, if possible. -static FailureOr> intsFromArrayAttr(ArrayAttr attrs) { - if (!attrs) +/// Collapse the final 2 dimensions of `vectorType`, if possible. +/// If not possible, return `vectorType`. +static VectorType getReducedType(VectorType vectorType) { + + if (!vectorType || vectorType.getRank() < 2) + return vectorType; + + return getReducedType(vectorType, vectorType.getRank() - 1); +} + +/// Collapse all the dimensions of `vectorType` into a single dimension, if +/// possible. +static FailureOr getRankOneType(VectorType vectorType) { + + // Multiple scalable dimensions cannot be collapsed together. + if (!vectorType || vectorType.getNumScalableDims() > 1) return failure(); - SmallVector ints; - ints.reserve(attrs.size()); - for (auto attr : attrs) { - if (auto intAttr = dyn_cast(attr)) { - ints.push_back(intAttr.getInt()); - } else { - return failure(); - } - } - return ints; + + VectorType rankOneType = + VectorType::get({vectorType.getNumElements()}, + vectorType.getElementType(), vectorType.isScalable()); + + return rankOneType; +} + +/// If `value` is a vector type of a rank other than 1, use a shape_cast to +/// get a vector of rank 1, if possible. +static FailureOr getCollapsedToRankOne(Value value, + PatternRewriter &rewriter) { + + auto vectorType = dyn_cast(value.getType()); + if (!vectorType) + return failure(); + + FailureOr rankOneType = getRankOneType(vectorType); + if (failed(rankOneType)) + return failure(); + + return rewriter.createOrFold(value.getLoc(), + rankOneType.value(), value); } /// Consider inserting a vector of shape `small` into a vector of shape `large`, @@ -145,22 +141,31 @@ static FailureOr> intsFromArrayAttr(ArrayAttr attrs) { /// /// The length of the returned vector is equal to the number of elements in /// the shape `small` (i.e. the product of dimensions of `small`). -SmallVector static getStridedSliceInsertionIndices( - ArrayRef small, ArrayRef large, - ArrayRef offsets) { - - // Example of alignment between, `large`, `small` and `offsets`: - // large = 4, 5, 6, 7, 8 - // small = 1, 6, 7, 8 - // offsets = 2, 3, 0 - // - // `offsets` has implicit trailing 0s, `small` has implicit leading 1s. +/// +/// Possible input for `large`, `small` and `offsets`: +/// large = 4, 5, 6, 7, 8 +/// small = 1, 6, 7, 8 +/// offsets = 2, 3, 0 +/// +/// `small` and `large` must not have more elements than `large`. If `offsets` +/// has fewer elements than `large`, it has implicit trailing 0s. If `small` has +/// fewer elements than `large`, it has implicit leading 1s. So the example +/// above is equivalent to +/// +/// large = 4, 5, 6, 7, 8 +/// small = 1, 1, 6, 7, 8 +/// offsets = 2, 3, 0, 0, 0 +static SmallVector +getStridedSliceInsertionIndices(ArrayRef small, + ArrayRef large, + ArrayRef offsets) { + assert((large.size() >= small.size()) && "rank of 'large' cannot be lower than rank of 'small'"); assert((large.size() >= offsets.size()) && "rank of 'large' cannot be lower than the number of offsets"); - unsigned delta = large.size() - small.size(); - unsigned nOffsets = offsets.size(); + const unsigned delta = large.size() - small.size(); + const unsigned nOffsets = offsets.size(); auto getSmall = [&](int64_t i) -> int64_t { return i >= delta ? small[i - delta] : 1; }; @@ -169,14 +174,15 @@ SmallVector static getStridedSliceInsertionIndices( }; // Using 2 vectors of indices, at each iteration populate the updated set of - // indices based on the old set of indices, and the size of the small vector - // in the current iteration. + // indices based on the old set of indices, and the size of the small + // vector in the current iteration. SmallVector indices{0}; + const int largeRank = large.size(); int64_t stride = 1; - for (int i = large.size() - 1; i >= 0; --i) { - int64_t currentSize = indices.size(); - int64_t smallSize = getSmall(i); - int64_t nextSize = currentSize * smallSize; + for (int i = largeRank - 1; i >= 0; --i) { + const int64_t currentSize = indices.size(); + const int64_t smallSize = getSmall(i); + const int64_t nextSize = currentSize * smallSize; SmallVector nextIndices(nextSize); int64_t *base = nextIndices.begin(); int64_t offset = getOffset(i) * stride; @@ -193,563 +199,1175 @@ SmallVector static getStridedSliceInsertionIndices( return indices; } -/// This pattern converts a vector.extract_strided_slice operation into a -/// vector.shuffle operation that has a rank-1 (linearized) operand and result. +/// Combine the first 2 elements of `position` into a single element, if +/// possible. The positions are merged based on the shape of `vectorType`. +/// The returned value specifies if `position` changes. +static bool collapseFront(SmallVector &position, + VectorType vectorType, PatternRewriter &rewriter) { + + if (position.size() <= 1) + return false; + + assert(vectorType && "expected a vector type"); + assert(vectorType.getRank() > 1 && + "vectorType must have rank no less than size of 'position'"); + + Attribute attributeDimZero = dyn_cast(position[0]); + Attribute attributeDimOne = dyn_cast(position[1]); + + // We don't currently support combining dynamic positions: + if (!attributeDimZero || !attributeDimOne) + return false; + + int64_t intDimZero = cast(attributeDimZero).getInt(); + int64_t intDimOne = cast(attributeDimOne).getInt(); + + int64_t newLeadingPos = intDimZero * vectorType.getDimSize(1) + intDimOne; + IntegerAttr leadingPos = rewriter.getI64IntegerAttr(newLeadingPos); + position[1] = leadingPos; + position.erase(position.begin()); + return true; +} + +/// Return true if the mask operation has 0 or 1 non-unit dimensions. +static bool +isCreateMaskWithAtMostOneNonUnit(vector::CreateMaskOp createMaskOp) { + ArrayRef shape = createMaskOp.getType().getShape(); + bool multipleNonUnitDim = + llvm::count_if(shape, [](int64_t dim) { return dim > 1; }) > 1; + return !multipleNonUnitDim; +} + +/// Find the inner most dimension `dim` such that an insert_strided_slice or +/// extract_strided_slice slice can be rewritten by collapsing dimensions `dim` +/// and `dim` - 1. If such a dimension is found, update `largeType`, `small`, +/// and `offsets` in place, and return `true`. If no such dimension is found, +/// return `false`. +/// +/// The assumptions on the sizes of `small`, `largeType`, and `offsets` are the +/// same as the function `getStridedSliceInsertionIndices`, please see the +/// example there. + +/// The return type encapsulating the types of the 'collapsed' operation. +struct StridedSliceTriple { + VectorType small; + VectorType large; + SmallVector offsets; +}; + +static FailureOr +collapseInnerMostPossible(VectorType smallType, const VectorType largeType, + const ArrayRef offsets) { + + assert(largeType.getRank() >= smallType.getRank() && + "rank of 'small' is greater than rank of 'large'"); + + // Rank-1 cannot be reduced to rank-0. + if (largeType.getRank() <= 1) + return failure(); + + // Prepend to the small type so that it has the same rank as the large type. + // Doing this upfront requires data copies before confirming that we won't + // return failure, but simplifies the logic significantly and so is deemed + // worth it. + smallType = [&]() { + const int64_t dr = largeType.getRank() - smallType.getRank(); + SmallVector shape(smallType.getShape()); + SmallVector scale(smallType.getScalableDims()); + shape.insert(shape.begin(), dr, 1); + scale.insert(scale.begin(), dr, false); + return VectorType::get(shape, smallType.getElementType(), scale); + }(); + + const ArrayRef smallShape = smallType.getShape(); + const ArrayRef largeShape = largeType.getShape(); + const ArrayRef scalableDims = largeType.getScalableDims(); + + // The algorithm iterates through the dimensions of the small type, from + // the back (inner-most dimension) to the front. When the remaining prefix is + // all 1's, the condition of collapsibility is more relaxed. Specifically, + // when the prefix is not all 1's, then the corresponding sizes the large and + // small types must match. To detect for all 1's, we keep track of the product + // of dimensions visited and compare it the total number of elements in the + // small type: + const int64_t totalElementsInSmall = smallType.getNumElements(); + + int64_t suffixElementsInSmall = 1; + for (int si = smallType.getRank() - 1; si > 0; --si) { + + suffixElementsInSmall *= smallShape[si]; + if ((suffixElementsInSmall != totalElementsInSmall) && + (smallShape[si] != largeShape[si])) + continue; + + // Can only collapse scalable dims if the resulting collapsed dimension is + // the same size in the 2 vectors. + if (scalableDims[si] || scalableDims[si - 1]) { + if (smallShape[si] != largeShape[si] || + smallShape[si - 1] != largeShape[si - 1]) + continue; + } + + const VectorType flatLarge = getReducedType(largeType, si); + if (flatLarge == largeType) + continue; + + VectorType flatSmall = getReducedType(smallType, si); + SmallVector flatOffsets(offsets); + flatOffsets.resize(largeType.getRank(), 0); + flatOffsets[si - 1] *= largeShape[si]; + flatOffsets[si - 1] += flatOffsets[si]; + flatOffsets.erase(flatOffsets.begin() + si); + return StridedSliceTriple{flatSmall, flatLarge, flatOffsets}; + } + return failure(); +} + +/// Convert an array of attributes into a vector of integers. +static FailureOr> intsFromArrayAttr(ArrayAttr attributes) { + + if (!attributes || llvm::any_of(attributes, [](Attribute a) { + return !isa(a); + })) + return failure(); + + SmallVector asIntegers; + asIntegers.reserve(attributes.size()); + for (auto attr : attributes) + asIntegers.push_back(cast(attr).getInt()); + + return asIntegers; +} + +/// Return `value` with dimensions `dim` and its preceding dimension combined, +/// if possible. Otherwise return `value`. +static Value getReducedValue(PatternRewriter &rewriter, Value value, + unsigned dim) { + + VectorType vectorType = dyn_cast(value.getType()); + if (!vectorType) + return value; + + VectorType reducedType = getReducedType(vectorType, dim); + return rewriter.createOrFold(value.getLoc(), reducedType, + value); +} + +/// Reduce the inner two dimensions of `value` using a shape_cast, if possible. +static Value getReducedValue(PatternRewriter &rewriter, Value value) { + + VectorType vectorType = dyn_cast(value.getType()); + if (!vectorType || vectorType.getRank() <= 1) + return value; + + return getReducedValue(rewriter, value, vectorType.getRank() - 1); +} + +/// Reduce the innermost 2 dimensions of values in `values` using a shape_cast, +/// otherwise retain the original value. +static SmallVector getReducedValues(ValueRange values, + PatternRewriter &rewriter) { + + SmallVector replacements; + replacements.reserve(values.size()); + for (auto val : values) + replacements.push_back(getReducedValue(rewriter, val)); + + return replacements; +} + +using PreCondition = std::function; + +/// This class automates the running of a user provided matcher at the start of +/// `matchAndRewrite`. Classes that inherit from it must implement +/// `postConditionMatchAndRewrite` instead of `matchAndRewrite`. +template +struct OpRewritePatternWithPreCondition : OpRewritePattern { + OpRewritePatternWithPreCondition(MLIRContext *context, const PreCondition &p, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), preCondition(p) {} + +private: + LogicalResult matchAndRewrite(TOp op, PatternRewriter &rewriter) const final { + if (failed(preCondition(op))) + return rewriter.notifyMatchFailure(op, "the precondition failed"); + return postConditionMatchAndRewrite(op, rewriter); + } + + virtual LogicalResult + postConditionMatchAndRewrite(TOp op, PatternRewriter &rewriter) const = 0; + + PreCondition preCondition; +}; + +/// Linearize the innermost 2 dimensions of a vector.bitcast /// -/// For example, the following: +/// BEFORE: +/// %b = vector.bitcast %arg0 : vector<1x3x[8]xi8> to vector<1x3x[2]xi32> /// -/// ``` -/// vector.extract_strided_slice %source -/// { offsets = [..], strides = [..], sizes = [..] } -/// ``` -/// -/// is converted to : -/// ``` -/// %source_1d = vector.shape_cast %source -/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] -/// %out_nd = vector.shape_cast %out_1d -/// ``` -/// -/// `shuffle_indices_1d` is computed using the offsets and sizes of the original -/// vector.extract_strided_slice operation. -struct LinearizeVectorExtractStridedSlice final - : public mlir::OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter, - MLIRContext *context, - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} +/// AFTER: +/// %0 = vector.shape_cast %arg0 : vector<1x3x[8]xi8> to vector<[24]xi8> +/// %1 = vector.bitcast %0 : vector<[24]xi8> to vector<[6]xi32> +/// %b = vector.shape_cast %1 : vector<[6]xi32> to vector<1x3x[2]xi32> +struct CollapseInnerVectorBitCast final + : OpRewritePatternWithPreCondition { + + CollapseInnerVectorBitCast(MLIRContext *context, const PreCondition &p, + PatternBenefit benefit = 1) + : OpRewritePatternWithPreCondition(context, p, + benefit) {} LogicalResult - matchAndRewrite(vector::ExtractStridedSliceOp extractStridedSliceOp, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + postConditionMatchAndRewrite(vector::BitCastOp bitCast, + PatternRewriter &rewriter) const final { + + VectorType preType = bitCast.getResultVectorType(); + VectorType postType = getReducedType(preType); + if (postType == preType) + return rewriter.notifyMatchFailure(bitCast, "result type is irreducible"); + Value source = getReducedValue(rewriter, bitCast.getSource()); + Value newBitCast = + rewriter.create(bitCast.getLoc(), postType, source); + rewriter.replaceOpWithNewOp(bitCast, preType, + newBitCast); + return success(); + } +}; - VectorType flatOutputType = getTypeConverter()->convertType( - extractStridedSliceOp.getType()); - assert(flatOutputType && "vector type expected"); +/// Linearize the innermost 2 dimensions of a vectorizable operation. +/// +/// BEFORE: +/// %s = math.sin %0 : vector<4x3x2xi8> +/// +/// AFTER: +/// %1 = vector.shape_cast %0 : vector<4x3x2xi8> to vector<4x6xi8> +/// %2 = math.sin %1 : vector<4x6xi8> +/// %s = vector.shape_cast %2 : vector<4x6xi8> to vector<4x3x2xi8> +struct CollapseInnerVectorizable final + : OpTraitRewritePattern { + using OpTraitRewritePattern::OpTraitRewritePattern; - // Expect a legalization failure if the strides are not all 1 (if ever the - // verifier for extract_strided_slice allows non-1 strides). - if (!stridesAllOne(extractStridedSliceOp)) { - return rewriter.notifyMatchFailure( - extractStridedSliceOp, - "extract_strided_slice with strides != 1 not supported"); - } +public: + CollapseInnerVectorizable(MLIRContext *context, const PreCondition &p, + PatternBenefit benefit) + : OpTraitRewritePattern(context, benefit), + preCondition(p) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { - FailureOr> offsets = - intsFromArrayAttr(extractStridedSliceOp.getOffsets()); - if (failed(offsets)) { - return rewriter.notifyMatchFailure(extractStridedSliceOp, - "failed to get integer offsets"); + if (failed(preCondition(op))) + return rewriter.notifyMatchFailure(op, "precondition failed"); + + if (op->getNumResults() != 1) + return rewriter.notifyMatchFailure(op, "does not have 1 result"); + + auto preType = dyn_cast(op->getResult(0).getType()); + if (!preType) + return rewriter.notifyMatchFailure(op, "unique result is not a vector"); + + VectorType postType = getReducedType(preType); + if (postType == preType) + return rewriter.notifyMatchFailure(op, "result has an irreducible type"); + + OperationState newOpState(op->getLoc(), op->getName()); + newOpState.addOperands(getReducedValues(op->getOperands(), rewriter)); + newOpState.addTypes(postType); + newOpState.addAttributes(op->getAttrs()); + Operation *newOp = rewriter.create(newOpState); + + rewriter.replaceOpWithNewOp(op, preType, + newOp->getResult(0)); + + return success(); + } + +private: + PreCondition preCondition; +}; + +/// Linearize the innermost 2 dimensions of a vector.shuffle operation. +/// +/// BEFORE: +/// %shuffle_2d = vector.shuffle %v1_2d, %v2_2d [ shuffle_indices ] +/// +/// AFTER: +/// %v1_1d = vector.shape_cast %v1_2d : [...] +/// %v2_1d = vector.shape_cast %v2_2d : [...] +/// %shuffle_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ] +/// %shuffle_2d = vector.shape_cast %shuffle_1d : [...] +/// +/// Where `shuffle_indices_1d` are computed by expanding `shuffle_indices`. +struct CollapseInnerVectorShuffle final + : OpRewritePatternWithPreCondition { + + CollapseInnerVectorShuffle(MLIRContext *context, const PreCondition &p, + PatternBenefit benefit = 1) + : OpRewritePatternWithPreCondition(context, p, + benefit) {} + + LogicalResult + postConditionMatchAndRewrite(vector::ShuffleOp shuffleOp, + PatternRewriter &rewriter) const final { + VectorType preType = shuffleOp.getResultVectorType(); + VectorType postType = getReducedType(preType); + if (postType == preType) + return rewriter.notifyMatchFailure(shuffleOp, "irreducible type"); + SmallVector newOperands = + getReducedValues(shuffleOp.getOperands(), rewriter); + const ArrayRef oldMask = shuffleOp.getMask(); + const ArrayRef v1Shape = shuffleOp.getV1VectorType().getShape(); + + // Only if the outermost dimension is being collapsed does the mask get + // modified: + auto factor = v1Shape.size() > 2 ? 1 : v1Shape.back(); + SmallVector indices(oldMask.size() * factor); + for (auto [i, value] : llvm::enumerate(oldMask)) { + auto *iter = indices.begin() + factor * i; + std::iota(iter, iter + factor, factor * value); } + auto newShuffle = rewriter.create( + shuffleOp.getLoc(), postType, newOperands[0], newOperands[1], indices); + rewriter.replaceOpWithNewOp(shuffleOp, preType, + newShuffle.getResult()); + return success(); + } +}; - ArrayRef inputShape = - extractStridedSliceOp.getSourceVectorType().getShape(); +/// Collapse the 2 innermost dimensions of a vector.extract_strided_slice that +/// can be collapsed. +/// +/// BEFORE: +/// %o vector.extract_strided_slice %arg0 [...] +/// vector<4x8xi8> to vector<2x8xi8> +/// +/// AFTER: +/// %0 = vector.shape_cast %arg0 : vector<4x8xi8> to vector<32xi8> +/// %1 = vector.extract_strided_slice %0 [...] vector<32xi8> to vector<16xi8> +/// %o = vector.shape_cast %1 : vector<16xi8> to vector<2x8xi8> +/// +/// Note that this pattern will collapse the first pair of successive dimensions +/// that it can, starting from the 2 innermost dimensions and working to the +/// outermost 2 dimensions. If no such pair of dimensions is found, the pattern +/// fails to match +struct CollapseInnerExtractStrided final + : public OpRewritePatternWithPreCondition { + CollapseInnerExtractStrided(MLIRContext *context, const PreCondition &p, + PatternBenefit benefit = 1) + : OpRewritePatternWithPreCondition( + context, p, benefit) {} - ArrayRef outputShape = extractStridedSliceOp.getType().getShape(); + LogicalResult + postConditionMatchAndRewrite(vector::ExtractStridedSliceOp extractOp, + PatternRewriter &rewriter) const final { - SmallVector indices = getStridedSliceInsertionIndices( - outputShape, inputShape, offsets.value()); + FailureOr> maybeIntOffsets = + intsFromArrayAttr(extractOp.getOffsets()); + if (failed(maybeIntOffsets)) + return rewriter.notifyMatchFailure(extractOp, + "failed to obtain integer offsets"); + + FailureOr updated = collapseInnerMostPossible( + extractOp.getType(), extractOp.getSourceVectorType(), + maybeIntOffsets.value()); + if (failed(updated)) + return rewriter.notifyMatchFailure(extractOp, + "failed to collapse any dimensions"); + + auto flatIn = rewriter.createOrFold( + extractOp.getLoc(), updated->large, extractOp.getVector()); + + auto replacement = rewriter.create( + extractOp.getLoc(), flatIn, updated->offsets, updated->small.getShape(), + SmallVector(updated->offsets.size(), 1)); + + rewriter.replaceOpWithNewOp( + extractOp, extractOp.getType(), replacement); - Value srcVector = adaptor.getVector(); - rewriter.replaceOpWithNewOp( - extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices); return success(); } }; -/// This pattern converts a vector.insert_strided_slice operation into a -/// vector.shuffle operation that has rank-1 (linearized) operands and result. +/// Collapse the 2 innermost dimensions of a vector.insert_strided_slice that +/// can be collapsed. /// -/// For example, the following: -/// ``` -/// %0 = vector.insert_strided_slice %to_store, %into -/// {offsets = [1, 0, 0, 0], strides = [1, 1]} -/// : vector<2x2xi8> into vector<2x1x3x2xi8> -/// ``` +/// BEFORE: +/// %o = vector.insert_strided_slice %arg0, %arg1 [...] vector +/// <2x2xi8> into vector<3x2xi8> /// -/// is converted to -/// ``` -/// %to_store_1d -/// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8> -/// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8> -/// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ] -/// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8> -/// ``` +/// AFTER: +/// %0 = vector.shape_cast %arg0 : vector<2x2xi8> to vector<4xi8> +/// %1 = vector.shape_cast %arg1 : vector<3x2xi8> to vector<6xi8> +/// %2 = vector.insert_strided_slice %0, %1 [...] vector<4xi8> into vector<6xi8> +/// %o = vector.shape_cast %2 : vector<6xi8> to vector<3x2xi8> +struct CollapseInnerInsertStrided final + : public OpRewritePatternWithPreCondition { + CollapseInnerInsertStrided(MLIRContext *context, const PreCondition &p, + PatternBenefit benefit = 1) + : OpRewritePatternWithPreCondition( + context, p, benefit) {} + + LogicalResult + postConditionMatchAndRewrite(vector::InsertStridedSliceOp insertOp, + PatternRewriter &rewriter) const final { + + FailureOr> maybeIntOffsets = + intsFromArrayAttr(insertOp.getOffsets()); + if (failed(maybeIntOffsets)) + return rewriter.notifyMatchFailure(insertOp, + "failed to obtain integer offsets"); + + FailureOr updated = + collapseInnerMostPossible(insertOp.getSourceVectorType(), + insertOp.getType(), maybeIntOffsets.value()); + + if (failed(updated)) + return rewriter.notifyMatchFailure(insertOp, + "failed to collapse any dimensions"); + + Value shapeCast = rewriter.createOrFold( + insertOp.getLoc(), updated->small, insertOp.getValueToStore()); + + Value collapsedDst = rewriter.createOrFold( + insertOp.getLoc(), updated->large, insertOp.getDest()); + + auto replacement = rewriter.create( + insertOp.getLoc(), shapeCast, collapsedDst, updated->offsets, + SmallVector(updated->offsets.size(), 1)); + + rewriter.replaceOpWithNewOp( + insertOp, insertOp.getType(), replacement); + + return success(); + } +}; + +/// Collapse the 2 innermost dimensions of a vector.extract /// -/// where shuffle_indices_1d in this case is -/// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11]. -/// ^^^^^^^^^^^^^^ -/// to_store_1d +/// BEFORE: +/// %o = vector.extract %arg0[1, 2] : vector<5x7xi8> from vector<2x3x5x7xi8> /// -struct LinearizeVectorInsertStridedSlice final - : public mlir::OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter, - MLIRContext *context, - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} +/// AFTER: +/// %0 = vector.shape_cast %arg0 : vector<2x3x5x7xi8> to vector<2x3x35xi8> +/// %1 = vector.extract %0[1, 2] : vector<35xi8> from vector<2x3x35xi8> +/// %o = vector.shape_cast %1 : vector<35xi8> to vector<5x7xi8> +struct CollapseInnerExtract final + : public OpRewritePatternWithPreCondition { + + CollapseInnerExtract(MLIRContext *context, const PreCondition &p, + PatternBenefit benefit = 1) + : OpRewritePatternWithPreCondition(context, p, + benefit) {} LogicalResult - matchAndRewrite(vector::InsertStridedSliceOp insertStridedSliceOp, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + postConditionMatchAndRewrite(vector::ExtractOp extractOp, + PatternRewriter &rewriter) const final { - // Expect a legalization failure if the strides are not all 1 (if ever the - // verifier for insert_strided_slice allows non-1 strides). - if (!stridesAllOne(insertStridedSliceOp)) { + auto vectorType = dyn_cast(extractOp.getType()); + if (!vectorType) return rewriter.notifyMatchFailure( - insertStridedSliceOp, - "insert_strided_slice with strides != 1 not supported"); - } + extractOp, "result type is scalar, cannot collapse inner dimensions"); - VectorType inputType = insertStridedSliceOp.getValueToStore().getType(); - ArrayRef inputShape = inputType.getShape(); + VectorType reducedType = getReducedType(vectorType); + if (reducedType == vectorType) + return rewriter.notifyMatchFailure(extractOp, + "result type is irreducible"); - VectorType outputType = insertStridedSliceOp.getType(); - ArrayRef outputShape = outputType.getShape(); - int64_t nOutputElements = outputType.getNumElements(); + Value reducedIn = getReducedValue(rewriter, extractOp.getVector()); - FailureOr> offsets = - intsFromArrayAttr(insertStridedSliceOp.getOffsets()); - if (failed(offsets)) { - return rewriter.notifyMatchFailure(insertStridedSliceOp, - "failed to get integer offsets"); - } - SmallVector sliceIndices = getStridedSliceInsertionIndices( - inputShape, outputShape, offsets.value()); + Value reducedExtract = rewriter.create( + extractOp.getLoc(), reducedIn, extractOp.getMixedPosition()); - SmallVector indices(nOutputElements); - std::iota(indices.begin(), indices.end(), 0); - for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) { - indices[sliceIndex] = index + nOutputElements; - } + rewriter.replaceOpWithNewOp(extractOp, vectorType, + reducedExtract); - Value flatToStore = adaptor.getValueToStore(); - Value flatDest = adaptor.getDest(); - rewriter.replaceOpWithNewOp(insertStridedSliceOp, - flatDest.getType(), flatDest, - flatToStore, indices); return success(); } }; -/// This pattern converts the ShuffleOp that works on nD (n > 1) -/// vectors to a ShuffleOp that works on linearized vectors. -/// Following, -/// vector.shuffle %v1, %v2 [ shuffle_indices ] -/// is converted to : -/// %v1_1d = vector.shape_cast %v1 -/// %v2_1d = vector.shape_cast %v2 -/// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ] -/// %out_nd = vector.shape_cast %out_1d -// `shuffle_indices_1d` is computed using the sizes and `shuffle_indices` -/// of the original shuffle operation. -struct LinearizeVectorShuffle final - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LinearizeVectorShuffle(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} +/// Collapse the 2 innermost dimensions of a vector.insert +/// +/// BEFORE: +/// %o = vector.insert %arg0, %arg1[1, 2] : vector<5x7xi8> into +/// vector<2x3x5x7xi8> +/// +/// AFTER: +/// %0 = vector.shape_cast %arg0 : vector<5x7xi8> to vector<35xi8> +/// %1 = vector.shape_cast %arg1 : vector<2x3x5x7xi8> to vector<2x3x35xi8> +/// %2 = vector.insert %0, %1[1, 2] : vector<35xi8> into vector<2x3x35xi8> +/// %o = vector.shape_cast %2 : vector<2x3x35xi8> to vector<2x3x5x7xi8> +struct CollapseInnerInsert final + : public OpRewritePatternWithPreCondition { + + CollapseInnerInsert(MLIRContext *context, const PreCondition &p, + PatternBenefit benefit = 1) + : OpRewritePatternWithPreCondition(context, p, + benefit) {} LogicalResult - matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - VectorType dstType = - getTypeConverter()->convertType(shuffleOp.getType()); - assert(dstType && "vector type destination expected."); - - Value vec1 = adaptor.getV1(); - Value vec2 = adaptor.getV2(); - int shuffleSliceLen = 1; - int rank = shuffleOp.getV1().getType().getRank(); - - // If rank > 1, we need to do the shuffle in the granularity of slices - // instead of scalars. Size of the slice is equal to the rank-1 innermost - // dims. Mask of the shuffle op specifies which slice to take from the - // outermost dim. - if (rank > 1) { - llvm::ArrayRef shape = shuffleOp.getV1().getType().getShape(); - for (unsigned i = 1; i < shape.size(); ++i) { - shuffleSliceLen *= shape[i]; - } - } + postConditionMatchAndRewrite(vector::InsertOp insertOp, + PatternRewriter &rewriter) const final { - // For each value in the mask, we generate the indices of the source vectors - // that need to be shuffled to the destination vector. If shuffleSliceLen > - // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of - // elements) instead of scalars. - ArrayRef mask = shuffleOp.getMask(); - int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen; - llvm::SmallVector indices(totalSizeOfShuffledElmnts); - for (auto [i, value] : llvm::enumerate(mask)) { - std::iota(indices.begin() + shuffleSliceLen * i, - indices.begin() + shuffleSliceLen * (i + 1), - shuffleSliceLen * value); - } + auto toInsertType = dyn_cast(insertOp.getValueToStoreType()); + if (!toInsertType) + return rewriter.notifyMatchFailure( + insertOp, + "value to insert is scalar, canot collapse inner dimensions"); + + VectorType reducedType = getReducedType(toInsertType); + if (reducedType == toInsertType) + return rewriter.notifyMatchFailure( + insertOp, "value to insert has an irreducible type"); + + Value reducedToStore = + getReducedValue(rewriter, insertOp.getValueToStore()); + Value reducedDst = getReducedValue(rewriter, insertOp.getDest()); + + auto reducedInsert = rewriter.create( + insertOp.getLoc(), reducedToStore, reducedDst, + insertOp.getMixedPosition()); + + rewriter.replaceOpWithNewOp( + insertOp, insertOp.getType(), reducedInsert); - rewriter.replaceOpWithNewOp(shuffleOp, dstType, vec1, - vec2, indices); return success(); } }; -/// This pattern converts the ExtractOp to a ShuffleOp that works on a -/// linearized vector. -/// Following, -/// vector.extract %source [ position ] -/// is converted to : -/// %source_1d = vector.shape_cast %source -/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] -/// %out_nd = vector.shape_cast %out_1d -/// `shuffle_indices_1d` is computed using the position of the original extract. -struct LinearizeVectorExtract final - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LinearizeVectorExtract(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} +/// Collapse the outermost 2 dimensions of a vector.extract +/// +/// BEFORE: +/// %o = vector.extract %arg0[1, 2] : vector<5x7xi8> from vector<2x3x5x7xi8> +/// +/// AFTER: +/// %0 = vector.shape_cast %arg0 : vector<2x3x5x7xi8> to vector<6x5x7xi8> +/// %o = vector.extract %0[5] : vector<5x7xi8> from vector<6x5x7xi8> +struct CollapseOuterExtract final + : public OpRewritePatternWithPreCondition { + + CollapseOuterExtract(MLIRContext *context, const PreCondition &p, + PatternBenefit benefit = 1) + : OpRewritePatternWithPreCondition(context, p, + benefit) {} + LogicalResult - matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Skip if result is not a vector type - if (!isa(extractOp.getType())) - return rewriter.notifyMatchFailure(extractOp, - "scalar extract not supported"); - Type dstTy = getTypeConverter()->convertType(extractOp.getType()); - assert(dstTy && "expected 1-D vector type"); + postConditionMatchAndRewrite(vector::ExtractOp extractOp, + PatternRewriter &rewriter) const final { - // Dynamic position is not supported. - if (extractOp.hasDynamicPosition()) - return rewriter.notifyMatchFailure(extractOp, - "dynamic position is not supported."); + VectorType srcType = extractOp.getVector().getType(); - llvm::ArrayRef shape = extractOp.getVector().getType().getShape(); - int64_t size = extractOp.getVector().getType().getNumElements(); + SmallVector position = extractOp.getMixedPosition(); + if (!collapseFront(position, srcType, rewriter)) + return rewriter.notifyMatchFailure( + extractOp, "failed to collapse the outermost 2 dimensions"); - // Compute linearized offset. - int64_t linearizedOffset = 0; - llvm::ArrayRef offsets = extractOp.getStaticPosition(); - for (auto [i, off] : llvm::enumerate(offsets)) { - size /= shape[i]; - linearizedOffset += offsets[i] * size; - } + Value reducedIn = getReducedValue(rewriter, extractOp.getVector(), 1); + rewriter.replaceOpWithNewOp(extractOp, reducedIn, + position); - llvm::SmallVector indices(size); - std::iota(indices.begin(), indices.end(), linearizedOffset); - rewriter.replaceOpWithNewOp( - extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices); + return success(); + } +}; + +/// Collapse the outermost 2 dimensions of a vector.insert +/// +/// BEFORE: +/// %o = vector.insert %arg0, %arg1[1, 2] : vector<5x7xi8> into +/// vector<2x3x5x7xi8> +/// +/// AFTER: +/// %0 = vector.shape_cast %arg1 : vector<5x7xi8> to vector<6x5x7xi8> +/// %1 = vector.insert %arg0, %0[5] : vector<5x7xi8> into vector<6x5x7xi8> +/// %o = vector.shape_cast %1 : vector<6x5x7xi8> to vector<2x3x5x7xi8> +struct CollapseOuterInsert final + : public OpRewritePatternWithPreCondition { + + CollapseOuterInsert(MLIRContext *context, const PreCondition &p, + PatternBenefit benefit = 1) + : OpRewritePatternWithPreCondition(context, p, + benefit) {} + + LogicalResult + postConditionMatchAndRewrite(vector::InsertOp insertOp, + PatternRewriter &rewriter) const final { + + VectorType dstType = insertOp.getDestVectorType(); + + SmallVector position = insertOp.getMixedPosition(); + if (!collapseFront(position, dstType, rewriter)) + return rewriter.notifyMatchFailure( + insertOp, "failed to collapse the outermost 2 dimensions"); + + Value reducedIn = getReducedValue(rewriter, insertOp.getDest(), 1); + + Value newInsert = rewriter.create( + insertOp.getLoc(), insertOp.getValueToStore(), reducedIn, position); + + rewriter.replaceOpWithNewOp( + insertOp, insertOp.getType(), newInsert); return success(); } }; -/// This pattern linearizes `vector.insert` operations. It generates a 1-D -/// version of the `vector.insert` operation when inserting a scalar into a -/// vector. It generates a 1-D `vector.shuffle` operation when inserting a -/// vector into another vector. +/// Collapse the outermost 2 dimensions of a vector.splat /// -/// Example #1: +/// BEFORE: +/// %o = vector.splat %arg0 : vector<2x3x5xi8> /// -/// %0 = vector.insert %source, %destination[0] : -/// vector<2x4xf32> into vector<2x2x4xf32> +/// AFTER: +/// %0 = vector.splat %arg0 : vector<2x15xi8> +/// %o = vector.shape_cast %0 : vector<2x15xi8> to vector<2x3x5xi8> +struct CollapseInnerSplat final + : public OpRewritePatternWithPreCondition { + + CollapseInnerSplat(MLIRContext *context, const PreCondition &p, + PatternBenefit benefit = 1) + : OpRewritePatternWithPreCondition(context, p, benefit) { + } + + LogicalResult + postConditionMatchAndRewrite(vector::SplatOp splatOp, + PatternRewriter &rewriter) const final { + + auto splatType = splatOp.getType(); + auto reducedType = getReducedType(splatType); + if (reducedType == splatType) + return rewriter.notifyMatchFailure(splatOp, "splat type is irreducible"); + + rewriter.replaceOpWithNewOp(splatOp, reducedType, + splatOp.getOperand()); + + return success(); + } +}; + +/// Convert an vector.insert (rank-2 to rank-1) to a vector.shuffle /// -/// is converted to: +/// Conversion of higher rank vector.insert operations to vector.shuffle +/// require rank-reducing patterns to be applied first. /// -/// %0 = vector.shape_cast %source : vector<2x4xf32> to vector<8xf32> -/// %1 = vector.shape_cast %destination : -/// vector<2x2x4xf32> to vector<16xf32> -/// %2 = vector.shuffle %1, %0 [16, 17, 18, 19, 20, 21, 22, 23 -/// 8, 9, 10, 11, 12, 13, 14, 15] : -/// vector<16xf32>, vector<8xf32> -/// %3 = vector.shape_cast %2 : vector<16xf32> to vector<2x2x4xf32> +/// BEFORE +/// %insert_2d = vector.insert %src %dst [ position ] /// -/// Example #2: +/// AFTER +/// %src_1d = vector.shape_cast %src : [...] +/// %dst_1d = vector.shape_cast %dst : [...] +/// %out_1d = vector.shuffle %dst_1d, %src_1d [ shuffle_indices ] +/// %out_2d = vector.shape_cast %out_1d : [...] /// -/// %0 = vector.insert %source, %destination[1, 2]: f32 into vector<2x4xf32> +/// `shuffle_indices` is computed from `position`. +struct ConvertInsertToShuffle final + : public OpRewritePatternWithPreCondition { + + ConvertInsertToShuffle(MLIRContext *context, const PreCondition &p, + PatternBenefit benefit = 1) + : OpRewritePatternWithPreCondition(context, p, + benefit) {} + + LogicalResult + postConditionMatchAndRewrite(vector::InsertOp insertOp, + PatternRewriter &rewriter) const final { + + if (insertOp.getDestVectorType().isScalable()) + return rewriter.notifyMatchFailure( + insertOp, "conversion to shuffle not possible with scalable vectors"); + + const Value toStore = insertOp.getValueToStore(); + auto toInsertType = dyn_cast(toStore.getType()); + if (!toInsertType || toInsertType.getRank() != 1) + return rewriter.notifyMatchFailure( + insertOp, + "this pattern only handles the case where rank-1 vectors are stored"); + + VectorType dstType = insertOp.getType(); + if (dstType.getRank() != 2) + return rewriter.notifyMatchFailure( + insertOp, "this pattern only handles the case where rank-2 vectors " + "are inserted into"); + + int64_t offset = insertOp.getStaticPosition()[0]; + if (offset == ShapedType::kDynamic) + return rewriter.notifyMatchFailure( + insertOp, "conversion to shuffle requires all static positions"); + + int64_t nSmall = toInsertType.getNumElements(); + int64_t nLarge = dstType.getNumElements(); + + SmallVector mask(nLarge); + auto *iter = mask.begin() + offset * nSmall; + std::iota(mask.begin(), mask.end(), 0); + std::iota(iter, iter + nSmall, nLarge); + + VectorType collapsedType = getReducedType(dstType, 1); + assert(collapsedType != dstType && "rank-2 to rank-1 failed"); + + Value collapsedDst = getReducedValue(rewriter, insertOp.getDest(), 1); + Value shuffled = rewriter.create( + insertOp.getLoc(), collapsedType, collapsedDst, toStore, mask); + + rewriter.replaceOpWithNewOp(insertOp, dstType, + shuffled); + + return success(); + } +}; + +/// Convert a vector.extract (rank-2 to rank-1) to a vector.shuffle /// -/// is converted to: +/// Conversion of higher rank vector.extract operations to vector.shuffle +/// require rank-reducing patterns to be applied first. /// -/// %0 = vector.shape_cast %destination : vector<2x4xf32> to vector<8xf32> -/// %1 = vector.insert %source, %0[6]: f32 into vector<8xf32> -/// %2 = vector.shape_cast %1 : vector<8xf32> to vector<2x4xf32> +/// BEFORE: +/// %extract_1d = vector.extract %src_2d [ position ] /// -struct LinearizeVectorInsert final - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LinearizeVectorInsert(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} +/// AFTER: +/// %src_1d = vector.shape_cast %src_2d : [...] +/// %out_1d = vector.shuffle %src_1d, %src_1d [ shuffle_indices ] [...] +struct ConvertExtractToShuffle final + : public OpRewritePatternWithPreCondition { + + ConvertExtractToShuffle(MLIRContext *context, const PreCondition &p, + PatternBenefit benefit = 1) + : OpRewritePatternWithPreCondition(context, p, + benefit) {} + LogicalResult - matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - VectorType dstTy = getTypeConverter()->convertType( - insertOp.getDestVectorType()); - assert(dstTy && "vector type destination expected."); - - // Dynamic position is not supported. - if (insertOp.hasDynamicPosition()) - return rewriter.notifyMatchFailure(insertOp, - "dynamic position is not supported."); - auto srcTy = insertOp.getValueToStoreType(); - auto srcAsVec = dyn_cast(srcTy); - uint64_t srcSize = srcAsVec ? srcAsVec.getNumElements() : 1; - - auto dstShape = insertOp.getDestVectorType().getShape(); - const auto dstSize = insertOp.getDestVectorType().getNumElements(); - auto dstSizeForOffsets = dstSize; - - // Compute linearized offset. - int64_t linearizedOffset = 0; - auto offsetsNd = insertOp.getStaticPosition(); - for (auto [dim, offset] : llvm::enumerate(offsetsNd)) { - dstSizeForOffsets /= dstShape[dim]; - linearizedOffset += offset * dstSizeForOffsets; - } + postConditionMatchAndRewrite(vector::ExtractOp extractOp, + PatternRewriter &rewriter) const final { - Location loc = insertOp.getLoc(); - Value valueToStore = adaptor.getValueToStore(); + if (extractOp.getSourceVectorType().isScalable()) + return rewriter.notifyMatchFailure( + extractOp, + "conversion to shuffle not possible with scalable vectors"); - if (!isa(valueToStore.getType())) { - // Scalar case: generate a 1-D insert. - Value result = rewriter.createOrFold( - loc, valueToStore, adaptor.getDest(), linearizedOffset); - rewriter.replaceOp(insertOp, result); - return success(); - } + VectorType srcType = extractOp.getSourceVectorType(); + if (srcType.getRank() != 2) + return rewriter.notifyMatchFailure( + extractOp, "this pattern only handles the case where rank-2 vectors " + "are extracted from"); - // Vector case: generate a shuffle. - llvm::SmallVector indices(dstSize); - auto *origValsUntil = indices.begin(); - std::advance(origValsUntil, linearizedOffset); + auto dstType = dyn_cast(extractOp.getType()); + if (!dstType || dstType.getRank() != 1) + return rewriter.notifyMatchFailure( + extractOp, "this pattern only handles the case where rank-1 vectors " + "are extracted"); - // Original values that remain [0, offset). - std::iota(indices.begin(), origValsUntil, 0); - auto *newValsUntil = origValsUntil; - std::advance(newValsUntil, srcSize); - // New values [offset, offset+srcNumElements). - std::iota(origValsUntil, newValsUntil, dstSize); - // The rest of original values [offset+srcNumElements, end); - std::iota(newValsUntil, indices.end(), linearizedOffset + srcSize); + int64_t offset = extractOp.getStaticPosition()[0]; + if (offset == ShapedType::kDynamic) + return rewriter.notifyMatchFailure( + extractOp, "conversion to shuffle requires all static positions"); + + Value collapsedIn = getReducedValue(rewriter, extractOp.getVector(), 1); + int64_t nSmall = dstType.getNumElements(); + + SmallVector mask(nSmall); + std::iota(mask.begin(), mask.end(), offset * nSmall); - Value result = rewriter.createOrFold( - loc, dstTy, adaptor.getDest(), valueToStore, indices); + rewriter.replaceOpWithNewOp( + extractOp, dstType, collapsedIn, collapsedIn, mask); - rewriter.replaceOp(insertOp, result); return success(); } }; -/// This pattern converts the BitCastOp that works on nD (n > 1) -/// vectors to a BitCastOp that works on linearized vectors. -/// Following, -/// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16> -/// is converted to : -/// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32> -/// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16> -/// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16> -struct LinearizeVectorBitCast final - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LinearizeVectorBitCast(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} +/// BEFORE +/// %out_nd = vector.extract_strided_slice %source_nd +/// { offsets = [..], strides = [..], sizes = [..] } +/// +/// AFTER +/// %source_1d = vector.shape_cast %source_nd [...] +/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] +/// %out_nd = vector.shape_cast %out_1d [...] +/// +/// `shuffle_indices_1d` is computed using the offsets and sizes of the +/// original vector.extract_strided_slice operation. +struct ConvertExtractStridedToShuffle final + : public OpRewritePatternWithPreCondition { + + ConvertExtractStridedToShuffle(MLIRContext *context, const PreCondition &p, + PatternBenefit benefit = 1) + : OpRewritePatternWithPreCondition( + context, p, benefit) {} + LogicalResult - matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto resType = getTypeConverter()->convertType(castOp.getType()); - assert(resType && "expected 1-D vector type"); - rewriter.replaceOpWithNewOp(castOp, resType, - adaptor.getSource()); - return mlir::success(); + postConditionMatchAndRewrite(vector::ExtractStridedSliceOp extractOp, + PatternRewriter &rewriter) const final { + + if (extractOp.hasNonUnitStrides()) + return rewriter.notifyMatchFailure( + extractOp, "conversion to shuffle requires unit strides"); + + if (extractOp.getSourceVectorType().isScalable()) + return rewriter.notifyMatchFailure( + extractOp, + "conversion to shuffle not possible with scalable vectors"); + + VectorType extractType = extractOp.getType(); + + FailureOr> offsets = + intsFromArrayAttr(extractOp.getOffsets()); + if (failed(offsets)) + return rewriter.notifyMatchFailure(extractOp, + "failed to get integer offsets"); + + SmallVector indices = getStridedSliceInsertionIndices( + extractType.getShape(), extractOp.getSourceVectorType().getShape(), + offsets.value()); + + FailureOr flatIn = + getCollapsedToRankOne(extractOp.getVector(), rewriter); + FailureOr flatOutType = getRankOneType(extractType); + + assert(succeeded(flatIn) && succeeded(flatOutType) && + "failed to linearize input or type"); + + Value shuffled = rewriter.create( + extractOp.getLoc(), flatOutType.value(), flatIn.value(), flatIn.value(), + indices); + + rewriter.replaceOpWithNewOp(extractOp, extractType, + shuffled); + + return success(); } }; -/// This pattern converts the SplatOp to work on a linearized vector. -/// Following, -/// vector.splat %value : vector<4x4xf32> -/// is converted to: -/// %out_1d = vector.splat %value : vector<16xf32> -/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> -struct LinearizeVectorSplat final - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context, - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} +/// This pattern converts a vector.insert_strided_slice operation into a +/// vector.shuffle operation that has rank-1 (linearized) operands and result. +/// +/// BEFORE +/// %0 = vector.insert_strided_slice %to_store, %into +/// {offsets = [1, 0, 0, 0], strides = [1, 1]} +/// : vector<2x2xi8> into vector<2x1x3x2xi8> +/// AFTER +/// %to_store_1d +/// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8> +/// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8> +/// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ] +/// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8> +/// +/// where shuffle_indices_1d in this case is +/// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11]. +/// ^^^^^^^^^^^^^^ +/// to_store_1d +struct ConvertInsertStridedToShuffle final + : public OpRewritePatternWithPreCondition { + + ConvertInsertStridedToShuffle(MLIRContext *context, const PreCondition &p, + PatternBenefit benefit = 1) + : OpRewritePatternWithPreCondition( + context, p, benefit) {} LogicalResult - matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto dstTy = getTypeConverter()->convertType(splatOp.getType()); - if (!dstTy) - return rewriter.notifyMatchFailure(splatOp, "cannot convert type."); - rewriter.replaceOpWithNewOp(splatOp, adaptor.getInput(), - dstTy); + postConditionMatchAndRewrite(vector::InsertStridedSliceOp insertOp, + PatternRewriter &rewriter) const final { + + if (insertOp.hasNonUnitStrides()) + return rewriter.notifyMatchFailure( + insertOp, "conversion to shuffle requires unit strides"); + + if (insertOp.getSourceVectorType().isScalable()) + return rewriter.notifyMatchFailure( + insertOp, "conversion to shuffle not possible with scalable vectors"); + + TypedValue toStore = insertOp.getValueToStore(); + VectorType inputType = toStore.getType(); + ArrayRef inputShape = inputType.getShape(); + + VectorType outputType = insertOp.getType(); + ArrayRef outputShape = outputType.getShape(); + int64_t nOutputElements = outputType.getNumElements(); + + FailureOr> offsets = + intsFromArrayAttr(insertOp.getOffsets()); + if (failed(offsets)) + return rewriter.notifyMatchFailure(insertOp, + "failed to get integer offsets"); + + SmallVector sliceIndices = getStridedSliceInsertionIndices( + inputShape, outputShape, offsets.value()); + + SmallVector indices(nOutputElements); + std::iota(indices.begin(), indices.end(), 0); + for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) { + indices[sliceIndex] = index + nOutputElements; + } + + FailureOr flatToStore = getCollapsedToRankOne(toStore, rewriter); + assert(succeeded(flatToStore) && "failed to linearize value to store"); + + FailureOr flatDest = + getCollapsedToRankOne(insertOp.getDest(), rewriter); + assert(succeeded(flatDest) && + "failed to linearize destination of insert strided slice"); + + FailureOr flatDestType = getRankOneType(outputType); + assert(succeeded(flatDestType) && + "failed to get rank-1 type for destination of insert strided slice"); + + Value shuffled = rewriter.create( + insertOp.getLoc(), flatDestType.value(), flatDest.value(), + flatToStore.value(), indices); + + rewriter.replaceOpWithNewOp(insertOp, outputType, + shuffled); + return success(); } }; /// This pattern converts the CreateMaskOp to work on a linearized vector. -/// It currently supports only 2D masks with a unit outer dimension. -/// Following, +/// +/// BEFORE: /// vector.create_mask %arg0, %arg1 : vector<1x4xi1> -/// is converted to: +/// +/// AFTER: /// %zero = arith.constant 0 : index /// %cmpi = arith.cmpi sgt, %arg0, %zero : index /// %index = arith.index_cast %cmpi : i1 to index /// %mul = arith.andi %index, %arg1 : index /// %mask = vector.create_mask %mul : vector<4xi1> /// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1> -struct LinearizeVectorCreateMask final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +/// +/// There can be at most one non-unit dimension in the mask type. +struct SqueezeCreateMaskUnitDims final + : OpRewritePatternWithPreCondition { - LinearizeVectorCreateMask(const TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} + SqueezeCreateMaskUnitDims(MLIRContext *context, const PreCondition &p, + PatternBenefit benefit = 1) + : OpRewritePatternWithPreCondition(context, p, + benefit) {} LogicalResult - matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = createMaskOp.getLoc(); - VectorType srcTy = createMaskOp.getType(); - auto srcShape = srcTy.getShape(); - if (srcShape.size() != 2) - return rewriter.notifyMatchFailure(createMaskOp, - "only 2D mask is supported."); - - if (srcShape[0] != 1) + postConditionMatchAndRewrite(vector::CreateMaskOp maskOp, + PatternRewriter &rewriter) const final { + + VectorType maskType = maskOp.getType(); + + if (!isCreateMaskWithAtMostOneNonUnit(maskOp)) + return rewriter.notifyMatchFailure( + maskOp, "mask type must have at most one non-unit dimension"); + + Location loc = maskOp.getLoc(); + + FailureOr flatType = getRankOneType(maskType); + if (failed(flatType)) + return rewriter.notifyMatchFailure(maskOp, + "failed to convert to rank-1 type"); + + if (flatType.value() == maskType) return rewriter.notifyMatchFailure( - createMaskOp, "only unit outer dimension is supported."); - - auto dstTy = getTypeConverter()->convertType(srcTy); - if (!dstTy) - return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type."); - - // Compare the first operand with 0. If it is greater than 0, the - // corresponding mask element is set to true, otherwise false. - // The result of the comparison is then multiplied with - // the second operand of create_mask to get the 1D mask. - auto firstOperand = adaptor.getOperands().front(); - auto zero = rewriter.create(loc, 0); - auto isNonZero = rewriter.createOrFold( - loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero); - auto isNonZeroIndex = rewriter.createOrFold( - loc, rewriter.getIndexType(), isNonZero); - auto secondOperand = adaptor.getOperands().back(); - auto maskSize = rewriter.createOrFold( - loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand); - - auto newMask = - rewriter.create(loc, dstTy, maskSize); - rewriter.replaceOp(createMaskOp, newMask); + maskOp, "mask type is already rank linearized"); + + // First, get the product of (clamped) mask sizes in the unit-dimensions. + Value prod = rewriter.create(loc, 1, 1); + Value zero = rewriter.create(loc, 0); + int nonUnitDim = -1; + for (unsigned i = 0; i < maskType.getRank(); ++i) { + + Value dimRange = maskOp.getOperands()[i]; + int64_t dimSize = maskType.getDimSize(i); + if (dimSize <= 1) { + Value nxt = rewriter.create( + loc, arith::CmpIPredicate::sgt, dimRange, zero); + prod = rewriter.create(loc, prod, nxt); + + } else { + assert(nonUnitDim == -1 && "at most 1 non-unit expected"); + nonUnitDim = i; + } + } + + prod = + rewriter.create(loc, rewriter.getIndexType(), prod); + + // Finally, multiply by the size in the dimension that is not unit. + if (nonUnitDim != -1) { + Value v = maskOp.getOperands()[nonUnitDim]; + prod = rewriter.create(loc, prod, v); + } + + Value newMask = + rewriter.create(loc, flatType.value(), prod); + + rewriter.replaceOpWithNewOp(maskOp, maskType, newMask); + return success(); } }; -} // namespace +enum class LinearizePattern { + // Patterns to collapse the 2 innermost dimensions: + CollapseInnerVectorizable, + CollapseInnerVectorBitCast, + CollapseInnerVectorShuffle, + CollapseInnerExtractStrided, + CollapseInnerInsertStrided, + CollapseInnerExtract, + CollapseInnerInsert, + CollapseInnerSplat, + + // Patterns to collapse the 2 outermost dimensions: + CollapseOuterExtract, + CollapseOuterInsert, + + // Patterns to convert ops to shuffle: + ConvertInsertToShuffle, + ConvertExtractToShuffle, + ConvertInsertStridedToShuffle, + ConvertExtractStridedToShuffle, + + // Patterns to remove unit dimensions: + SqueezeCreateMaskUnitDims, + + // The number of patterns in this enum: + N +}; -/// This method defines the set of operations that are linearizable, and hence -/// that are considered illegal for the conversion target. -static bool isLinearizable(Operation *op) { - - // Only ops that are in the vector dialect, are ConstantLike, or - // are Vectorizable might be linearized currently. - StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace(); - StringRef opDialect = op->getDialect()->getNamespace(); - bool supported = (opDialect == vectorDialect) || - op->hasTrait() || - op->hasTrait(); - if (!supported) - return false; +/// This class contains functions to control the set of linearization patterns +/// to include for the conversion, and their priority. +struct VectorLinearizePatterns { - return TypeSwitch(op) - // As type legalization is done with vector.shape_cast, shape_cast - // itself cannot be linearized (will create new shape_casts to linearize - // ad infinitum). - .Case([&](auto) { return false; }) - // The operations - // - vector.extract_strided_slice - // - vector.extract - // - vector.insert_strided_slice - // - vector.insert - // are linearized to a rank-1 vector.shuffle by the current patterns. - // vector.shuffle only supports fixed size vectors, so it is impossible to - // use this approach to linearize these ops if they operate on scalable - // vectors. - .Case( - [&](vector::ExtractStridedSliceOp extractOp) { - return !extractOp.getType().isScalable(); - }) - .Case( - [&](vector::InsertStridedSliceOp insertOp) { - return !insertOp.getType().isScalable(); - }) - .Case([&](vector::InsertOp insertOp) { - return !insertOp.getType().isScalable(); - }) - .Case([&](vector::ExtractOp extractOp) { - return !extractOp.getSourceVectorType().isScalable(); - }) - .Default([&](auto) { return true; }); -} +public: + /// By default all patterns are enabled and have benefit 1. + VectorLinearizePatterns() { + enabled.fill(true); + benefits.fill(PatternBenefit(1)); + } -void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter, - ConversionTarget &target) { + /// Add the patterns enabled for the conversion to `patterns`. + void addToPatternSet(RewritePatternSet &patterns, + const PreCondition &pc) const; - auto convertType = [](Type type) -> std::optional { - VectorType vectorType = dyn_cast(type); - if (!vectorType || !isLinearizableVector(vectorType)) - return type; + VectorLinearizePatterns &enable(LinearizePattern id, bool e = true) { + enabled[static_cast(id)] = e; + return *this; + } - VectorType linearizedType = - VectorType::get(vectorType.getNumElements(), - vectorType.getElementType(), vectorType.isScalable()); - return linearizedType; - }; - typeConverter.addConversion(convertType); + VectorLinearizePatterns &enableAll(bool e = true) { + enabled.fill(e); + return *this; + } - auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, - Location loc) -> Value { - if (inputs.size() != 1) - return nullptr; + bool isEnabled(LinearizePattern id) const { + return enabled[static_cast(id)]; + } - Value value = inputs.front(); - if (!isa(type) || !isa(value.getType())) - return nullptr; + PatternBenefit getBenefit(LinearizePattern id) const { + return benefits[static_cast(id)]; + } - return builder.create(loc, type, value); - }; - typeConverter.addSourceMaterialization(materializeCast); - typeConverter.addTargetMaterialization(materializeCast); - - target.markUnknownOpDynamicallyLegal( - [=](Operation *op) -> std::optional { - if (!isLinearizable(op)) - return true; - // This will return true if, for all operand and result types `t`, - // convertType(t) = t. This is true if there are no rank>=2 vectors. - return typeConverter.isLegal(op); - }); + VectorLinearizePatterns &setBenefits(PatternBenefit benefit) { + benefits.fill(benefit); + return *this; + } + + VectorLinearizePatterns &setBenefit(LinearizePattern id, + PatternBenefit benefit) { + getBenefitRef(id) = benefit; + return *this; + } + + VectorLinearizePatterns &incrementBenefit(LinearizePattern id, + unsigned inc = 1) { + getBenefitRef(id) = getBenefit(id).getBenefit() + 1; + return *this; + } + +private: + std::array(LinearizePattern::N)> enabled; + std::array(LinearizePattern::N)> + benefits; + + PatternBenefit &getBenefitRef(LinearizePattern id) { + unsigned idInt = static_cast(id); + assert(idInt < static_cast(LinearizePattern::N) && + "invalid linearization pattern id"); + return benefits[idInt]; + } + + template + void addIfEnabled(RewritePatternSet &patterns, + std::function preCond, + LinearizePattern id) const { + if (isEnabled(id)) { + patterns.add(patterns.getContext(), preCond, getBenefit(id)); + } + } +}; + +void VectorLinearizePatterns::addToPatternSet(RewritePatternSet &patterns, + const PreCondition &pc) const { + + using LP = LinearizePattern; + + addIfEnabled(patterns, pc, + LP::CollapseInnerVectorizable); + + addIfEnabled(patterns, pc, + LP::CollapseInnerVectorBitCast); + + addIfEnabled(patterns, pc, + LP::CollapseInnerVectorShuffle); + + addIfEnabled(patterns, pc, + LP::CollapseInnerExtractStrided); + + addIfEnabled(patterns, pc, + LP::CollapseInnerInsertStrided); + + addIfEnabled(patterns, pc, LP::CollapseInnerExtract); + + addIfEnabled(patterns, pc, LP::CollapseInnerExtract); + + addIfEnabled(patterns, pc, LP::CollapseOuterExtract); + + addIfEnabled(patterns, pc, LP::CollapseOuterInsert); + + addIfEnabled(patterns, pc, LP::CollapseInnerSplat); + + addIfEnabled(patterns, pc, + LP::ConvertInsertToShuffle); + + addIfEnabled(patterns, pc, + LP::ConvertExtractToShuffle); + + addIfEnabled( + patterns, pc, LP::ConvertInsertStridedToShuffle); + + addIfEnabled( + patterns, pc, LP::ConvertExtractStridedToShuffle); + + addIfEnabled(patterns, pc, + LP::SqueezeCreateMaskUnitDims); } -void mlir::vector::populateVectorLinearizeBasePatterns( - const TypeConverter &typeConverter, const ConversionTarget &target, - RewritePatternSet &patterns) { - patterns - .add( - typeConverter, patterns.getContext()); +} // namespace + +void vector::populateForVectorLinearize(RewritePatternSet &patterns, + const PreCondition &preCondition, + PatternBenefit benefit) { + VectorLinearizePatterns vlp; + + // We want to perform rank reduction as much as possible before converting to + // shuffle. We do this by setting the benefit of all the patterns that do not + // convert to shuffle to be 1 higher. + vlp.enableAll(true) + .setBenefits(benefit.getBenefit() + 1) + .setBenefit(LinearizePattern::ConvertExtractToShuffle, benefit) + .setBenefit(LinearizePattern::ConvertInsertToShuffle, benefit) + .setBenefit(LinearizePattern::ConvertExtractStridedToShuffle, benefit) + .setBenefit(LinearizePattern::ConvertInsertStridedToShuffle, benefit); + + vlp.addToPatternSet(patterns, preCondition); } -void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( - const TypeConverter &typeConverter, const ConversionTarget &target, - RewritePatternSet &patterns) { - patterns.add(typeConverter, - patterns.getContext()); +void vector::populateForStridedRankReduction(RewritePatternSet &patterns, + PatternBenefit benefit) { + VectorLinearizePatterns() + .enableAll(false) + .setBenefits(benefit) + .enable(LinearizePattern::CollapseInnerExtractStrided) + .enable(LinearizePattern::CollapseInnerInsertStrided) + .addToPatternSet(patterns, [](auto x) { return success(); }); } diff --git a/mlir/test/Dialect/Vector/linearize-subject-to-bitwidth.mlir b/mlir/test/Dialect/Vector/linearize-subject-to-bitwidth.mlir deleted file mode 100644 index 739fb2fb8b68b..0000000000000 --- a/mlir/test/Dialect/Vector/linearize-subject-to-bitwidth.mlir +++ /dev/null @@ -1,58 +0,0 @@ -// RUN: mlir-opt %s -split-input-file -test-bit-width-constrained-vector-linearize=target-vector-bitwidth=128 | FileCheck %s --check-prefixes=ALL,BW-128 -// RUN: mlir-opt %s -split-input-file -test-bit-width-constrained-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0 - -// A vector<2x2xf32> has inner-most dimension with 64-bits. Check that at -// bitwidth threshold 128 (>= 64), operations are linearized, and at -// bitwidth threshold 0 (< 64), operations are not linearized. - -// ALL-LABEL: test_result_bitwidth_64 -func.func @test_result_bitwidth_64(%arg0: vector<2x2xf32>) -> vector<2x2xf32> { - - // BW-128: arith.constant {{.*}} vector<4xf32> - // BW-0: arith.constant {{.*}} vector<2x2xf32> - %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32> - - // BW-128: math.sin {{.*}} vector<4xf32> - // BW-0: math.sin {{.*}} vector<2x2xf32> - %1 = math.sin %arg0 : vector<2x2xf32> - - return %0 : vector<2x2xf32> -} - -// ----- - -// The size of the 'index' type is backend specific, so we cannot guarantee that -// the inner-most dimension below (of size 2*nbBits(index)) is below any bitwidth -// threshold. Test that operations with vectors of index type are not linearized. - -// ALL-LABEL: test_index_no_linearize -func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> { - - // BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex> - // BW-0: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex> - %0 = arith.addi %arg0, %arg1 : vector<2x2xindex> - return %0 : vector<2x2xindex> -} - -// ----- - -// The logic for the insert op with regards to the bitwidth threshold is -// different to the other ops, so we test it here. Specifically, the logic -// is based on the bitwidth of the value to store. - -// ALL-LABEL: test_vector_insert -// ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> { -func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> { - - // BW-128-DAG: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32> - // BW-128-DAG: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32> - // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]] - // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32> - // BW-128: return %[[RES]] : vector<2x8x4xf32> - - // BW-0: %[[RES:.*]] = vector.insert %[[SRC]], %[[DEST]] [0] : vector<8x4xf32> into vector<2x8x4xf32> - // BW-0: return %[[RES]] : vector<2x8x4xf32> - - %0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32> - return %0 : vector<2x8x4xf32> -} diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir deleted file mode 100644 index 894171500d9d6..0000000000000 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ /dev/null @@ -1,480 +0,0 @@ -// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s - -// CHECK-LABEL: test_linearize -// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>) -func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> { - - // CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32> - // CHECK: %[[CST:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32> - // CHECK: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32> - %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32> - - // CHECK: %{{.*}} = math.sin %[[ARG]] : vector<4xf32> - %1 = math.sin %arg0 : vector<2x2xf32> - - // CHECK: %{{.*}} = arith.addf %[[ARG]], %[[CST]] : vector<4xf32> - %2 = arith.addf %arg0, %0 : vector<2x2xf32> - - // CHECK: return %[[RES]] : vector<2x2xf32> - return %0 : vector<2x2xf32> -} - -// ----- - -// CHECK-LABEL: test_linearize_poison -func.func @test_linearize_poison() -> vector<2x2xf32> { - - // CHECK: %[[POISON:.*]] = ub.poison : vector<4xf32> - // CHECK: %[[RES:.*]] = vector.shape_cast %[[POISON]] : vector<4xf32> to vector<2x2xf32> - %0 = ub.poison : vector<2x2xf32> - - // CHECK: return %[[RES]] : vector<2x2xf32> - return %0 : vector<2x2xf32> -} - -// ----- - -// CHECK-LABEL: test_partial_linearize -// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>) -func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>) -> vector<2x2xf32> { - - // CHECK-DAG: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32> - // CHECK-DAG: %[[ARG2:.*]] = vector.shape_cast %[[ORIG_ARG2]] : vector<4x4xf32> to vector<16xf32> - // CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : vector<4xf32> - // CHECK: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32> - %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32> - - // CHECK: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<16xf32> - %5 = arith.constant dense<[[1.0, 2.0, 3.0, 4.0], [1.0, 2.0,3.0, 4.0], [1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 5.0, 6.0]]> : vector<4x4xf32> - - // Arith and math ops are handled in generic way, check some of them - // CHECK: %[[SIN:.*]] = math.sin %[[ARG]] : vector<4xf32> - %1 = math.sin %arg0 : vector<2x2xf32> - - // CHECK: %[[SIN1:.*]] = math.sin %[[ARG2]] : vector<16xf32> - %6 = math.sin %arg1 : vector<4x4xf32> - - // CHECK: %{{.*}} = arith.addf %[[ARG]], %[[CST]] : vector<4xf32> - %2 = arith.addf %arg0, %0 : vector<2x2xf32> - - // CHECK: %[[ADD2:.*]] = arith.addf %[[ARG2]], %[[C2]] : vector<16xf32> - %7 = arith.addf %arg1, %5 : vector<4x4xf32> - - // CHECK: return %[[RES]] : vector<2x2xf32> - return %0 : vector<2x2xf32> -} - -// ----- - -// vectorizable operation (arith.mulf) with tensor result types. - -// CHECK-LABEL: test_tensor_no_linearize -func.func @test_tensor_no_linearize(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>) { - - // CHECK: %[[MULF:.*]] = arith.mulf %arg0, %arg1 : tensor<2x2xf32> - %0 = arith.mulf %arg0, %arg1 : tensor<2x2xf32> - - return %0, %arg0 : tensor<2x2xf32>, tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func.func @test_scalable_linearize( -// CHECK-SAME: %[[ARG_0:.*]]: vector<2x[2]xf32>) -> vector<2x[2]xf32> { -func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> { - - // CHECK: %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[2]xf32> to vector<[4]xf32> - // CHECK: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[4]xf32> - %0 = arith.constant dense<[[3., 3.], [3., 3.]]> : vector<2x[2]xf32> - - // CHECK: %[[SIN:.*]] = math.sin %[[SC]] : vector<[4]xf32> - %1 = math.sin %arg0 : vector<2x[2]xf32> - - // CHECK: %[[ADDF:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<[4]xf32> - %2 = arith.addf %0, %1 : vector<2x[2]xf32> - - // CHECK: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32> - // CHECK: return %[[RES]] : vector<2x[2]xf32> - return %2 : vector<2x[2]xf32> -} - -// ----- - -// CHECK-LABEL: func.func @test_scalable_no_linearize( -// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> { -func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> { - - // CHECK: %[[CST:.*]] = arith.constant dense<2.000000e+00> : vector<[2]x[2]xf32> - %0 = arith.constant dense<[[2., 2.], [2., 2.]]> : vector<[2]x[2]xf32> - - // CHECK: %[[SIN:.*]] = math.sin %[[VAL_0]] : vector<[2]x[2]xf32> - %1 = math.sin %arg0 : vector<[2]x[2]xf32> - - // CHECK: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<[2]x[2]xf32> - %2 = arith.addf %0, %1 : vector<[2]x[2]xf32> - - // CHECK: return %[[RES]] : vector<[2]x[2]xf32> - return %2 : vector<[2]x[2]xf32> -} - -// ----- - -// CHECK-LABEL: func.func @test_0d_vector -func.func @test_0d_vector() -> vector { - - // CHECK: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector - %0 = arith.constant dense<0.0> : vector - - // CHECK: return %[[CST]] - return %0 : vector -} - -// ----- - -// CHECK-LABEL: test_extract_strided_slice_2D -// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> { -func.func @test_extract_strided_slice_2D(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> { - - // CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32> - // CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] - // CHECK-SAME: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32> - // CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32> - // CHECK: return %[[RES]] : vector<2x2xf32 - %0 = vector.extract_strided_slice %arg0 { sizes = [2, 2], strides = [1, 1], offsets = [0, 4]} - : vector<4x8xf32> to vector<2x2xf32> - return %0 : vector<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func.func @test_extract_strided_slice_2D_scalable( -// CHECK-SAME: %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> { -func.func @test_extract_strided_slice_2D_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> { - - // CHECK-NOT: vector.shuffle - // CHECK-NOT: vector.shape_cast - // CHECK: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] - %0 = vector.extract_strided_slice %arg0 { sizes = [2, 8], strides = [1, 1], offsets = [1, 0] } : vector<4x[8]xf32> to vector<2x[8]xf32> - - // CHECK: return %[[RES]] : vector<2x[8]xf32> - return %0 : vector<2x[8]xf32> -} - -// ----- - -// CHECK-LABEL: test_extract_strided_slice_3D -// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> { -func.func @test_extract_strided_slice_3D(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> { - - // CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32> - // CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] - // CHECK-SAME: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32> - // CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32> - // CHECK: return %[[RES]] : vector<1x4x2xf32> - %0 = vector.extract_strided_slice %arg0 { offsets = [1, 2], strides = [1, 1], sizes = [1, 4] } - : vector<2x8x2xf32> to vector<1x4x2xf32> - return %0 : vector<1x4x2xf32> -} - -// ----- - -// Test of insert_strided_slice -> shuffle. -// This is a contiguous insertion of 4 elements at offset 6 into a vector of 12 elements. -// CHECK-LABEL: insert_strided_slice_2D_into_4D -func.func @insert_strided_slice_2D_into_4D(%arg0 : vector<2x2xi8>, %arg1 : vector<2x1x3x2xi8>) -> vector<2x1x3x2xi8> { - -// CHECK-DAG: %[[ARG0:.*]] = vector.shape_cast {{.*}} to vector<4xi8> -// CHECK-DAG: %[[ARG1:.*]] = vector.shape_cast {{.*}} to vector<12xi8> -// CHECK: vector.shuffle %[[ARG1]], %[[ARG0]] -// CHECK-SAME: [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11] : vector<12xi8>, vector<4xi8> - %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1, 0, 0, 0], strides = [1, 1]} : vector<2x2xi8> into vector<2x1x3x2xi8> - -// CHECK: %[[RES:.*]] = vector.shape_cast {{.*}} to vector<2x1x3x2xi8> -// CHECK: return %[[RES]] : vector<2x1x3x2xi8> - return %0 : vector<2x1x3x2xi8> -} - -// ----- - -// Test of insert_strided_slice -> shuffle. -// [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]], [[12, 13], [14, 15]], [[16, 17]]] -// ^ ^ -// | | -// where the 2 elements are inserted into the 3x3x2 vector -// CHECK-LABEL: insert_strided_slice_3D -func.func @insert_strided_slice_3D(%arg0 : vector<1x2x1xi8>, %arg1 : vector<3x3x2xi8>) -> vector<3x3x2xi8> { - -// CHECK-DAG: %[[ARG0:.*]] = vector.shape_cast {{.*}} to vector<2xi8> -// CHECK-DAG: %[[ARG1:.*]] = vector.shape_cast {{.*}} to vector<18xi8> -// CHECK: vector.shuffle %[[ARG1]], %[[ARG0]] -// CHECK-SAME: [0, 1, 2, 3, 4, 5, 6, 7, 8, 18, 10, 19, 12, 13, 14, 15, 16, 17] : vector<18xi8>, vector<2xi8> - %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1, 1, 1], sizes = [1, 2, 1], strides = [1, 1, 1]} : vector<1x2x1xi8> into vector<3x3x2xi8> - -// CHECK: %[[RES:.*]] = vector.shape_cast {{.*}} to vector<3x3x2xi8> -// CHECK: return %[[RES]] : vector<3x3x2xi8> - return %0 : vector<3x3x2xi8> -} - -// ----- - -// CHECK-LABEL: insert_strided_slice_2D_higher_offsets -func.func @insert_strided_slice_2D_higher_offsets(%arg0 : vector<2x1xi8>, %arg1 : vector<2x2xi8>, %arg2 : vector<5x2xi8>) -> vector<5x2xi8> { - - // CHECK: [0, 1, 2, 3, 10, 11, 12, 13, 8, 9] - // ^^^ ^^^ ^^^ ^^^ - // insertion indices - %0 = vector.insert_strided_slice %arg1, %arg2 {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x2xi8> into vector<5x2xi8> - - // CHECK: [0, 1, 2, 3, 10, 5, 11, 7, 8, 9] - // ^^^ ^^^ - %1 = vector.insert_strided_slice %arg0, %0 {offsets = [2, 0], sizes = [2, 1], strides = [1, 1]} : vector<2x1xi8> into vector<5x2xi8> - - // CHECK: [0, 1, 2, 3, 4, 5, 6, 10, 8, 11] - // ^^^ ^^^ - %2 = vector.insert_strided_slice %arg0, %1 {offsets = [3, 1], sizes = [2, 1], strides = [1, 1]} : vector<2x1xi8> into vector<5x2xi8> - - return %2 : vector<5x2xi8> -} - -// ----- - -// CHECK-LABEL: negative_insert_strided_slice_scalable -// CHECK-NOT: vector.shuffle -// CHECK: return -func.func @negative_insert_strided_slice_scalable(%arg0 : vector<1x[2]xi8>, %arg1 : vector<2x[2]xi8>) -> vector<2x[2]xi8> { - %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0], strides = [1,1]} : vector<1x[2]xi8> into vector<2x[2]xi8> - return %0 : vector<2x[2]xi8> -} - -// ----- - -// CHECK-LABEL: test_vector_shuffle -// CHECK-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> { -func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> { - - // CHECK-DAG: %[[ARG0:.*]] = vector.shape_cast %[[ORIG_ARG0]] : vector<4x2xf32> to vector<8xf32> - // CHECK-DAG: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x2xf32> to vector<8xf32> - // CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG0]], %[[ARG1]] - // CHECK-SAME: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> - // CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32> - // CHECK: return %[[RES]] : vector<8x2xf32> - %0 = vector.shuffle %arg0, %arg1 [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x2xf32>, vector<4x2xf32> - return %0 : vector<8x2xf32> -} - -// ----- - -// CHECK-LABEL: test_vector_extract -// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> { -func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> { - - // CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32> - // CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] - // CHECK-SAME: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32> - // CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32> - // CHECK: return %[[RES]] : vector<8x2xf32> - %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32> - return %0 : vector<8x2xf32> -} - -// ----- - -// CHECK-LABEL: func.func @test_vector_extract_scalable( -// CHECK-SAME: %[[VAL_0:.*]]: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> { -func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> { - - // CHECK-NOT: vector.shuffle - // CHECK-NOT: vector.shape_cast - // CHECK: %[[RES:.*]] = vector.extract %[[VAL_0]][1] : vector<8x[2]xf32> from vector<2x8x[2]xf32> - %0 = vector.extract %arg0[1]: vector<8x[2]xf32> from vector<2x8x[2]xf32> - - // CHECK: return %[[RES]] : vector<8x[2]xf32> - return %0 : vector<8x[2]xf32> -} - -// ----- - -// CHECK-LABEL: test_vector_insert_scalar -// CHECK-SAME: (%[[DEST:.*]]: vector<2x4xf32>, %[[SRC:.*]]: f32) -> vector<2x4xf32> { -func.func @test_vector_insert_scalar(%arg0: vector<2x4xf32>, %arg1: f32) -> vector<2x4xf32> { - - // CHECK: %[[DEST_1D:.*]] = vector.shape_cast %[[DEST]] : vector<2x4xf32> to vector<8xf32> - // CHECK: %[[INSERT_1D:.*]] = vector.insert %[[SRC]], %[[DEST_1D]] [6] : f32 into vector<8xf32> - // CHECK: %[[RES:.*]] = vector.shape_cast %[[INSERT_1D]] : vector<8xf32> to vector<2x4xf32> - // CHECK: return %[[RES]] : vector<2x4xf32> - %0 = vector.insert %arg1, %arg0[1, 2]: f32 into vector<2x4xf32> - return %0 : vector<2x4xf32> -} - -// ----- - -// CHECK-LABEL: test_vector_insert -// CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> { -func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> { - - // CHECK-DAG: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32> - // CHECK-DAG: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32> - // CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]] - // CHECK-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, - // CHECK-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, - // CHECK-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32> - // CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32> - // CHECK: return %[[RES]] : vector<2x8x4xf32> - %0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32> - return %0 : vector<2x8x4xf32> -} - -// ----- - -// CHECK-LABEL: func.func @test_vector_insert_scalable( -// CHECK-SAME: %[[VAL_0:.*]]: vector<2x8x[4]xf32>, %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> { -func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> { - - // CHECK-NOT: vector.shuffle - // CHECK-NOT: vector.shape_cast - // CHECK: %[[RES:.*]] = vector.insert %[[VAL_1]], %[[VAL_0]] [0] : vector<8x[4]xf32> into vector<2x8x[4]xf32> - - %0 = vector.insert %arg1, %arg0[0]: vector<8x[4]xf32> into vector<2x8x[4]xf32> - // CHECK: return %[[RES]] : vector<2x8x[4]xf32> - return %0 : vector<2x8x[4]xf32> -} - -// ----- - -// CHECK-LABEL: test_vector_extract_scalar -func.func @test_vector_extract_scalar(%idx : index) { - %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32> - - // CHECK-NOT: vector.shuffle - // CHECK: vector.extract - // CHECK-NOT: vector.shuffle - %0 = vector.extract %cst[%idx] : i32 from vector<4xi32> - return -} - -// ----- - -// CHECK-LABEL: test_vector_bitcast -// CHECK-SAME: %[[ARG_0:.*]]: vector<4x4xf32> -func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> { - - // CHECK: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x4xf32> to vector<16xf32> - // CHECK: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<16xf32> to vector<32xf16> - // CHECK: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<32xf16> to vector<4x8xf16> - %1 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16> - return %1 : vector<4x8xf16> -} - -// ----- - -// CHECK-LABEL: test_vector_bitcast -// CHECK-SAME: %[[ARG_0:.*]]: vector<4x2xf32> -func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> { - - // CHECK: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x2xf32> to vector<8xf32> - // CHECK: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<8xf32> to vector<16xf16> - // CHECK: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<16xf16> to vector<4x4xf16> - %1 = vector.bitcast %arg0 : vector<4x2xf32> to vector<4x4xf16> - return %1 : vector<4x4xf16> -} - -// ----- - -// CHECK-LABEL: test_vector_bitcast -// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[2]xf32> -func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> { - - // CHECK: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x[2]xf32> to vector<[8]xf32> - // CHECK: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16> - // CHECK: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<4x[4]xf16> - %1 = vector.bitcast %arg0 : vector<4x[2]xf32> to vector<4x[4]xf16> - return %1 : vector<4x[4]xf16> -} - -// ----- - -// CHECK-LABEL: test_vector_bitcast -// CHECK-SAME: %[[ARG_0:.*]]: vector<[4]x2xf32> -func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> { - - // CHECK: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<[4]x2xf32> to vector<[8]xf32> - // CHECK: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16> - // CHECK: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<[4]x4xf16> - %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16> - return %1 : vector<[4]x4xf16> -} - -// ----- - -// CHECK-LABEL: test_linearize_across_for -func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> { - %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<2x2xi8> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - - // CHECK: scf.for {{.*}} -> (vector<4xi8>) - %1 = scf.for %i = %c0 to %c4 step %c1 iter_args(%arg1 = %0) -> (vector<2x2xi8>) { - - // CHECK: arith.addi {{.*}} : vector<4xi8> - %2 = arith.addi %arg1, %0 : vector<2x2xi8> - - // CHECK: scf.yield {{.*}} : vector<4xi8> - scf.yield %2 : vector<2x2xi8> - } - %3 = vector.shape_cast %1 : vector<2x2xi8> to vector<4xi8> - return %3 : vector<4xi8> -} - -// ----- - -// CHECK-LABEL: linearize_vector_splat -// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32> -func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> { - - // CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32> - // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32> - // CHECK: return %[[CAST]] : vector<4x2xi32> - %0 = vector.splat %arg0 : vector<4x2xi32> - return %0 : vector<4x2xi32> -} - -// ----- - -// CHECK-LABEL: linearize_scalable_vector_splat -// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32> -func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> { - - // CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<[8]xi32> - // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<[8]xi32> to vector<4x[2]xi32> - // CHECK: return %[[CAST]] : vector<4x[2]xi32> - %0 = vector.splat %arg0 : vector<4x[2]xi32> - return %0 : vector<4x[2]xi32> - -} - -// ----- - -// CHECK-LABEL: linearize_create_mask -// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1> -func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> { - - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index - // CHECK: %[[INDEXCAST:.*]] = arith.index_cast %[[CMP]] : i1 to index - // CHECK: %[[MULI:.*]] = arith.andi %[[INDEXCAST]], %[[ARG1]] : index - // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[MULI]] : vector<16xi1> - // CHECK: %[[CAST:.*]] = vector.shape_cast %[[MASK_1D]] : vector<16xi1> to vector<1x16xi1> - // CHECK: return %[[CAST]] : vector<1x16xi1> - %0 = vector.create_mask %arg0, %arg1 : vector<1x16xi1> - return %0 : vector<1x16xi1> -} - -// ----- -// CHECK-LABEL: linearize_scalable_create_mask -func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vector<1x[16]xi1> { - - // CHECK: %[[MASK_1D:.*]] = vector.create_mask {{%.*}} : vector<[16]xi1> - %0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1> - return %0 : vector<1x[16]xi1> -} diff --git a/mlir/test/Dialect/Vector/linearize/linearize-subject-to-bitwidth.mlir b/mlir/test/Dialect/Vector/linearize/linearize-subject-to-bitwidth.mlir new file mode 100644 index 0000000000000..c4183c2767b43 --- /dev/null +++ b/mlir/test/Dialect/Vector/linearize/linearize-subject-to-bitwidth.mlir @@ -0,0 +1,73 @@ +// RUN: mlir-opt %s -split-input-file -test-bit-width-constrained-vector-linearize=target-vector-bitwidth=2048 | FileCheck %s --check-prefixes=ALL,BW-2048 +// RUN: mlir-opt %s -split-input-file -test-bit-width-constrained-vector-linearize=target-vector-bitwidth=128 | FileCheck %s --check-prefixes=ALL,BW-128 +// RUN: mlir-opt %s -split-input-file -test-bit-width-constrained-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0 + +// A vector<2x2xf32> has inner-most dimension with 64-bits. Check that at +// bitwidth threshold 128 (>= 64), operations are linearized, and at +// bitwidth threshold 0 (< 64), operations are not linearized. + +// ALL-LABEL: test_result_bitwidth_64 + +// BW-2048: arith.constant {{.*}} vector<4xf32> +// BW-2048: math.sin {{.*}} vector<4xf32> + +// BW-128: arith.constant {{.*}} vector<4xf32> +// BW-128: math.sin {{.*}} vector<4xf32> + +// BW-0: arith.constant {{.*}} vector<2x2xf32> +// BW-0: math.sin {{.*}} vector<2x2xf32> + +func.func @test_result_bitwidth_64(%arg0: vector<2x2xf32>) -> vector<2x2xf32> { + %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32> + %1 = arith.addf %arg0, %0 : vector<2x2xf32> + %2 = math.sin %1 : vector<2x2xf32> + return %2 : vector<2x2xf32> +} + +// ----- + +// The size of the 'index' type is backend specific, so we cannot guarantee that +// the inner-most dimension below (of size 2*nbBits(index)) is below any bitwidth +// threshold. Test that operations with vectors of index type are not linearized. + +// ALL-LABEL: test_index_no_linearize + +// BW-2048: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex> + +// BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex> + +// BW-0: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex> + +func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> { + %0 = arith.addi %arg0, %arg1 : vector<2x2xindex> + return %0 : vector<2x2xindex> +} + +// ----- + +// The logic for the insert op with regards to the bitwidth threshold is +// different to the other ops, so we test it here. Specifically, the logic +// is based on the bitwidth of the value to store. + +// ALL-LABEL: test_vector_insert +// ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> { + +// BW-2048-DAG: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32> +// BW-2048-DAG: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32> +// BW-2048: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]] +// BW-2048: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32> +// BW-2048: return %[[RES]] : vector<2x8x4xf32> + +// BW-128-DAG: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32> +// BW-128-DAG: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<2x32xf32> +// BW-128: %[[INSERT:.*]] = vector.insert %[[ARG_SRC]], %[[ARG_DEST]] [0] : vector<32xf32> into vector<2x32xf32> +// BW-128: %[[RES:.*]] = vector.shape_cast %[[INSERT]] : vector<2x32xf32> to vector<2x8x4xf32> +// BW-128: return %[[RES]] : vector<2x8x4xf32> + +// BW-0: %[[RES:.*]] = vector.insert %[[SRC]], %[[DEST]] [0] : vector<8x4xf32> into vector<2x8x4xf32> +// BW-0: return %[[RES]] : vector<2x8x4xf32> + +func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> { + %0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32> + return %0 : vector<2x8x4xf32> +} diff --git a/mlir/test/Dialect/Vector/linearize/linearize.mlir b/mlir/test/Dialect/Vector/linearize/linearize.mlir new file mode 100644 index 0000000000000..a382c798b62bc --- /dev/null +++ b/mlir/test/Dialect/Vector/linearize/linearize.mlir @@ -0,0 +1,782 @@ +// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s + +// **--------------------------------------------------------** +// Tests of vectoriable ops +// **--------------------------------------------------------** + +// Constant linearization happens here because of the vector.shape_cast folder. +// The linearization of math.sin and arith.addf happens because of the pattern, CollapseInnerVectorizable. + +// CHECK-LABEL: linearize_constant_and_elementwise +// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>) +// CHECK-DAG: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32> +// CHECK: %[[SIN:.*]] = math.sin %[[ARG]] : vector<4xf32> +// CHECK: %[[ADD:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<4xf32> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[ADD]] : vector<4xf32> to vector<2x2xf32> +// CHECK: return %[[RES]] : vector<2x2xf32> +func.func @linearize_constant_and_elementwise(%arg0: vector<2x2xf32>) -> vector<2x2xf32> { + %0 = arith.constant dense<[[1., 2.], [3., 4.]]> : vector<2x2xf32> + %1 = math.sin %arg0 : vector<2x2xf32> + %2 = arith.addf %1, %0 : vector<2x2xf32> + return %2 : vector<2x2xf32> +} + +// ----- + +// The pattern CollapseInnerVectorizable is applied twice (4D->3D, then 3D-2D). + +func.func @linearize_elemenetwise_4D(%arg0 : vector<2x3x5x7xi8>, %arg1 : vector<2x3x5x7xi8>) -> vector<210xi8> { + // CHECK-LABEL: linearize_elemenetwise_4D + // CHECK-SAME: (%[[ARG0:.*]]: vector<2x3x5x7xi8>, %[[ARG1:.*]]: vector<2x3x5x7xi8>) -> vector<210xi8> { + // CHECK-DAG: %[[SC0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x3x5x7xi8> to vector<210xi8> + // CHECK-DAG: %[[SC1:.*]] = vector.shape_cast %[[ARG1]] : vector<2x3x5x7xi8> to vector<210xi8> + // CHECK: %[[ADD:.*]] = arith.addi %[[SC0]], %[[SC1]] : vector<210xi8> + // CHECK: return %[[ADD]] : vector<210xi8> + %0 = arith.addi %arg0, %arg1 : vector<2x3x5x7xi8> + %1 = vector.shape_cast %0 : vector<2x3x5x7xi8> to vector<210xi8> + return %1 : vector<210xi8> +} + +// ----- + +// Poison linearization happens here because of the vector.shape_cast folder. + +// CHECK-LABEL: linearize_poison +// CHECK: %[[POISON:.*]] = ub.poison : vector<4xf32> +// CHECK: return %[[POISON]] : vector<4xf32> +func.func @linearize_poison() -> vector<4xf32> { + %0 = ub.poison : vector<2x2xf32> + %1 = vector.shape_cast %0 : vector<2x2xf32> to vector<4xf32> + return %1 : vector<4xf32> +} + +// ----- + +// Check that linearization does not happen if the operands of the vectorizable operation are not vectors. + +// CHECK-LABEL: tensor_no_linearize +// CHECK: %[[MULF:.*]] = arith.mulf %arg0, %arg1 : tensor<2x2xf32> +func.func @tensor_no_linearize(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>) { + %0 = arith.mulf %arg0, %arg1 : tensor<2x2xf32> + return %0, %arg0 : tensor<2x2xf32>, tensor<2x2xf32> +} + +// ----- + +// Check that linearization happens as long as the combined dimensions have at most 1 scalable dimension. + +// CHECK-LABEL: func.func @scalable_linearize( +// CHECK-SAME: %[[ARG_0:.*]]: vector<2x[2]xf32>) -> vector<2x[2]xf32> { +// CHECK-DAG: %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[2]xf32> to vector<[4]xf32> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[4]xf32> +// CHECK: %[[SIN:.*]] = math.sin %[[SC]] : vector<[4]xf32> +// CHECK: %[[ADDF:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<[4]xf32> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32> +// CHECK: return %[[RES]] : vector<2x[2]xf32> +func.func @scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> { + %0 = arith.constant dense<[[3., 3.], [3., 3.]]> : vector<2x[2]xf32> + %1 = math.sin %arg0 : vector<2x[2]xf32> + %2 = arith.addf %0, %1 : vector<2x[2]xf32> + return %2 : vector<2x[2]xf32> +} + +// ----- + +// In this case there are 2 scalable dimensions, these cannot be combined. + +// CHECK-LABEL: func.func @scalable_no_linearize( +// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> { +// CHECK: %[[CST:.*]] = arith.constant dense<2.000000e+00> : vector<[2]x[2]xf32> +// CHECK: %[[SIN:.*]] = math.sin %[[VAL_0]] : vector<[2]x[2]xf32> +// CHECK: %[[RES:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<[2]x[2]xf32> +// CHECK: return %[[RES]] : vector<[2]x[2]xf32> +func.func @scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> { + %0 = arith.constant dense<[[2., 2.], [2., 2.]]> : vector<[2]x[2]xf32> + %1 = math.sin %arg0 : vector<[2]x[2]xf32> + %2 = arith.addf %1, %0 : vector<[2]x[2]xf32> + return %2 : vector<[2]x[2]xf32> +} + +// ----- + +// In this case, the innermost 2 dimensions can be combined, because only 1 of them is scalable. +// However, at a subsequent application of the pattern the innermost 2 dimensions are now both +// scalable, and so the pattern fails to collapse 2D -> 1D. + +// CHECK-LABEL: func.func @scalable_partial_linearize( +// CHECK-SAME: %[[VAL_0:.*]]: vector<[2]x[2]x4xi8>) -> vector<[2]x[2]x4xi8> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[VAL_0]] : vector<[2]x[2]x4xi8> to vector<[2]x[8]xi8> +// CHECK: %[[COS:.*]] = math.absi %[[SC]] : vector<[2]x[8]xi8> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[COS]] : vector<[2]x[8]xi8> to vector<[2]x[2]x4xi8> +// CHECK: return %[[RES]] : vector<[2]x[2]x4xi8> +func.func @scalable_partial_linearize(%arg0: vector<[2]x[2]x4xi8>) -> vector<[2]x[2]x4xi8> { + %0 = math.absi %arg0 : vector<[2]x[2]x4xi8> + return %0 : vector<[2]x[2]x4xi8> +} + +// ----- + +// Check that rank-0 vectors are not converted to rank-1. + +// CHECK-LABEL: func.func @vector_0d +// CHECK: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector +// CHECK: return %[[CST]] +func.func @vector_0d() -> vector { + %0 = arith.constant dense<0.> : vector + return %0 : vector +} + + +// ----- + +// **--------------------------------------------------------** +// Tests of vector.shuffle +// [CollapseInnnerVectorShuffle] +// **--------------------------------------------------------** + +// CHECK-LABEL: vector_shuffle_2D +// CHECK-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> { +// CHECK-DAG: %[[ARG0:.*]] = vector.shape_cast %[[ORIG_ARG0]] : vector<4x2xf32> to vector<8xf32> +// CHECK-DAG: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x2xf32> to vector<8xf32> +// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG0]], %[[ARG1]] +// CHECK-SAME: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32> +// CHECK: return %[[RES]] : vector<8x2xf32> +func.func @vector_shuffle_2D(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> { + %0 = vector.shuffle %arg0, %arg1 [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x2xf32>, vector<4x2xf32> + return %0 : vector<8x2xf32> +} + +// ----- + +// CHECK-LABEL: vector_shuffle_5D +// CHECK-SAME: (%[[ARG0:.*]]: vector<2x1x2x1x2xf32>, %[[ARG1:.*]]: vector<2x1x2x1x2xf32>) +// CHECK-DAG: %[[SC0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x1x2x1x2xf32> to vector<8xf32> +// CHECK-DAG: %[[SC1:.*]] = vector.shape_cast %[[ARG1]] : vector<2x1x2x1x2xf32> to vector<8xf32> +// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[SC0]], %[[SC1]] +// CHECK-SAME: [12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3] : vector<8xf32>, vector<8xf32> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<12xf32> to vector<3x1x2x1x2xf32> +// CHECK: return %[[RES]] : vector<3x1x2x1x2xf32> +func.func @vector_shuffle_5D(%arg0: vector<2x1x2x1x2xf32>, %arg1: vector<2x1x2x1x2xf32>) -> vector<3x1x2x1x2xf32> { + %0 = vector.shuffle %arg0, %arg1 [3, 2, 0] : vector<2x1x2x1x2xf32>, vector<2x1x2x1x2xf32> + return %0 : vector<3x1x2x1x2xf32> +} + +// ----- + +// **--------------------------------------------------------** +// Tests of vector.bitcast +// [CollapseInnerVectorBitcast] +// **--------------------------------------------------------** + +// CHECK-LABEL: vector_bitcast_2D +// CHECK-SAME: %[[ARG_0:.*]]: vector<4x4xf32> +// CHECK: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x4xf32> to vector<16xf32> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<16xf32> to vector<32xf16> +// CHECK: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<32xf16> to vector<4x8xf16> +func.func @vector_bitcast_2D(%arg0: vector<4x4xf32>) -> vector<4x8xf16> { + %1 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16> + return %1 : vector<4x8xf16> +} + +// ----- + +// CHECK-LABEL: vector_bitcast_3D +// CHECK-SAME: %[[ARG_0:.*]]: vector<10x4x2xf32> +// CHECK: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<10x4x2xf32> to vector<80xf32> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<80xf32> to vector<160xf16> +// CHECK: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<160xf16> to vector<10x4x4xf16> +func.func @vector_bitcast_3D(%arg0: vector<10x4x2xf32>) -> vector<10x4x4xf16> { + %1 = vector.bitcast %arg0 : vector<10x4x2xf32> to vector<10x4x4xf16> + return %1 : vector<10x4x4xf16> +} + +// ----- + +// CHECK-LABEL: vector_bitcast_scalable +// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[2]xf32> +// CHECK: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x[2]xf32> to vector<[8]xf32> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16> +// CHECK: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<4x[4]xf16> +func.func @vector_bitcast_scalable(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> { + %1 = vector.bitcast %arg0 : vector<4x[2]xf32> to vector<4x[4]xf16> + return %1 : vector<4x[4]xf16> +} + +// ----- + +// CHECK-LABEL: vector_bitcast_scalable_3D +// CHECK-SAME: %[[ARG_0:.*]]: vector<1x3x[8]xi8> +// CHECK: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<1x3x[8]xi8> to vector<[24]xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[24]xi8> to vector<[6]xi32> +// CHECK: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[6]xi32> to vector<1x3x[2]xi32> +// CHECK: return %[[UPCAST]] : vector<1x3x[2]xi32> +func.func @vector_bitcast_scalable_3D(%arg0: vector<1x3x[8]xi8>) -> vector<1x3x[2]xi32> { + %1 = vector.bitcast %arg0 : vector<1x3x[8]xi8> to vector<1x3x[2]xi32> + return %1 : vector<1x3x[2]xi32> +} + +// ----- + +// **--------------------------------------------------------** +// Tests of vector.create_mask +// [SqueezeCreateMaskUnitDims] +// **--------------------------------------------------------** + +// CHECK-LABEL: linearize_create_mask_2D +// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index +// CHECK: %[[INDEXCAST:.*]] = arith.index_cast %[[CMP]] : i1 to index +// CHECK: %[[MULI:.*]] = arith.muli %[[INDEXCAST]], %[[ARG1]] : index +// CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[MULI]] : vector<16xi1> +// CHECK: %[[CAST:.*]] = vector.shape_cast %[[MASK_1D]] : vector<16xi1> to vector<1x16xi1> +// CHECK: return %[[CAST]] : vector<1x16xi1> +func.func @linearize_create_mask_2D(%arg0 : index, %arg1 : index) -> vector<1x16xi1> { + %0 = vector.create_mask %arg0, %arg1 : vector<1x16xi1> + return %0 : vector<1x16xi1> +} + +// ----- + +// CHECK-LABEL: linearize_create_mask_4D +// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16x1x1xi1> { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CMP0:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index +// CHECK: %[[CMP1:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index +// CHECK: %[[MULI:.*]] = arith.muli %[[CMP0]], %[[CMP1]] : i1 +// CHECK: %[[INDEXCAST:.*]] = arith.index_cast %[[MULI]] : i1 to index +// CHECK: %[[MULI2:.*]] = arith.muli %[[INDEXCAST]], %[[ARG1]] : index +// CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[MULI2]] : vector<16xi1> +// CHECK: %[[CAST:.*]] = vector.shape_cast %[[MASK_1D]] : +// CHECK-SAME: vector<16xi1> to vector<1x16x1x1xi1> +// CHECK: return %[[CAST]] : vector<1x16x1x1xi1> +func.func @linearize_create_mask_4D(%arg0 : index, %arg1 : index) -> vector<1x16x1x1xi1> { + %cst1 = arith.constant 1 : index + %0 = vector.create_mask %arg0, %arg1, %cst1, %arg0 : vector<1x16x1x1xi1> + return %0 : vector<1x16x1x1xi1> +} + +// ----- + +// If any of the indices to the vector.create_mask is 0, the arithmetic is greatly +// simplified, because the mask will always be all false. + +// CHECK-LABEL: linearize_create_mask_4D_false +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C0]] : vector<16xi1> +// CHECK: vector.shape_cast %[[MASK_1D]] : vector<16xi1> to vector<1x16x1x1xi1> +func.func @linearize_create_mask_4D_false(%arg0 : index, %arg1 : index) -> vector<1x16x1x1xi1> { + %cst0 = arith.constant 0 : index + %0 = vector.create_mask %arg0, %arg1, %cst0, %arg0 : vector<1x16x1x1xi1> + return %0 : vector<1x16x1x1xi1> +} + +// ----- + +// CHECK-LABEL: linearize_scalable_create_mask +// CHECK: vector.create_mask {{%.*}} : vector<[16]xi1> +func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vector<1x[16]xi1> { + %0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1> + return %0 : vector<1x[16]xi1> +} + +// ----- + +// The mask being created in this test has 2 dimensions that are not 1, so it is not linearized. + +// CHECK-LABEL: negative_create_mask +// CHECK: vector.create_mask {{.*}} vector<2x2xi1> +func.func @negative_create_mask(%arg0 : index, %arg1 : index) -> vector<2x2xi1> { + %0 = vector.create_mask %arg0, %arg1 : vector<2x2xi1> + return %0 : vector<2x2xi1> +} + +// ----- + +// **--------------------------------------------------------** +// Tests of scf.for +// **--------------------------------------------------------** + +// This test illustrates how type conversion can be used to linearize structured +// op types. + +// CHECK-LABEL: linearize_across_for +// CHECK: scf.for {{.*}} -> (vector<4xi8>) +// CHECK: arith.addi {{.*}} : vector<4xi8> +// CHECK: scf.yield {{.*}} : vector<4xi8> +func.func @linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> { + %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<2x2xi8> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %1 = scf.for %i = %c0 to %c4 step %c1 iter_args(%arg1 = %0) -> (vector<2x2xi8>) { + %2 = arith.addi %arg1, %0 : vector<2x2xi8> + scf.yield %2 : vector<2x2xi8> + } + %3 = vector.shape_cast %1 : vector<2x2xi8> to vector<4xi8> + return %3 : vector<4xi8> +} + +// ----- + +// **--------------------------------------------------------** +// Tests of vector.splat +// **--------------------------------------------------------** + +// CHECK-LABEL: linearize_vector_splat +// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<8xi32> +// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32> +// CHECK: return %[[SPLAT]] : vector<8xi32> +func.func @linearize_vector_splat(%arg0: i32) -> vector<8xi32> { + %0 = vector.splat %arg0 : vector<4x2xi32> + %1 = vector.shape_cast %0 : vector<4x2xi32> to vector<8xi32> + return %1 : vector<8xi32> +} + +// ----- + +// CHECK-LABEL: linearize_scalable_vector_splat +// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<[8]xi32> +// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<[8]xi32> +// CHECK: return %[[SPLAT]] : vector<[8]xi32> +func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<[8]xi32> { + %0 = vector.splat %arg0 : vector<4x[2]xi32> + %1 = vector.shape_cast %0 : vector<4x[2]xi32> to vector<[8]xi32> + return %1 : vector<[8]xi32> +} + + +// **--------------------------------------------------------** +// Tests of vector.insert +// **--------------------------------------------------------** + +// ----- + +// vector.insert where the destination is 1D vector is always unchanged. + +// CHECK-LABEL: insert_scalar_to_1D( +// CHECK-SAME: %[[A0:.*]]: i8, %[[A1:.*]]: vector<4xi8> +// CHECK: %[[IN0:.*]] = vector.insert %[[A0]], %[[A1]] [2] : i8 into vector<4xi8> +// CHECK: return %[[IN0]] : vector<4xi8> +func.func @insert_scalar_to_1D(%arg0 : i8, %arg1 : vector<4xi8>) -> vector<4xi8> +{ + %inserted = vector.insert %arg0, %arg1[2] : i8 into vector<4xi8> + return %inserted : vector<4xi8> +} + +// ----- + +// vector.insert of scalar always becomes insert of scalar into 1-D vector. +// +// CHECK-LABEL: insert_scalar_to_2D( +// CHECK-SAME: %[[A0:.*]]: i8, %[[A1:.*]]: vector<3x4xi8> +// CHECK: %[[SC0:.*]] = vector.shape_cast %[[A1]] : vector<3x4xi8> to vector<12xi8> +// CHECK: %[[IN0:.*]] = vector.insert %[[A0]], %[[SC0]] [9] : i8 into vector<12xi8> +// CHECK: %[[SC1:.*]] = vector.shape_cast %[[IN0]] : vector<12xi8> to vector<3x4xi8> +// CHECK: return %[[SC1]] : vector<3x4xi8> +func.func @insert_scalar_to_2D(%arg0 : i8, %arg1 : vector<3x4xi8>) -> vector<3x4xi8> +{ + %inserted = vector.insert %arg0, %arg1[2, 1] : i8 into vector<3x4xi8> + return %inserted : vector<3x4xi8> +} + +// ----- + +// Another test of inserting a scalar into a vector. + +// CHECK-LABEL: insert_scalar_to_2D_f32 +// CHECK-SAME: (%[[DEST:.*]]: vector<2x4xf32>, %[[SRC:.*]]: f32) -> vector<2x4xf32> { +// CHECK: %[[DEST_1D:.*]] = vector.shape_cast %[[DEST]] : vector<2x4xf32> to vector<8xf32> +// CHECK: %[[INSERT_1D:.*]] = vector.insert %[[SRC]], %[[DEST_1D]] [6] : f32 into vector<8xf32> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[INSERT_1D]] : vector<8xf32> to vector<2x4xf32> +// CHECK: return %[[RES]] : vector<2x4xf32> +func.func @insert_scalar_to_2D_f32(%arg0: vector<2x4xf32>, %arg1: f32) -> vector<2x4xf32> { + + %0 = vector.insert %arg1, %arg0[1, 2]: f32 into vector<2x4xf32> + return %0 : vector<2x4xf32> +} + +// ----- + +// vector.insert where the source isn't a scalar. This case: 1D -> 2D. + +// CHECK-LABEL: insert_1D_to_2D( +// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11] +func.func @insert_1D_to_2D(%arg0 : vector<4xi8>, %arg1 : vector<3x4xi8>) -> vector<3x4xi8> { + %inserted = vector.insert %arg0, %arg1[1] : vector<4xi8> into vector<3x4xi8> + return %inserted : vector<3x4xi8> +} + +// ----- + +// CHECK-LABEL: insert_2D_to_3D +// CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> { +// CHECK-DAG: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32> +// CHECK-DAG: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32> +// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]] +// CHECK-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, +// CHECK-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, +// CHECK-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32> +// CHECK: return %[[RES]] : vector<2x8x4xf32> +func.func @insert_2D_to_3D(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> { + %0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32> + return %0 : vector<2x8x4xf32> +} + +// ----- + +// CHECK-LABEL: insert_2D_to_4D( +// CHECK-COUNT-2: shape_cast +// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : +// CHECK-SAME: vector<16xi8>, vector<4xi8> +func.func @insert_2D_to_4D(%arg0 : vector<2x2xi8>, %arg1 : vector<2x2x2x2xi8>) -> vector<2x2x2x2xi8> { + %inserted = vector.insert %arg0, %arg1[1, 1] : vector<2x2xi8> into vector<2x2x2x2xi8> + return %inserted : vector<2x2x2x2xi8> +} + +// ----- + +// CHECK-LABEL: func.func @insert_scalable( +// CHECK-SAME: %[[ARG0:.*]]: vector<2x8x[4]xf32>, %[[ARG1:.*]]: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> { +// CHECK-DAG: %[[SHAPE_CAST0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x8x[4]xf32> to vector<2x[32]xf32> +// CHECK-DAG: %[[SHAPE_CAST1:.*]] = vector.shape_cast %[[ARG1]] : vector<8x[4]xf32> to vector<[32]xf32> +// CHECK: %[[INSERT:.*]] = vector.insert %[[SHAPE_CAST1]], %[[SHAPE_CAST0]] [0] : vector<[32]xf32> into vector<2x[32]xf32> +// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[INSERT]] : vector<2x[32]xf32> to vector<2x8x[4]xf32> +// CHECK: return %[[RESULT]] : vector<2x8x[4]xf32> +func.func @insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector<8x[4]xf32>) -> vector<2x8x[4]xf32> { + %0 = vector.insert %arg1, %arg0[0]: vector<8x[4]xf32> into vector<2x8x[4]xf32> + return %0 : vector<2x8x[4]xf32> +} + +// ----- + +// **--------------------------------------------------------** +// Tests of vector.extract +// **--------------------------------------------------------** + +// vector.extract where the source is 1D vector is always unchanged. + +// CHECK-LABEL: extract_scalar_from_1D( +// CHECK-SAME: %[[A0:.*]]: vector<4xi8> +// CHECK: %[[EX0:.*]] = vector.extract %[[A0]][2] : i8 from vector<4xi8> +// CHECK: return %[[EX0]] : i8 +func.func @extract_scalar_from_1D(%arg0 : vector<4xi8>) -> i8 { + %extracted = vector.extract %arg0[2] : i8 from vector<4xi8> + return %extracted : i8 +} + +// ----- + +// CHECK-LABEL: extract_scalar_from_1D_dynamic +// CHECK-NOT: vector.shuffle +// CHECK: vector.extract +// CHECK-NOT: vector.shuffle +func.func @extract_scalar_from_1D_dynamic(%idx : index) -> i32 { + %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32> + %0 = vector.extract %cst[%idx] : i32 from vector<4xi32> + return %0 : i32 +} + +// ----- + + +// CHECK-LABEL: extract_scalar_from_2D( +// CHECK-SAME: %[[A0:.*]]: vector<12xi8> +// CHECK: %[[EX0:.*]] = vector.extract %[[A0]][9] : i8 from vector<12xi8> +// CHECK: return %[[EX0]] : i8 +func.func @extract_scalar_from_2D(%arg0 : vector<12xi8>) -> i8 { + %sc = vector.shape_cast %arg0 : vector<12xi8> to vector<3x4xi8> + %extracted = vector.extract %sc[2, 1] : i8 from vector<3x4xi8> + return %extracted : i8 +} + +// ----- + +// CHECK-LABEL: extract_1D_from_2D( +// CHECK: vector.shuffle +// CHECK-SAME: [4, 5, 6, 7] : vector<12xi8>, vector<12xi8> +func.func @extract_1D_from_2D(%arg0 : vector<3x4xi8>) -> vector<4xi8> { + %extracted = vector.extract %arg0[1] : vector<4xi8> from vector<3x4xi8> + return %extracted : vector<4xi8> +} + +// ----- + +// CHECK-LABEL: extract_2D_from_3D +// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> { +// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32> +// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] +// CHECK-SAME: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32> +// CHECK: return %[[RES]] : vector<8x2xf32> +func.func @extract_2D_from_3D(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> { + %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32> + return %0 : vector<8x2xf32> +} + +// ----- + +// CHECK-LABEL: extract_2D_from_4D( +// CHECK: vector.shuffle +// CHECK-SAME: [10, 11] : vector<24xi8>, vector<24xi8> +func.func @extract_2D_from_4D(%arg0 : vector<4x3x2x1xi8>) -> vector<2x1xi8> { + %extracted = vector.extract %arg0[1, 2] : vector<2x1xi8> from vector<4x3x2x1xi8> + return %extracted : vector<2x1xi8> +} + +// ----- + +// In this test, the dynamic extract dimension prevents linearization all the way to a shuffle operation. +// The outermost 2 and the innermost 2 dimensions are linearized, however. + +// CHECK-LABEL: extract_2D_from_5D_dynamic( +// CHECK-SAME: %[[ARG0:.*]]: vector<5x4x3x2x1xi8>, %[[IDX:.*]]: index) -> vector<2x1xi8> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[ARG0]] : vector<5x4x3x2x1xi8> to vector<20x3x2xi8> +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[SC]][9, %[[IDX]]] : vector<2xi8> from vector<20x3x2xi8> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[EXTRACT]] : vector<2xi8> to vector<2x1xi8> +func.func @extract_2D_from_5D_dynamic(%arg0 : vector<5x4x3x2x1xi8>, %idx : index) -> vector<2x1xi8> { + %extracted = vector.extract %arg0[2, 1, %idx] : vector<2x1xi8> from vector<5x4x3x2x1xi8> + return %extracted : vector<2x1xi8> +} + +// ----- + +// CHECK-LABEL: func.func @extract_scalable( +// CHECK-SAME: %[[ARG0:.*]]: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[ARG0]] : vector<2x8x[2]xf32> to vector<2x[16]xf32> +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[SC]][1] : vector<[16]xf32> from vector<2x[16]xf32> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[EXTRACT]] : vector<[16]xf32> to vector<8x[2]xf32> +// CHECK: return %[[RES]] : vector<8x[2]xf32> +func.func @extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> { + %0 = vector.extract %arg0[1]: vector<8x[2]xf32> from vector<2x8x[2]xf32> + return %0 : vector<8x[2]xf32> +} + + +// **--------------------------------------------------------** +// Tests of vector.insert_strided_slice +// **--------------------------------------------------------** + +// ----- + +// Test of insert_strided_slice -> shuffle. +// This is a contiguous insertion of 4 elements at offset 6 into a vector of 12 elements. +// CHECK-LABEL: insert_strided_slice_2D_into_4D +// CHECK-DAG: %[[ARG0:.*]] = vector.shape_cast {{.*}} to vector<4xi8> +// CHECK-DAG: %[[ARG1:.*]] = vector.shape_cast {{.*}} to vector<12xi8> +// CHECK: vector.shuffle %[[ARG1]], %[[ARG0]] +// CHECK-SAME: [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11] : vector<12xi8>, vector<4xi8> +// CHECK: %[[RES:.*]] = vector.shape_cast {{.*}} to vector<2x1x3x2xi8> +// CHECK: return %[[RES]] : vector<2x1x3x2xi8> +func.func @insert_strided_slice_2D_into_4D(%arg0 : vector<2x2xi8>, %arg1 : vector<2x1x3x2xi8>) -> vector<2x1x3x2xi8> { + %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1, 0, 0, 0], strides = [1, 1]} : vector<2x2xi8> into vector<2x1x3x2xi8> + return %0 : vector<2x1x3x2xi8> +} + +// ----- + +// Test of insert_strided_slice -> shuffle. +// [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]], [[12, 13], [14, 15]], [[16, 17]]] +// ^ ^ +// | | +// where the 2 elements are inserted into the 3x3x2 vector +// CHECK-LABEL: insert_strided_slice_3D +// CHECK-DAG: %[[ARG0:.*]] = vector.shape_cast {{.*}} to vector<2xi8> +// CHECK-DAG: %[[ARG1:.*]] = vector.shape_cast {{.*}} to vector<18xi8> +// CHECK: vector.shuffle %[[ARG1]], %[[ARG0]] +// CHECK-SAME: [0, 1, 2, 3, 4, 5, 6, 7, 8, 18, 10, 19, 12, 13, 14, 15, 16, 17] : vector<18xi8>, vector<2xi8> +// CHECK: %[[RES:.*]] = vector.shape_cast {{.*}} to vector<3x3x2xi8> +// CHECK: return %[[RES]] : vector<3x3x2xi8> +func.func @insert_strided_slice_3D(%arg0 : vector<1x2x1xi8>, %arg1 : vector<3x3x2xi8>) -> vector<3x3x2xi8> { + %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1, 1, 1], sizes = [1, 2, 1], strides = [1, 1, 1]} : vector<1x2x1xi8> into vector<3x3x2xi8> + return %0 : vector<3x3x2xi8> +} + +// ----- + +// CHECK-LABEL: insert_strided_slice_2D_higher_offsets +// CHECK: [0, 1, 2, 3, 10, 11, 12, 13, 8, 9] +// ^^^ ^^^ ^^^ ^^^ +// insertion indices +// CHECK: [0, 1, 2, 3, 10, 5, 11, 7, 8, 9] +// ^^^ ^^^ +// insertion indices +// CHECK: [0, 1, 2, 3, 4, 5, 6, 10, 8, 11] +// ^^^ ^^^ +// insertion indices +func.func @insert_strided_slice_2D_higher_offsets(%arg0 : vector<2x1xi8>, %arg1 : vector<2x2xi8>, %arg2 : vector<5x2xi8>) -> vector<5x2xi8> { + %0 = vector.insert_strided_slice %arg1, %arg2 {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x2xi8> into vector<5x2xi8> + %1 = vector.insert_strided_slice %arg0, %0 {offsets = [2, 0], sizes = [2, 1], strides = [1, 1]} : vector<2x1xi8> into vector<5x2xi8> + %2 = vector.insert_strided_slice %arg0, %1 {offsets = [3, 1], sizes = [2, 1], strides = [1, 1]} : vector<2x1xi8> into vector<5x2xi8> + return %2 : vector<5x2xi8> +} + +// ----- + +// CHECK-LABEL: negative_insert_strided_slice_scalable_shapes_differ +// CHECK-NOT: vector.shuffle +// CHECK: return +func.func @negative_insert_strided_slice_scalable_shapes_differ(%arg0 : vector<1x[2]xi8>, %arg1 : vector<2x[2]xi8>) -> vector<2x[2]xi8> { + %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0], strides = [1,1]} : vector<1x[2]xi8> into vector<2x[2]xi8> + return %0 : vector<2x[2]xi8> +} + +// ----- + +// CHECK-LABEL: insert_strided_slice_scalable_common_pre +// CHECK-SAME: (%[[ARG0:.*]]: vector<3x1x[4]x2xi8>, %[[ARG1:.*]]: vector<3x1x[4]x5xi8>) +// CHECK-DAG: %[[SC0:.*]] = vector.shape_cast %[[ARG0]] : vector<3x1x[4]x2xi8> to vector<[12]x2xi8> +// CHECK-DAG: %[[SC1:.*]] = vector.shape_cast %[[ARG1]] : vector<3x1x[4]x5xi8> to vector<[12]x5xi8> +// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[SC0]], %[[SC1]] +// CHECK-SAME: {offsets = [0, 2], strides = [1, 1]} : vector<[12]x2xi8> into vector<[12]x5xi8> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[INSERT]] : vector<[12]x5xi8> to vector<3x1x[4]x5xi8> +// CHECK: return %[[RES]] : vector<3x1x[4]x5xi8> + +func.func @insert_strided_slice_scalable_common_pre(%arg0 : vector<3x1x[4]x2xi8>, %arg1 : vector<3x1x[4]x5xi8>) -> vector<3x1x[4]x5xi8> { + %1 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0, 2], strides = [1, 1, 1, 1]} : vector<3x1x[4]x2xi8> into vector<3x1x[4]x5xi8> + return %1 : vector<3x1x[4]x5xi8> +} + +// ----- + +// CHECK-LABEL: insert_strided_slice_scalable_common_post +// CHECK-SAME: (%[[ARG0:.*]]: vector<1x[2]x3xi8>, %[[ARG1:.*]]: vector<5x[2]x3xi8>) +// CHECK-DAG: %[[SC0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x[2]x3xi8> to vector<1x[6]xi8> +// CHECK-DAG: %[[SC1:.*]] = vector.shape_cast %[[ARG1]] : vector<5x[2]x3xi8> to vector<5x[6]xi8> +// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[SC0]], %[[SC1]] +// CHECK-SAME: {offsets = [3, 0], strides = [1, 1]} : vector<1x[6]xi8> into vector<5x[6]xi8> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[INSERT]] : vector<5x[6]xi8> to vector<5x[2]x3xi8> +// CHECK: return %[[RES]] : vector<5x[2]x3xi8> + +func.func @insert_strided_slice_scalable_common_post(%arg0 : vector<1x[2]x3xi8>, %arg1 : vector<5x[2]x3xi8>) -> vector<5x[2]x3xi8> { + %1 = vector.insert_strided_slice %arg0, %arg1 {offsets = [3, 0, 0], strides = [1, 1, 1]} : vector<1x[2]x3xi8> into vector<5x[2]x3xi8> + return %1 : vector<5x[2]x3xi8> +} + +// ----- + +// CHECK-LABEL: insert_strided_slice_1D( +// CHECK: shuffle {{.*}} [0, 8, 9, 3, 4, 5, 6, 7] +func.func @insert_strided_slice_1D(%arg0 : vector<2xi8>, %arg1 : vector<8xi8>) -> vector<8xi8> { + %inserted = vector.insert_strided_slice %arg0, %arg1 {offsets = [1], strides = [1]} : vector<2xi8> into vector<8xi8> + return %inserted : vector<8xi8> +} + +// ----- + +// CHECK-LABEL: insert_strided_slice_4D_contiguous( +// CHECK: vector.shuffle +// CHECK-SAME: 52, 53, 120, 121 +// CHECK-SAME: 130, 131, 66, 67 +// CHECK-SAME: vector<120xi8>, vector<12xi8> + +func.func @insert_strided_slice_4D_contiguous(%arg0 : vector<1x2x2x3xi8>, %arg1 : vector<5x4x2x3xi8>) -> vector<5x4x2x3xi8> { + %inserted = vector.insert_strided_slice %arg0, %arg1 {offsets = [2, 1, 0, 0], strides = [1, 1, 1, 1]} : vector<1x2x2x3xi8> into vector<5x4x2x3xi8> + return %inserted : vector<5x4x2x3xi8> +} + +// ----- + +// This insert_strided_slice is not contiguous, and so it is always linearized to a 1D vector.shuffle + +// CHECK-LABEL: insert_strided_slice_4D_noncontiguous( +// CHECK: vector.shuffle +// CHECK-SAME: [0, 1, 2, 8, 4, 5, 6, 9] : vector<8xi8>, vector<2xi8> + +func.func @insert_strided_slice_4D_noncontiguous(%arg0 : vector<1x2x1x1xi8>, %arg1 : vector<1x2x2x2xi8>) -> vector<1x2x2x2xi8> { + %inserted = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 1, 1], strides = [1, 1, 1, 1]} : vector<1x2x1x1xi8> into vector<1x2x2x2xi8> + return %inserted : vector<1x2x2x2xi8> +} + +// ----- + +// **--------------------------------------------------------** +// Tests of vector.extract_strided_slice +// **--------------------------------------------------------** + +// CHECK-LABEL: extract_strided_slice_2D +// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> { +// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32> +// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] +// CHECK-SAME: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32> +// CHECK: return %[[RES]] : vector<2x2xf32 +func.func @extract_strided_slice_2D(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> { + %0 = vector.extract_strided_slice %arg0 { sizes = [2, 2], strides = [1, 1], offsets = [0, 4]} + : vector<4x8xf32> to vector<2x2xf32> + return %0 : vector<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @extract_strided_slice_2D_scalable( +// CHECK-SAME: %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> { +// CHECK-NOT: vector.shuffle +// CHECK-NOT: vector.shape_cast +// CHECK: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] +// CHECK: return %[[RES]] : vector<2x[8]xf32> +func.func @extract_strided_slice_2D_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> { + %0 = vector.extract_strided_slice %arg0 { sizes = [2, 8], strides = [1, 1], offsets = [1, 0] } : vector<4x[8]xf32> to vector<2x[8]xf32> + return %0 : vector<2x[8]xf32> +} + +// ----- + +// CHECK-LABEL: extract_strided_slice_3D +// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> { +// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32> +// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] +// CHECK-SAME: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32> +// CHECK: return %[[RES]] : vector<1x4x2xf32> +func.func @extract_strided_slice_3D(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> { + %0 = vector.extract_strided_slice %arg0 { offsets = [1, 2], strides = [1, 1], sizes = [1, 4] } + : vector<2x8x2xf32> to vector<1x4x2xf32> + return %0 : vector<1x4x2xf32> +} + +// ----- + + +// CHECK-LABEL: extract_strided_slice_1D( +// CHECK: vector.shuffle {{.*}} [1, 2] +func.func @extract_strided_slice_1D(%arg0 : vector<8xi8>) -> vector<2xi8> { + %extracted = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [2], strides = [1]} : vector<8xi8> to vector<2xi8> + return %extracted : vector<2xi8> +} + +// ----- + +//CHECK-LABEL: extract_strided_slice_4D_contiguous_1( +// CHECK: vector.shuffle +// CHECK-SAME: [3, 4, 5] +// CHECK-SAME: vector<6xi8>, vector<6xi8> +func.func @extract_strided_slice_4D_contiguous_1(%arg0 : vector<2x1x3x1xi8>) -> vector<1x1x3x1xi8> { + %extracted = vector.extract_strided_slice %arg0 {offsets = [1, 0, 0, 0], sizes = [1, 1, 3, 1], strides = [1, 1, 1, 1]} : vector<2x1x3x1xi8> to vector<1x1x3x1xi8> + return %extracted : vector<1x1x3x1xi8> +} + +// ----- + +//CHECK-LABEL: extract_strided_slice_4D_contiguous_2( +// CHECK: vector.shuffle +// CHECK-SAME: [3, 4] +// CHECK-SAME: vector<6xi8>, vector<6xi8> +func.func @extract_strided_slice_4D_contiguous_2(%arg0 : vector<2x1x3x1xi8>) -> vector<1x1x2x1xi8> { + %extracted = vector.extract_strided_slice %arg0 {offsets = [1, 0, 0, 0], sizes = [1, 1, 2, 1], strides = [1, 1, 1, 1]} : vector<2x1x3x1xi8> to vector<1x1x2x1xi8> + return %extracted : vector<1x1x2x1xi8> +} + +// ----- + +// CHECK-LABEL: extract_strided_slice_4D_noncontiguous( +// CHECK: vector.shuffle +// CHECK-SAME: [0, 1, 3, 4] +// CHECK-SAME: vector<6xi8>, vector<6xi8> +func.func @extract_strided_slice_4D_noncontiguous(%arg0 : vector<2x1x3x1xi8>) -> vector<2x1x2x1xi8> { + %extracted = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0], sizes = [2, 1, 2, 1], strides = [1, 1, 1, 1]} : vector<2x1x3x1xi8> to vector<2x1x2x1xi8> + return %extracted : vector<2x1x2x1xi8> +} diff --git a/mlir/test/Dialect/Vector/linearize/rank-reduce-strided-ops.mlir b/mlir/test/Dialect/Vector/linearize/rank-reduce-strided-ops.mlir new file mode 100644 index 0000000000000..363f532dd825b --- /dev/null +++ b/mlir/test/Dialect/Vector/linearize/rank-reduce-strided-ops.mlir @@ -0,0 +1,195 @@ +// RUN: mlir-opt %s -split-input-file -test-rank-reduce-strided-slice-ops -verify-diagnostics | FileCheck %s + + +// **---------------------------------------------** +// Tests of vector.extract_strided_slice +// **---------------------------------------------** + + +// The 6 elements extracted are contiguous, so this can be expressed as a rank-1 vector.extract_strided_slice. + +// CHECK-LABEL: @extract_strided_slice_2D_to_1D( +// CHECK-SAME: %[[A:.*]]: vector<5x2xi8>) -> vector<3x2xi8> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<5x2xi8> to vector<10xi8> +// CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[SC]] +// CHECK-SAME: {offsets = [2], sizes = [6], strides = [1]} : vector<10xi8> to vector<6xi8> +// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[EXTRACTED]] : vector<6xi8> to vector<3x2xi8> +// CHECK: return %[[CASTED]] : vector<3x2xi8> +func.func @extract_strided_slice_2D_to_1D(%arg0 : vector<5x2xi8>) -> vector<3x2xi8> { + %extracted = vector.extract_strided_slice %arg0 {offsets = [1, 0], sizes = [3, 2], strides = [1, 1]} : vector<5x2xi8> to vector<3x2xi8> + return %extracted : vector<3x2xi8> +} + +// ----- + +// The 5 elements extracted are not contiguous, so this cannot be expressed as a rank-1 vector.extract_strided_slice. + +// CHECK-LABEL: @negative_extract_strided_slice_2D_to_1D( +// CHECK-SAME: %[[A:.*]]: vector<5x2xi8>) -> vector<5x1xi8> { +// CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[A]] +// CHECK: return %[[EXTRACTED]] : vector<5x1xi8> +func.func @negative_extract_strided_slice_2D_to_1D(%arg0 : vector<5x2xi8>) -> vector<5x1xi8> { + %extracted = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [5, 1], strides = [1, 1]} : vector<5x2xi8> to vector<5x1xi8> + return %extracted : vector<5x1xi8> +} + +// ----- + +// The 2 elements extracted are contiguous, so this can be expressed as a rank-1 vector.extract_strided_slice. + +// CHECK-LABEL: @extract_strided_slice_4D_leading_ones( +// CHECK-SAME: %[[A:.*]]: vector<2x1x3x1xi8>) -> vector<1x1x2x1xi8> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<2x1x3x1xi8> to vector<6xi8> +// CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[SC]] +// CHECK-SAME: {offsets = [3], sizes = [2], strides = [1]} : vector<6xi8> to vector<2xi8> +// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[EXTRACTED]] : vector<2xi8> to vector<1x1x2x1xi8> +// CHECK: return %[[CASTED]] : vector<1x1x2x1xi8> + +func.func @extract_strided_slice_4D_leading_ones(%arg0 : vector<2x1x3x1xi8>) -> vector<1x1x2x1xi8> { + %extracted = vector.extract_strided_slice %arg0 {offsets = [1, 0, 0, 0], sizes = [1, 1, 2, 1], strides = [1, 1, 1, 1]} : vector<2x1x3x1xi8> to vector<1x1x2x1xi8> + return %extracted : vector<1x1x2x1xi8> +} + +// ----- + +// CHECK-LABEL: @extract_strided_slice_4D_becomes_2D( +// CHECK-SAME: %[[A:.*]]: vector<8x7x6x5xi8>) -> vector<2x7x2x5xi8> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<8x7x6x5xi8> to vector<56x30xi8> +// CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[SC]] +// CHECK-SAME: {offsets = [14, 5], sizes = [14, 10], strides = [1, 1]} : vector<56x30xi8> to vector<14x10xi8> +// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[EXTRACTED]] : vector<14x10xi8> to vector<2x7x2x5xi8> +// CHECK: return %[[CASTED]] : vector<2x7x2x5xi8> +func.func @extract_strided_slice_4D_becomes_2D(%arg0 : vector<8x7x6x5xi8>) -> vector<2x7x2x5xi8> { + %extracted = vector.extract_strided_slice %arg0 {offsets = [2, 0, 1, 0], sizes = [2, 7, 2, 5], strides = [1, 1, 1, 1]} : vector<8x7x6x5xi8> to vector<2x7x2x5xi8> + return %extracted : vector<2x7x2x5xi8> +} + +// ----- + +// CHECK-LABEL: @test_extract_strided_slice_4D( +// CHECK-SAME: %[[ARG0:.*]]: vector<2x2x2x2xi8>) -> vector<1x2x1x2xi8> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[ARG0]] : vector<2x2x2x2xi8> to vector<4x4xi8> +// CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[SC]] +// CHECK-SAME: {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi8> to vector<2x2xi8> +// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[EXTRACTED]] : vector<2x2xi8> to vector<1x2x1x2xi8> +// CHECK: return %[[CASTED]] : vector<1x2x1x2xi8> +func.func @test_extract_strided_slice_4D(%arg0 : vector<2x2x2x2xi8>) -> vector<1x2x1x2xi8> { + %0 = vector.extract_strided_slice %arg0 + {offsets = [1, 0, 1, 0], + sizes = [1, 2, 1, 2], + strides = [1, 1, 1, 1]} : vector<2x2x2x2xi8> to vector<1x2x1x2xi8> + return %0 : vector<1x2x1x2xi8> +} + +// ----- + +// CHECK-LABEL: @extract_strided_slice_4D_becomes_3D( +// CHECK-SAME: %[[A:.*]]: vector<8x7x6x5xi8>) -> vector<8x2x6x2xi8> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<8x7x6x5xi8> to vector<8x42x5xi8> +// CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[SC]] +// CHECK-SAME: {offsets = [0, 12, 1], sizes = [8, 12, 2], strides = [1, 1, 1]} : vector<8x42x5xi8> to vector<8x12x2xi8> +// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[EXTRACTED]] : vector<8x12x2xi8> to vector<8x2x6x2xi8> +// CHECK: return %[[CASTED]] : vector<8x2x6x2xi8> + +func.func @extract_strided_slice_4D_becomes_3D(%arg0 : vector<8x7x6x5xi8>) -> vector<8x2x6x2xi8> { + %extracted = vector.extract_strided_slice %arg0 {offsets = [0, 2, 0, 1], sizes = [8, 2, 6, 2], strides = [1, 1, 1, 1]} : vector<8x7x6x5xi8> to vector<8x2x6x2xi8> + return %extracted : vector<8x2x6x2xi8> +} + +// ----- + +// CHECK-LABEL: @extract_strided_implicit( +// CHECK-SAME: %[[ARG:.*]]: vector<4x8x16xf32>) -> vector<1x8x16xf32> { +// CHECK: %[[SC0:.*]] = vector.shape_cast %[[ARG]] : vector<4x8x16xf32> to vector<512xf32> +// CHECK: %[[EXTRACTED:.*]] = vector.extract_strided_slice %[[SC0]] +// CHECK-SAME: {offsets = [256], sizes = [128], strides = [1]} : vector<512xf32> to vector<128xf32> +// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[EXTRACTED]] : vector<128xf32> to vector<1x8x16xf32> +// CHECK: return %[[CASTED]] : vector<1x8x16xf32> +func.func @extract_strided_implicit(%arg0 : vector<4x8x16xf32>) -> vector<1x8x16xf32> { + %0 = vector.extract_strided_slice %arg0 + {offsets = [2], sizes = [1], strides = [1]}: + vector<4x8x16xf32> to vector<1x8x16xf32> + return %0 : vector<1x8x16xf32> +} + +// ----- + +// **---------------------------------------------** +// Tests of vector.insert_strided_slice +// **---------------------------------------------** + + +// CHECK-LABEL: @negative_insert_strided_slice( +// CHECK-SAME: %[[A:.*]]: vector<2x2xi8>, %[[B:.*]]: vector<2x1xi8>) -> vector<2x2xi8> { +// CHECK: %[[INSERTED:.*]] = vector.insert_strided_slice %[[B]], %[[A]] +// CHECK: return %[[INSERTED]] : vector<2x2xi8> +func.func @negative_insert_strided_slice(%arg0 : vector<2x2xi8>, %arg1 : vector<2x1xi8>) -> vector<2x2xi8> { + %inserted = vector.insert_strided_slice %arg1, %arg0 {offsets = [0, 1], strides = [1, 1]} : vector<2x1xi8> into vector<2x2xi8> + return %inserted : vector<2x2xi8> +} + +// ----- + +// CHECK-LABEL: @positive_insert_strided_slice( +// CHECK-SAME: %[[A:.*]]: vector<2x2xi8>, %[[B:.*]]: vector<1x2xi8>) -> vector<2x2xi8> { +// CHECK-DAG: %[[SCA:.*]] = vector.shape_cast %[[A]] : vector<2x2xi8> to vector<4xi8> +// CHECK-DAG: %[[SCB:.*]] = vector.shape_cast %[[B]] : vector<1x2xi8> to vector<2xi8> +// CHECK: %[[INSERTED:.*]] = vector.insert_strided_slice %[[SCB]], %[[SCA]] +// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xi8> into vector<4xi8> +// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INSERTED]] : vector<4xi8> to vector<2x2xi8> +// CHECK: return %[[CASTED]] : vector<2x2xi8> + +func.func @positive_insert_strided_slice(%arg0 : vector<2x2xi8>, %arg1 : vector<1x2xi8>) -> vector<2x2xi8> { + %inserted = vector.insert_strided_slice %arg1, %arg0 {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi8> into vector<2x2xi8> + return %inserted : vector<2x2xi8> +} + +// ----- + +// CHECK-LABEL: @test_insert_strided_slice_4D( +// CHECK-SAME: %[[ARG0:.*]]: vector<2x2x2x2xi8>, %[[ARG1:.*]]: vector<1x2x1x2xi8>) -> vector<2x2x2x2xi8> { +// CHECK-DAG: %[[SC1:.*]] = vector.shape_cast %[[ARG1]] : vector<1x2x1x2xi8> to vector<2x2xi8> +// CHECK-DAG: %[[SC0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x2x2x2xi8> to vector<4x4xi8> +// CHECK: %[[INSERTED:.*]] = vector.insert_strided_slice %[[SC1]], %[[SC0]] +// CHECK-SAME: {offsets = [2, 2], strides = [1, 1]} : vector<2x2xi8> into vector<4x4xi8> +// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INSERTED]] : vector<4x4xi8> to vector<2x2x2x2xi8> +// CHECK: return %[[CASTED]] : vector<2x2x2x2xi8> +func.func @test_insert_strided_slice_4D(%arg0 : vector<2x2x2x2xi8>, %arg1 : vector<1x2x1x2xi8>) -> vector<2x2x2x2xi8> { + %0 = vector.insert_strided_slice %arg1, %arg0 + {offsets = [1, 0, 1, 0], + strides = [1, 1, 1, 1]} : vector<1x2x1x2xi8> into vector<2x2x2x2xi8> + return %0 : vector<2x2x2x2xi8> +} + +// ----- + +// CHECK-LABEL: @test_insert_strided_implicit_2_into_3( +// CHECK-SAME: %[[ARG0:.*]]: vector<16x4x8xf32>, %[[ARG1:.*]]: vector<2x8xf32>) -> vector<16x4x8xf32> { +// CHECK-DAG: %[[SC1:.*]] = vector.shape_cast %[[ARG1]] : vector<2x8xf32> to vector<16xf32> +// CHECK-DAG: %[[SC0:.*]] = vector.shape_cast %[[ARG0]] : vector<16x4x8xf32> to vector<512xf32> +// CHECK: %[[INSERTED:.*]] = vector.insert_strided_slice %[[SC1]], %[[SC0]] +// CHECK-SAME: {offsets = [72], strides = [1]} : vector<16xf32> into vector<512xf32> +// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INSERTED]] : vector<512xf32> to vector<16x4x8xf32> +// CHECK: return %[[CASTED]] : vector<16x4x8xf32> + +func.func @test_insert_strided_implicit_2_into_3(%arg0 : vector<16x4x8xf32>, %arg1 : vector<2x8xf32>) -> vector<16x4x8xf32> { + %0 = vector.insert_strided_slice %arg1, %arg0 {offsets = [2, 1, 0], strides = [1, 1]}: + vector<2x8xf32> into vector<16x4x8xf32> + return %0 : vector<16x4x8xf32> +} + +// ----- + +// CHECK-LABEL: @test_insert_strided_implicit_1_into_3( +// CHECK-SAME: %[[ARG0:.*]]: vector<16x4x8xf32>, %[[ARG1:.*]]: vector<1xf32>) -> vector<16x4x8xf32> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[ARG0]] : vector<16x4x8xf32> to vector<512xf32> +// CHECK: %[[INSERTED:.*]] = vector.insert_strided_slice %[[ARG1]], %[[SC]] +// CHECK-SAME: {offsets = [72], strides = [1]} : vector<1xf32> into vector<512xf32> +// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INSERTED]] : vector<512xf32> to vector<16x4x8xf32> +// CHECK: return %[[CASTED]] : vector<16x4x8xf32> + +func.func @test_insert_strided_implicit_1_into_3(%arg0 : vector<16x4x8xf32>, %arg1 : vector<1xf32>) -> vector<16x4x8xf32> { + %0 = vector.insert_strided_slice %arg1, %arg0 {offsets = [2, 1, 0], strides = [1]}: + vector<1xf32> into vector<16x4x8xf32> + return %0 : vector<16x4x8xf32> +} diff --git a/mlir/test/lib/Dialect/Vector/CMakeLists.txt b/mlir/test/lib/Dialect/Vector/CMakeLists.txt index e16937029ac0e..1ce069599af43 100644 --- a/mlir/test/lib/Dialect/Vector/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Vector/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRVectorTestPasses TestVectorTransforms.cpp + TestVectorLinearize.cpp EXCLUDE_FROM_LIBMLIR ) diff --git a/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp b/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp new file mode 100644 index 0000000000000..011f25e8227f2 --- /dev/null +++ b/mlir/test/lib/Dialect/Vector/TestVectorLinearize.cpp @@ -0,0 +1,268 @@ +//===- TestVectorLinearize.cpp - Test Vector linearization ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Math//IR/Math.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::vector; + +namespace { + +struct TestVectorLinearize final + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize) + + StringRef getArgument() const override { return "test-vector-linearize"; } + StringRef getDescription() const override { + return "Use shape_casts to ensure vector operands/results are rank <= 1"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext &context = getContext(); + Operation *op = getOperation(); + + // Step 1: Run the linearization patterns. + // + // Note that we disable folding to prevent the extract(shape_cast) -> + // extract folder undoing linearization. Without disabling this, we can get + // into infinite loops. + { + RewritePatternSet patterns(&context); + populateForVectorLinearize(patterns); + GreedyRewriteConfig config; + config.enableFolding(false); + if (failed(applyPatternsGreedily(op, std::move(patterns), config))) + return signalPassFailure(); + } + + // Step 2: linearize SCF structured ops using type conversion. + { + TypeConverter typeConverter; + RewritePatternSet patterns(&context); + ConversionTarget target(context); + + // Convert 'type' to a "legal" (rank-1) type. + auto convertType = [](Type type) -> std::optional { + VectorType vectorType = dyn_cast(type); + if (!vectorType || !isLinearizableVector(vectorType)) + return type; + + VectorType linearizedType = VectorType::get(vectorType.getNumElements(), + vectorType.getElementType(), + vectorType.isScalable()); + return linearizedType; + }; + typeConverter.addConversion(convertType); + + // This function is used during legalization to create shape_casts between + // the legal rank-1 types and other types. + auto materializeCast = [](OpBuilder &builder, Type type, + ValueRange inputs, Location loc) -> Value { + if (inputs.size() != 1) + return nullptr; + + Value input = inputs.front(); + if (!isa(type) || !isa(input.getType())) + return nullptr; + + return builder.create(loc, type, input); + }; + typeConverter.addSourceMaterialization(materializeCast); + typeConverter.addTargetMaterialization(materializeCast); + + // As we are here just illustrating how to use type conversion to + // linearize SCF operations, we consider all other operations already + // legal. + target.markUnknownOpDynamicallyLegal( + [=](Operation *op) -> std::optional { + if (scf::SCFDialect::getDialectNamespace() != + op->getDialect()->getNamespace()) + return true; + + // This will return true if, for all operand and result types `t`, + // convertType(t) = t. This is true if there are no rank>=2 vectors. + return typeConverter.isLegal(op); + }); + + mlir::scf::populateSCFStructuralTypeConversionsAndLegality( + typeConverter, patterns, target); + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + return signalPassFailure(); + } + + // Step 3: Perform folding. + if (failed(applyPatternsGreedily(op, RewritePatternSet(&context)))) + return signalPassFailure(); + } +}; + +struct TestRankReduceStridedSliceOps final + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRankReduceStridedSliceOps) + + TestRankReduceStridedSliceOps() = default; + TestRankReduceStridedSliceOps(const TestRankReduceStridedSliceOps &pass) = + default; + + StringRef getArgument() const override { + return "test-rank-reduce-strided-slice-ops"; + } + StringRef getDescription() const override { + return "Test pass for rank-reducing strided slice ops."; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateForStridedRankReduction(patterns); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +struct TestVectorBitWidthLinearize final + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize) + + TestVectorBitWidthLinearize() = default; + TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass) + : PassWrapper(pass) {} + + StringRef getArgument() const override { + return "test-bit-width-constrained-vector-linearize"; + } + StringRef getDescription() const override { + return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints " + "on inner-most dimension's bit width. If the inner-most dimension " + "exceded a threshold, the op is not linearized."; + } + Option targetVectorBitwidth{ + *this, "target-vector-bitwidth", + llvm::cl::desc( + "Minimum vector bitwidth to enable the flattening transformation"), + llvm::cl::init(std::numeric_limits::max())}; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext &context = getContext(); + Operation *op = getOperation(); + + // Initialize the patterns with a pre-condition on the the bit-width, for + // linearization. + auto preCondition = [&](Operation *op) -> LogicalResult { + bool notLinearizable = + isNotLinearizableBecauseLargeInnerDimension(op, targetVectorBitwidth); + return notLinearizable ? failure() : success(); + }; + RewritePatternSet patterns(&context); + populateForVectorLinearize(patterns, preCondition); + + // Apply the patterns, with folding disabled. + if (failed( + applyPatternsGreedily(op, std::move(patterns), + GreedyRewriteConfig().enableFolding(false)))) + return signalPassFailure(); + + // Fold. + if (failed(applyPatternsGreedily(op, RewritePatternSet(&context)))) + return signalPassFailure(); + } + + /// If `type` is VectorType with trailing dimension of (bit) size greater than + /// or equal to `targetBitWidth`, its defining op is considered legal. + static bool + isNotLinearizableBecauseLargeInnerDimension(Type type, + unsigned targetBitWidth) { + + VectorType vecType = dyn_cast(type); + + // Not linearizable for reasons other than what this function checks. + if (!vecType || vecType.getRank() == 0) + return false; + + // The width of the type 'index' is unbounded (and therefore potentially + // above the target width). + if (vecType.getElementType().isIndex()) + return true; + + unsigned finalDimSize = vecType.getShape().back(); + unsigned nbBitsPerElm = vecType.getElementTypeBitWidth(); + unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm; + return trailingVecDimBitWidth >= targetBitWidth; + } + +private: + static bool + isNotLinearizableBecauseLargeInnerDimension(Operation *op, + unsigned targetBitWidth) { + // Check on bitwidths. + SmallVector> toCheck = + getTypeBitWidthBoundPairs(op, targetBitWidth); + return std::any_of(toCheck.begin(), toCheck.end(), + [&](std::pair typeWidth) { + return isNotLinearizableBecauseLargeInnerDimension( + typeWidth.first, typeWidth.second); + }); + } + + /// Get the set of operand/result types to check for sufficiently + /// small inner-most dimension size. + static SmallVector> + getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) { + + if (auto insertOp = dyn_cast(op)) { + unsigned w = targetBitWidth < std::numeric_limits::max() + ? targetBitWidth + 1 + : targetBitWidth; + return {{insertOp.getValueToStoreType(), w}}; + } + + auto resultTypes = op->getResultTypes(); + SmallVector> resultsWithBitWidth; + resultsWithBitWidth.reserve(resultTypes.size()); + for (Type type : resultTypes) { + resultsWithBitWidth.push_back({type, targetBitWidth}); + } + return resultsWithBitWidth; + } +}; + +} // namespace + +namespace mlir { +namespace test { +extern void registerTestVectorLinearize() { + PassRegistration(); + PassRegistration(); + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index a7285ab8cb15a..6c2be985a52db 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -17,7 +17,6 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" @@ -840,160 +839,6 @@ struct TestVectorEmulateMaskedLoadStore final } }; -/// Get the set of operand/result types to check for sufficiently -/// small inner-most dimension size. -static SmallVector> -getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) { - - if (auto insertOp = dyn_cast(op)) { - unsigned w = targetBitWidth < std::numeric_limits::max() - ? targetBitWidth + 1 - : targetBitWidth; - return {{insertOp.getValueToStoreType(), w}}; - } - - auto resultTypes = op->getResultTypes(); - SmallVector> resultsWithBitWidth; - resultsWithBitWidth.reserve(resultTypes.size()); - for (Type type : resultTypes) { - resultsWithBitWidth.push_back({type, targetBitWidth}); - } - return resultsWithBitWidth; -} - -/// If `type` is VectorType with trailing dimension of (bit) size greater than -/// or equal to `targetBitWidth`, its defining op is considered legal. -static bool -isNotLinearizableBecauseLargeInnerDimension(Type type, - unsigned targetBitWidth) { - - VectorType vecType = dyn_cast(type); - - // Not linearizable for reasons other than what this function checks. - if (!vecType || vecType.getRank() == 0) - return false; - - // The width of the type 'index' is unbounded (and therefore potentially above - // the target width). - if (vecType.getElementType().isIndex()) - return true; - - unsigned finalDimSize = vecType.getShape().back(); - unsigned nbBitsPerElm = vecType.getElementTypeBitWidth(); - unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm; - return trailingVecDimBitWidth >= targetBitWidth; -} - -static bool -isNotLinearizableBecauseLargeInnerDimension(Operation *op, - unsigned targetBitWidth) { - // Check on bitwidths. - SmallVector> toCheck = - getTypeBitWidthBoundPairs(op, targetBitWidth); - return llvm::any_of(toCheck, [&](std::pair typeWidth) { - return isNotLinearizableBecauseLargeInnerDimension(typeWidth.first, - typeWidth.second); - }); -} - -void populateWithBitWidthConstraints(TypeConverter &typeConverter, - ConversionTarget &target, - unsigned targetBitWidth) { - - // The general purpose definition of what ops are legal must come first. - populateForVectorLinearize(typeConverter, target); - - // Extend the set of legal ops to include those with large inner-most - // dimensions on selected operands/results. - target.markUnknownOpDynamicallyLegal( - [=](Operation *op) -> std::optional { - if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) { - return true; - } - return {}; - }); -} - -struct TestVectorBitWidthLinearize final - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize) - - TestVectorBitWidthLinearize() = default; - TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass) - : PassWrapper(pass) {} - - StringRef getArgument() const override { - return "test-bit-width-constrained-vector-linearize"; - } - StringRef getDescription() const override { - return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints " - "in inner-most dimension's bit width."; - } - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - Option targetVectorBitwidth{ - *this, "target-vector-bitwidth", - llvm::cl::desc( - "Minimum vector bitwidth to enable the flattening transformation"), - llvm::cl::init(std::numeric_limits::max())}; - void runOnOperation() override { - auto *context = &getContext(); - - TypeConverter typeConverter; - RewritePatternSet patterns(context); - ConversionTarget target(*context); - - populateWithBitWidthConstraints(typeConverter, target, - targetVectorBitwidth); - - vector::populateVectorLinearizeBasePatterns(typeConverter, target, - patterns); - - vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter, target, - patterns); - - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - return signalPassFailure(); - } -}; - -struct TestVectorLinearize final - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize) - - TestVectorLinearize() = default; - - StringRef getArgument() const override { return "test-vector-linearize"; } - StringRef getDescription() const override { - return "Linearizes ND vectors for N >= 2 into 1D vectors"; - } - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - MLIRContext &context = getContext(); - TypeConverter converter; - RewritePatternSet patterns(&context); - ConversionTarget target(context); - - vector::populateForVectorLinearize(converter, target); - - vector::populateVectorLinearizeBasePatterns(converter, target, patterns); - vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target, - patterns); - mlir::scf::populateSCFStructuralTypeConversionsAndLegality( - converter, patterns, target); - - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - return signalPassFailure(); - } -}; - struct TestEliminateVectorMasks : public PassWrapper> { @@ -1065,10 +910,6 @@ void registerTestVectorLowerings() { PassRegistration(); - PassRegistration(); - - PassRegistration(); - PassRegistration(); } } // namespace test diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 143a5e8e8f8dd..21d6573662ad3 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -156,6 +156,7 @@ void registerTestTopologicalSortAnalysisPass(); void registerTestTransformDialectEraseSchedulePass(); void registerTestPassStateExtensionCommunication(); void registerTestVectorLowerings(); +void registerTestVectorLinearize(); void registerTestVectorReductionToSPIRVDotProd(); void registerTestVulkanRunnerPipeline(); void registerTestWrittenToPass(); @@ -302,7 +303,9 @@ void registerTestPasses() { mlir::test::registerTestTransformDialectEraseSchedulePass(); mlir::test::registerTestPassStateExtensionCommunication(); mlir::test::registerTestVectorLowerings(); + mlir::test::registerTestVectorLinearize(); mlir::test::registerTestVectorReductionToSPIRVDotProd(); + mlir::test::registerTestVulkanRunnerPipeline(); mlir::test::registerTestWrittenToPass(); mlir::test::registerTestXeGPULowerings(); From 3776e58df90aded41b0d10e76f321fbc7df82ccd Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 8 Jul 2025 14:37:22 -0700 Subject: [PATCH 2/2] test docs improvement --- .../Dialect/Vector/linearize/linearize.mlir | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/mlir/test/Dialect/Vector/linearize/linearize.mlir b/mlir/test/Dialect/Vector/linearize/linearize.mlir index a382c798b62bc..fd0f46e786694 100644 --- a/mlir/test/Dialect/Vector/linearize/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize/linearize.mlir @@ -1,7 +1,8 @@ // RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s // **--------------------------------------------------------** -// Tests of vectoriable ops +// Tests of vectorizable ops +// [CollapseInnerVectorizable] // **--------------------------------------------------------** // Constant linearization happens here because of the vector.shape_cast folder. @@ -323,6 +324,7 @@ func.func @linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> { // **--------------------------------------------------------** // Tests of vector.splat +// [CollapseInnerSplat] // **--------------------------------------------------------** // CHECK-LABEL: linearize_vector_splat @@ -348,9 +350,10 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<[8]xi32> { } -// **--------------------------------------------------------** -// Tests of vector.insert -// **--------------------------------------------------------** +// **-------------------------------------------------------------------** +// Tests of vector.insert +// [ConvertInsertToShuffle, CollapseInnerInsert, CollapseOuterInsert] +// **-------------------------------------------------------------------** // ----- @@ -453,9 +456,10 @@ func.func @insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector<8x[4]xf32>) // ----- -// **--------------------------------------------------------** -// Tests of vector.extract -// **--------------------------------------------------------** +// **--------------------------------------------------------------------------** +// Tests of vector.extract +// [ConvertExtractToShuffle, CollapseInnerExtract, CollapseOuterExtract] +// **--------------------------------------------------------------------------** // vector.extract where the source is 1D vector is always unchanged. @@ -556,9 +560,10 @@ func.func @extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x[2]xf32> { } -// **--------------------------------------------------------** -// Tests of vector.insert_strided_slice -// **--------------------------------------------------------** +// **------------------------------------------------------------** +// Tests of vector.insert_strided_slice +// [ConvertInsertStridedToShuffle, CollapseInnerInsertStride] +// **------------------------------------------------------------** // ----- @@ -693,9 +698,10 @@ func.func @insert_strided_slice_4D_noncontiguous(%arg0 : vector<1x2x1x1xi8>, %ar // ----- -// **--------------------------------------------------------** +// **----------------------------------------------------------------** // Tests of vector.extract_strided_slice -// **--------------------------------------------------------** +// [ConvertExtractStridedToShuffle, CollapseInnerExtractStride] +// **----------------------------------------------------------------** // CHECK-LABEL: extract_strided_slice_2D // CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> {