Skip to content

[mlir][linalg] Vectorize directly to a named contraction #147296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2445,6 +2445,8 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
$static_vector_sizes,
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
OptionalAttr<UnitAttr>:$flatten1D_depthwise_conv,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any special reason to add this unrelated option?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for completness.
As I already tweak the transform op with a new option, this one present in linalg::vectorize API was missing here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me add a bit more context.

I added the option to flatten the depthwise convs as an optimisation for convs with low channel dim count. While great for NEON (i.e. fixed width vectors), it's something that's tricky to generalise to scalable vectors. So I deliberately avoided extending the support ( I am waiting to see whether others find it useful).

I am fine with extending this Op, but if we do, we should also add more tests. Your call :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the insight 🙂
In this case, I'll remove the option to limit the scope of this PR.

OptionalAttr<UnitAttr>:$create_named_contraction,
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
$scalable_sizes);

Expand Down
5 changes: 4 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -876,11 +876,14 @@ struct VectorizationResult {
/// greater than or equal to their counterpart iteration space sizes, if static.
/// `inputVectorShapes` also allows the vectorization of operations with dynamic
/// shapes.
/// Optionally, `createNamedContraction` can force compatible contractions to be
/// vectorized directly to vector.contract operation.
FailureOr<VectorizationResult>
vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
ArrayRef<bool> inputScalableVecDims = {},
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
bool createNamedContraction = false);

/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ bool isLinearizableVector(VectorType type);
/// Note: all read offsets are set to 0.
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
ArrayRef<int64_t> inputVectorSizes, Value padValue,
bool useInBoundsInsteadOfMasking = false);
bool useInBoundsInsteadOfMasking = false,
ArrayRef<bool> scalableDims = {});

/// Returns success if `inputVectorSizes` is a valid masking configuraion for
/// given `shape`, i.e., it meets:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3920,7 +3920,9 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
}
FailureOr<VectorizationResult> vectorResults =
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
getVectorizeNdExtract().value_or(false));
getVectorizeNdExtract().value_or(false),
getFlatten1DDepthwiseConv().value_or(false),
getCreateNamedContraction().value_or(false));
if (failed(vectorResults)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Attempted to vectorize, but failed";
Expand Down
110 changes: 103 additions & 7 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down Expand Up @@ -1681,10 +1682,13 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
return write;

// Compute the mask and mask the write Op.
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
vecToStoreType.getScalableDims());

SmallVector<OpFoldResult> destSizes =
tensor::getMixedSizes(builder, loc, dest);
isa<MemRefType>(dest.getType())
? memref::getMixedSizes(builder, loc, dest)
: tensor::getMixedSizes(builder, loc, dest);
SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
destSizes.end());

Expand Down Expand Up @@ -2093,6 +2097,84 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
return success();
}

