Skip to content

[mlir][linalg] Add support for scalable vectorization of linalg.mmt4d #146531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2440,12 +2440,11 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
}];

let arguments = (ins TransformHandleTypeInterface:$target,
Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
$static_vector_sizes,
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
$scalable_sizes);
Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
OptionalAttr<UnitAttr>:$assume_dynamic_dims_match_vec_sizes,
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);

let results = (outs);

Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,8 @@ FailureOr<VectorizationResult>
vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
ArrayRef<bool> inputScalableVecDims = {},
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
bool assumeDynamicDimsMatchVecSizes = false);

/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3921,7 +3921,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
}
FailureOr<VectorizationResult> vectorResults =
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
getVectorizeNdExtract().value_or(false));
getVectorizeNdExtract().value_or(false), false,
getAssumeDynamicDimsMatchVecSizes().value_or(false));
if (failed(vectorResults)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Attempted to vectorize, but failed";
Expand Down
51 changes: 39 additions & 12 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some changes on comments are not relevant, can you revert them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how that crept it. Fixed in this commit.

Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ struct VectorizationState {
/// canonical vector shape for vectorization.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims);
ArrayRef<bool> inputScalableVecDims,
bool assumeDynamicDimsMatchVecSizes = false);

/// Returns the canonical vector shape used to vectorize the iteration space.
ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
Expand Down Expand Up @@ -331,6 +332,14 @@ struct VectorizationState {
/// Global vectorization guard for the incoming rewriter. It's initialized
/// when the vectorization state is initialized.
OpBuilder::InsertionGuard rewriterGuard;

/// Do all dynamic dims match the corresponding vector sizes?
///
/// When a dynamic tensor/memref dimension matches the corresponding vector
/// dimension, masking can be safely skipped, despite the presence of dynamic
/// shapes. Use this flag with care and only for cases where you are
/// confident the assumption holds.
bool assumeDynamicDimsMatchVecSizes = false;
};

LogicalResult
Expand Down Expand Up @@ -367,10 +376,12 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
/// Initializes the vectorization state, including the computation of the
/// canonical vector shape for vectorization.
// TODO: Move this to the constructor when we can remove the failure cases.
LogicalResult
VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims) {
LogicalResult VectorizationState::initState(RewriterBase &rewriter,
LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
bool assumeDimsMatchVec) {
assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
// Initialize the insertion point.
rewriter.setInsertionPoint(linalgOp);

Expand Down Expand Up @@ -470,6 +481,21 @@ Value VectorizationState::getOrCreateMaskFor(
return Value();
}

if (assumeDynamicDimsMatchVecSizes) {
// Given that all _scalable vector sizes_ match the corresponding
// memref/tensor dim sizes, masking can be skipped provided that:
// * all vector sizes corresponding to dynamic dims are scalable.
if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getScalableDims()),
[](auto it) {
return std::get<0>(it) == ShapedType::kDynamic
? std::get<1>(it)
: false;
Comment on lines +490 to +492
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I misinterpreting this for some reason or does this imply: "if every size in permutedStaticSizes dynamic and the corresponding dim is scalable". Also, it just prints the debug print and does not guard the following lines that set the mask to nothing and return

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

}))
LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
activeMaskCache[maskingMap] = Value();
return Value();
}

// Permute the iteration space value sizes to compute the mask upper bounds.
SmallVector<Value> upperBounds =
applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
Expand Down Expand Up @@ -2479,7 +2505,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
isa<linalg::MatmulTransposeAOp>(op) ||
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if there is a particular reason why this wouldn't work for batch_mmt4d as well?

hasReductionIterator(linalgOp));
}

LogicalResult mlir::linalg::vectorizeOpPrecondition(
Expand Down Expand Up @@ -2535,11 +2562,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
tensor::InsertSliceOp>(op);
}

FailureOr<VectorizationResult>
mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
FailureOr<VectorizationResult> mlir::linalg::vectorize(
RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
Expand All @@ -2559,7 +2585,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
VectorizationState state(rewriter);
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
inputScalableVecDims))) {
inputScalableVecDims,
assumeDynamicDimsMatchVecSizes))) {
LDBG("Vectorization state couldn't be initialized\n");
return failure();
}
Expand Down
117 changes: 93 additions & 24 deletions mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,99 @@ module attributes {transform.with_named_sequence} {
}
}

// -----

///----------------------------------------------------------------------------------------
/// Tests for linalg.mmt4d
///----------------------------------------------------------------------------------------

func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
outs(%C_in: memref<16x16x8x8xf32>)
return
}

// CHECK-LABEL: func.func @mmt4d(
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %mmt4d : !transform.any_op
transform.yield
}
}

// -----

func.func @mmt4d_scalable(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
outs(%C_in: memref<16x16x8x?xf32>)
return
}
// CHECK-LABEL: func.func @mmt4d_scalable(
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>,
// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
// CHECK: %[[VAL_0:.*]] = arith.constant 16 : index
// CHECK: %[[VAL_1:.*]] = arith.constant 16 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 16 : index
// CHECK: %[[C8:.*]] = arith.constant 8 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[DIM_2:.*]] = memref.dim %[[B]], %[[C2]] : memref<16x16x?x1xf32>
// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[MASK_1:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x[4]x1xi1>
// CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> } : vector<16x16x[4]x1xi1> -> vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
// CHECK: %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
// CHECK: %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x16x8x[4]x1xi1>
// CHECK: %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16]], %[[VAL_15]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
// CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>


module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] : !transform.any_op
transform.yield
}
}

// -----

func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
outs(%C_in: memref<16x16x8x?xf32>)
return
}
// CHECK-LABEL: func.func @mmt4d_scalable_with_assume(
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>,
// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
// CHECK-NOT: mask
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
// CHECK: %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
// CHECK: %[[VAL_15:.*]] = vector.multi_reduction <add>, %[[VAL_14]], %[[VAL_13]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
// CHECK: vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] {assume_scalable_sizes_match_dim_size} : !transform.any_op
transform.yield
}
}

///----------------------------------------------------------------------------------------
/// Tests for other Ops
///----------------------------------------------------------------------------------------
Expand Down Expand Up @@ -1094,30 +1187,6 @@ module attributes {transform.with_named_sequence} {
}
}

// -----

func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
outs(%C_in: memref<16x16x8x8xf32>)
return
}

// CHECK-LABEL: func.func @mmt4d(
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %mmt4d : !transform.any_op
transform.yield
}
}

// -----

Expand Down
Loading