Skip to content

[WIP] Linearization without conversion #147433

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -348,12 +348,12 @@ void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns,
void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Collect a set of patterns to flatten n-D vector transfers on contiguous
/// memref.
/// Collect a set of patterns to flatten/linearize n-D vector transfers on
/// contiguous memref.
///
/// These patterns insert memref.collapse_shape + vector.shape_cast patterns
/// to transform multiple small n-D transfers into a larger 1-D transfer where
/// the memref contiguity properties allow it.
/// to transform a n-D transfer into a larger 1-D transfer where the memref
/// contiguity properties allow it.
///
/// Flattening is only applied if the bitwidth of the trailing vector dimension
/// is smaller or equal to `targetVectorBitwidth`.
Expand All @@ -362,6 +362,28 @@ void populateFlattenVectorTransferPatterns(
unsigned targetVectorBitwidth = std::numeric_limits<unsigned>::max(),
PatternBenefit benefit = 1);

/// Collect a set of patterns to flatten/linearize operations on vectors.
///
/// These patterns insert vector.shape_cast to transform operations to have
/// lower rank operands and results.
///
/// At the start of every pattern's `matchAndRewrite` call, `preCondition`
/// is called. If it returns failure, the pattern is not applied.
///
/// TODO(newling) combine this API with `populateFlattenVectorTransferPatterns`.
void populateForVectorLinearize(
RewritePatternSet &patterns,
const std::function<LogicalResult(Operation *)> &preCondition =
[](Operation *) { return success(); },
PatternBenefit benefit = 1);

/// Collect a set of patterns to rewrite vector.extract_strided_slice and
/// vector.insert_strided_slice operations to have the lowest possible rank.
/// This is done by using shape_cast to combine consecutive dimensions whose
/// memory is contiguous.
void populateForStridedRankReduction(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Collect a set of patterns that bubble up/down bitcast ops.
///
/// These patterns move vector.bitcast ops to be before insert ops or after
Expand Down Expand Up @@ -408,39 +430,6 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);

/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
///
/// Definition: here 'linearization' means converting a single operation with
/// 1+ vector operand/result of rank>1, into a new single operation whose
/// vector operands and results are all of rank<=1.
///
/// This function registers (1) which operations are legal, and hence should not
/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
/// materialze the conversion (with shape_cast)
///
/// Note: the set of legal operations can be extended by a user if for example
/// certain rank>1 vectors are considered valid, by adding additional
/// dynamically legal ops to `conversionTarget`.
///
/// Further note: the choice to use a dialect conversion design for
/// linearization is to make it easy to reuse generic structural type
/// conversions for linearizing scf/cf/func operations
void populateForVectorLinearize(TypeConverter &typeConverter,
ConversionTarget &conversionTarget);

/// Populates `patterns` for ND vector (N >= 2) linearization. This currently
/// contains patterns for converting ConstantLike, Vectorizable, and
/// vector::BitCast ops.
void populateVectorLinearizeBasePatterns(const TypeConverter &,
const ConversionTarget &,
RewritePatternSet &patterns);

/// Populates `patterns` for linearizing ND (N >= 2) vector operations
/// to 1D vector shuffle operations.
void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
const ConversionTarget &,
RewritePatternSet &patterns);

} // namespace vector
} // namespace mlir

Expand Down
46 changes: 32 additions & 14 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3876,7 +3876,11 @@ Type OuterProductOp::getExpectedMaskType() {
static Type inferStridedSliceOpResultType(VectorType vectorType,
ArrayAttr offsets, ArrayAttr sizes,
ArrayAttr strides) {
assert(offsets.size() == sizes.size() && offsets.size() == strides.size());

assert(offsets.size() == sizes.size() &&
"offsets and sizes must be same size");
assert(offsets.size() == strides.size() &&
"offsets and strides must be same size");
SmallVector<int64_t, 4> shape;
shape.reserve(vectorType.getRank());
unsigned idx = 0;
Expand Down Expand Up @@ -5896,13 +5900,21 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {

VectorType resultType = getType();

// No-op shape cast.
if (getSource().getType() == resultType)
return getSource();
// y = shape_cast(shape_cast(shape_cast(x)))
// -> shape_cast(x) # if x and y different types
// -> x # if x and y same type
// Value newSource = getSource();
ShapeCastOp parent = *this;
while (auto precedingShapeCast =
parent.getSource().getDefiningOp<ShapeCastOp>()) {
parent = precedingShapeCast;
}

if (parent.getSource().getType() == resultType)
return parent.getSource();

// shape_cast(shape_cast(x)) -> shape_cast(x)
if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
setOperand(precedingShapeCast.getSource());
if (parent != *this) {
setOperand(parent.getSource());
return getResult();
}

Expand All @@ -5922,14 +5934,20 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
return bcastOp.getSource();
}

// shape_cast(constant) -> constant
if (auto splatAttr =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
return splatAttr.reshape(getType());
Attribute attr = adaptor.getSource();
if (attr) {
// shape_cast(constant) -> constant
if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
return splatAttr.reshape(getType());

// shape_cast(poison) -> poison
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
return ub::PoisonAttr::get(getContext());
if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(attr)) {
return dstElementsAttr.reshape(getType());
}

// shape_cast(poison) -> poison
if (llvm::dyn_cast<ub::PoisonAttr>(attr)) {
return ub::PoisonAttr::get(getContext());
}
}

return {};
Expand Down
Loading