/// Vectorize a named linalg contraction op into:
/// vector::TransferReadOp - Reads vectors from the operands
/// vector::ContractionOp - Performs contraction
/// vector::TransferWriteOp - Write the result vector back to the
/// destination
/// The operands shapes are preserved and loaded directly into vectors.
/// Any further permutations or numerical casting remain within contraction.
static LogicalResult
vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
LinalgOp linalgOp,
SmallVectorImpl<Value> &newResults) {
Location loc = linalgOp.getLoc();
MLIRContext *ctx = linalgOp.getContext();

if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
return failure();

OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
Operation *reduceOp = matchLinalgReduction(outOperand);
auto maybeKind = getCombinerOpKind(reduceOp);
if (!maybeKind)
return failure();

// Check that all dimensions are present in the input operands.
// Arbitrary broadcasts are not supported by the vector contraction.
// Broadcasts are expected to be materialized before vectorization.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the broadcast vectorization work in tandem with this one? Or can you call pinpoint vectorization on the contract and not on the surrounding code (say, a transform)?

If the latter, then we may (at some point later) validate the producers and consumers to make sure the vectorization won't break anything around.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes to both. You can also vectorize selectively (either in a pass or a transform).

In general, the current vectorizer rewrites one op at the time. It creates read and write ops at the boundaries exactly to ensure seamless transition between a vectorized op and its consumers and producers.
At tensor level, these read-write pairs can easily cancel out thanks to value semantics. In memrefs, cleanup is obviously tricker due to possible aliasing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At tensor level, these read-write pairs can easily cancel out thanks to value semantics. In memrefs, cleanup is obviously tricker due to possible aliasing.

Yes, that's my worry, but I guess it's up to the transform user to know memrefs are harder and adjust strategy. This won't be the only problem they'll have with memrefs anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't want to guess how to materialize such broadcasts. We could just pick one default broadcasting scheme (e.g., canonical vector shape like in the linalg generic vectorizer) but it's likely to be suboptimal too.

Perhaps this broadcasting information could be encoded in the operand's layout? Sth I could experiment with later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I am missing "materialize" context in the vectorization concept. Is it a blocker to make the option default? What does materializing broadcasts mean? Is it breaking a matmul into something like broadcast(LHS)->matmul?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a blocker to make the option default?

Not necessarily but it'd be good to align if broadcasts should be handled at all. If yes, then how.

What does materializing broadcasts mean? Is it breaking a matmul into something like broadcast(LHS)->matmul?

Correct. Today, broadcast semantics can't be preserved and kept within vector.contract so the extra dimension has to be created somewhere.
I'll elaborate more in a separate answer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My terminology would be decompose, because it is what we use for pack/unpack/pad/etc ops in upstream. If we break an op into a sequence of simpler ops, I'd call it decomposition. E.g., DecomposeGenericByUnfoldingPermutation,

/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
/// InsertSliceOp. For now, only constant padding values are supported.
struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
DecomposePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<tensor::PadOp>(context, benefit) {}
LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override;
protected:
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp,
Value dest,
const SmallVector<Value> &dynSizes) const;
};
/// Rewrites a linalg::PackOp into a sequence of:
/// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp +
/// tensor::InsertSliceOp ops.
///
/// Requires that all the outer dims of the input linalg::PackOp are 1.
///
/// Before:
/// ```
/// %packed = linalg.pack %input
/// padding_value(%pad : f32)
/// inner_dims_pos = [1, 0]
/// inner_tiles = [2, %high]
/// into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
/// ```
///
/// After:
/// ```
/// // PadOp
/// %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
/// ^bb0(...):
/// tensor.yield %arg2 : f32
/// } : tensor<5x1xf32> to tensor<?x2xf32>
/// // EmptyOp + TransposeOp
/// %empty = tensor.empty(%arg3) : tensor<2x?xf32>
/// %transposed = linalg.transpose
/// ins(%extracted_slice : tensor<?x2xf32>)
/// outs(%empty : tensor<2x?xf32>)
/// permutation = [1, 0]
/// // InsertSliceOp
/// %inserted_slice = tensor.insert_slice %transposed
/// into %arg1[0, 0, 0, 0] [1, 1, 2, %tile_dim_1] [1, 1, 1, 1]
/// : tensor<2x?xf32> into tensor<1x1x2x?xf32>
/// ```
struct DecomposeOuterUnitDimsPackOpPattern
: public OpRewritePattern<linalg::PackOp> {
using OpRewritePattern<linalg::PackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::PackOp packOp,
PatternRewriter &rewriter) const override;
};
/// Rewrites a linalg::UnPackOp into a sequence of rank-reduced
/// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
///
/// Requires that all the outer dims of the input linalg::PackOp are 1.
///
/// Before:
/// ```
/// %packed = linalg.unpack %input
/// inner_dims_pos = [1, 0]
/// inner_tiles = [2, 8]
/// into %output : tensor<1x1x2x8xf32> -> tensor<5x1xf32>
/// ```
///
/// After:
/// ```
/// // Rank-reduced extract to obtain the tile
/// %slice = tensor.extract_slice %arg0[0, 0, 0, 0] [1, 1, 2, 8] [1, 1, 1, 1]
/// : tensor<1x1x2x8xf32> to tensor<2x8xf32>
/// // EmptyOp + TransposeOp
/// %init = tensor.empty() : tensor<8x2xf32>
/// %transposed = linalg.transpose
/// ins(%extracted_slice : tensor<2x8xf32>)
/// outs(%0 : tensor<8x2xf32>) permutation = [1, 0]
/// // Extract a slice matching the specified output size
/// %result = tensor.extract_slice %transposed[0, 0] [5, 1] [1, 1]
/// : tensor<8x2xf32> to tensor<5x1xf32>
/// ```
struct DecomposeOuterUnitDimsUnPackOpPattern
: public OpRewritePattern<linalg::UnPackOp> {
using OpRewritePattern<linalg::UnPackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp,
PatternRewriter &rewriter) const override;
};

AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
return failure();

// Load operands.
SmallVector<Value> vecOperands;
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
// The operand vector shape is computed by mapping the canonical vector
// shape to the operand's domain. Further permutations are left as a part of
// the contraction.
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
AffineMap readMap = AffineMap::getMultiDimIdentityMap(
indexingMap.getNumResults(), rewriter.getContext());
Type elemType = getElementTypeOrSelf(opOperand.get());
VectorType readType =
state.getCanonicalVecType(elemType, readMap.compose(indexingMap));

Value read = mlir::vector::createReadOrMaskedRead(
rewriter, loc, opOperand.get(), readType.getShape(),
/*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
/*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
vecOperands.push_back(read);
}

// Remap iterators from linalg to vector.
SmallVector<Attribute> iterAttrs;
auto iterators = linalgOp.getIteratorTypesArray();
for (utils::IteratorType iter : iterators) {
auto vecIter = iter == utils::IteratorType::parallel
? vector::IteratorType::parallel
: vector::IteratorType::reduction;
iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
}

// Create contraction.
Value contractOp = rewriter.create<vector::ContractionOp>(
loc, /*lhs=*/vecOperands[0],
/*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);

// Store result.
Operation *write =
createWriteOrMaskedWrite(rewriter, loc, contractOp, outOperand->get());

// Finalize.
if (!write->getResults().empty())
newResults.push_back(write->getResult(0));

return success();
}

namespace {
enum class ConvOperationKind { Conv, Pool };
} // namespace
Expand Down Expand Up @@ -2528,11 +2610,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
tensor::InsertSliceOp>(op);
}

FailureOr<VectorizationResult>
mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
FailureOr<VectorizationResult> mlir::linalg::vectorize(
RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
bool flatten1DDepthwiseConv, bool createNamedContraction) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
Expand Down Expand Up @@ -2578,6 +2659,21 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return failure();
}

// For simplicity, contraction vectorization is limited to linalg
// named ops. Generic op is ignored as not every arbitrary
// contraction body can be expressed by a vector.contract.
if (createNamedContraction &&
isa<ContractionOpInterface>(linalgOp.getOperation())) {
// Attempt vectorizing directly into a named contraction.
// In case of failure, fall back to the generic path.
LogicalResult res = vectorizeAsLinalgContraction(
rewriter, state, linalgOp, results);
if (succeeded(res))
return success();

LDBG("Failed to vectorize as a named contraction.\n");
}

LDBG("Vectorize generic by broadcasting to the canonical vector "
"shape\n");

Expand Down
13 changes: 9 additions & 4 deletions mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,16 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
Value source,
ArrayRef<int64_t> inputVectorSizes,
Value padValue,
bool useInBoundsInsteadOfMasking) {
bool useInBoundsInsteadOfMasking,
ArrayRef<bool> scalableDims) {
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
"invalid input vector sizes");
auto sourceShapedType = cast<ShapedType>(source.getType());
auto sourceShape = sourceShapedType.getShape();
assert(sourceShape.size() == inputVectorSizes.size() &&
"expected same ranks.");
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
auto vectorType =
VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
assert(padValue.getType() == sourceShapedType.getElementType() &&
"expected same pad element type to match source element type");
int64_t readRank = inputVectorSizes.size();
Expand All @@ -352,9 +354,12 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
return transferReadOp;
SmallVector<OpFoldResult> mixedSourceDims =
tensor::getMixedSizes(builder, loc, source);
isa<MemRefType>(source.getType())
? memref::getMixedSizes(builder, loc, source)
: tensor::getMixedSizes(builder, loc, source);

auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
auto maskType =
VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
Value mask =
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
return mlir::vector::maskOperation(builder, transferReadOp, mask)
Expand Down
Loading
Loading