-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][linalg] Vectorize directly to a named contraction #147296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Extends linalg vectorizer with a path to lower contraction ops directly into `vector.contract`. The direct rewriting preserves high-level op semantics and provides more progressive lowering compared to reconstructing contraction back from multi dimensional reduction. The added lowering focuses on named linalg ops and leverages their well defined semantics to avoid complex precondition verification. The new path is optional and disabled by default to avoid changing the default vectorizer behavior.
@llvm/pr-subscribers-mlir-linalg Author: Adam Siemieniuk (adam-smnk) ChangesExtends linalg vectorizer with a path to lower contraction ops directly into The direct rewriting preserves high-level op semantics and provides more progressive lowering compared to reconstructing contraction back from multi dimensional reduction. The new path is optional and disabled by default to avoid changing the default vectorizer behavior. Patch is 32.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147296.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 4360055e78691..3bbde1240286c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2445,6 +2445,8 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
$static_vector_sizes,
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
+ OptionalAttr<UnitAttr>:$flatten1D_depthwise_conv,
+ OptionalAttr<UnitAttr>:$create_named_contraction,
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
$scalable_sizes);
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 74280fdd82f4e..eee856f3eba68 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -876,11 +876,14 @@ struct VectorizationResult {
/// greater than or equal to their counterpart iteration space sizes, if static.
/// `inputVectorShapes` also allows the vectorization of operations with dynamic
/// shapes.
+/// Optionally, `createNamedContraction` can force compatible contractions to be
+/// vectorized directly to vector.contract operation.
FailureOr<VectorizationResult>
vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
ArrayRef<bool> inputScalableVecDims = {},
- bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
+ bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
+ bool createNamedContraction = false);
/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index cc8421b23a074..9b765d0b8ede6 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -226,7 +226,8 @@ bool isLinearizableVector(VectorType type);
/// Note: all read offsets are set to 0.
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
ArrayRef<int64_t> inputVectorSizes, Value padValue,
- bool useInBoundsInsteadOfMasking = false);
+ bool useInBoundsInsteadOfMasking = false,
+ ArrayRef<bool> scalableDims = {});
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
/// given `shape`, i.e., it meets:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7ec3f3445281a..1038615f21b17 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3920,7 +3920,9 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
}
FailureOr<VectorizationResult> vectorResults =
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
- getVectorizeNdExtract().value_or(false));
+ getVectorizeNdExtract().value_or(false),
+ getFlatten1DDepthwiseConv().value_or(false),
+ getCreateNamedContraction().value_or(false));
if (failed(vectorResults)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Attempted to vectorize, but failed";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2864c5b807d69..bb1454cf43932 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -25,6 +25,7 @@
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -1681,10 +1682,13 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
return write;
// Compute the mask and mask the write Op.
- auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+ auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
+ vecToStoreType.getScalableDims());
SmallVector<OpFoldResult> destSizes =
- tensor::getMixedSizes(builder, loc, dest);
+ isa<MemRefType>(dest.getType())
+ ? memref::getMixedSizes(builder, loc, dest)
+ : tensor::getMixedSizes(builder, loc, dest);
SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
destSizes.end());
@@ -2093,6 +2097,84 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
return success();
}
+/// Vectorize a named linalg contraction op into:
+/// vector::TransferReadOp - Reads vectors from the operands
+/// vector::ContractionOp - Performs contraction
+/// vector::TransferWriteOp - Write the result vector back to the
+/// destination
+/// The operands shapes are preserved and loaded directly into vectors.
+/// Any further permutations or numerical casting remain within contraction.
+static LogicalResult
+vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
+ LinalgOp linalgOp,
+ SmallVectorImpl<Value> &newResults) {
+ Location loc = linalgOp.getLoc();
+ MLIRContext *ctx = linalgOp.getContext();
+
+ if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
+ return failure();
+
+ OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
+ Operation *reduceOp = matchLinalgReduction(outOperand);
+ auto maybeKind = getCombinerOpKind(reduceOp);
+ if (!maybeKind)
+ return failure();
+
+ // Check that all dimensions are present in the input operands.
+ // Arbitrary broadcasts are not supported by the vector contraction.
+ // Broadcasts are expected to be materialized before vectorization.
+ AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
+ AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
+ if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
+ return failure();
+
+ // Load operands.
+ SmallVector<Value> vecOperands;
+ for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+ // The operand vector shape is computed by mapping the canonical vector
+ // shape to the operand's domain. Further permutations are left as a part of
+ // the contraction.
+ AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
+ AffineMap readMap = AffineMap::getMultiDimIdentityMap(
+ indexingMap.getNumResults(), rewriter.getContext());
+ Type elemType = getElementTypeOrSelf(opOperand.get());
+ VectorType readType =
+ state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
+
+ Value read = mlir::vector::createReadOrMaskedRead(
+ rewriter, loc, opOperand.get(), readType.getShape(),
+ /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
+ /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
+ vecOperands.push_back(read);
+ }
+
+ // Remap iterators from linalg to vector.
+ SmallVector<Attribute> iterAttrs;
+ auto iterators = linalgOp.getIteratorTypesArray();
+ for (utils::IteratorType iter : iterators) {
+ auto vecIter = iter == utils::IteratorType::parallel
+ ? vector::IteratorType::parallel
+ : vector::IteratorType::reduction;
+ iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
+ }
+
+ // Create contraction.
+ Value contractOp = rewriter.create<vector::ContractionOp>(
+ loc, /*lhs=*/vecOperands[0],
+ /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
+ linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
+
+ // Store result.
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, contractOp, outOperand->get());
+
+ // Finalize.
+ if (!write->getResults().empty())
+ newResults.push_back(write->getResult(0));
+
+ return success();
+}
+
namespace {
enum class ConvOperationKind { Conv, Pool };
} // namespace
@@ -2528,11 +2610,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
tensor::InsertSliceOp>(op);
}
-FailureOr<VectorizationResult>
-mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
- ArrayRef<int64_t> inputVectorSizes,
- ArrayRef<bool> inputScalableVecDims,
- bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
+FailureOr<VectorizationResult> mlir::linalg::vectorize(
+ RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
+ ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
+ bool flatten1DDepthwiseConv, bool createNamedContraction) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2578,6 +2659,21 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return failure();
}
+ // For simplicity, contraction vectorization is limited to linalg
+ // named ops. Generic op is ignored as not every arbitrary
+ // contraction body can be expressed by a vector.contract.
+ if (createNamedContraction &&
+ isa<ContractionOpInterface>(linalgOp.getOperation())) {
+ // Attempt vectorizing directly into a named contraction.
+ // In case of failure, fall back to the generic path.
+ LogicalResult res = vectorizeAsLinalgContraction(
+ rewriter, state, linalgOp, results);
+ if (succeeded(res))
+ return success();
+
+ LDBG("Failed to vectorize as a named contraction.\n");
+ }
+
LDBG("Vectorize generic by broadcasting to the canonical vector "
"shape\n");
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 674a93d7c520e..594bfce57e598 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -320,14 +320,16 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
Value source,
ArrayRef<int64_t> inputVectorSizes,
Value padValue,
- bool useInBoundsInsteadOfMasking) {
+ bool useInBoundsInsteadOfMasking,
+ ArrayRef<bool> scalableDims) {
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
"invalid input vector sizes");
auto sourceShapedType = cast<ShapedType>(source.getType());
auto sourceShape = sourceShapedType.getShape();
assert(sourceShape.size() == inputVectorSizes.size() &&
"expected same ranks.");
- auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+ auto vectorType =
+ VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
assert(padValue.getType() == sourceShapedType.getElementType() &&
"expected same pad element type to match source element type");
int64_t readRank = inputVectorSizes.size();
@@ -352,9 +354,12 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
return transferReadOp;
SmallVector<OpFoldResult> mixedSourceDims =
- tensor::getMixedSizes(builder, loc, source);
+ isa<MemRefType>(source.getType())
+ ? memref::getMixedSizes(builder, loc, source)
+ : tensor::getMixedSizes(builder, loc, source);
- auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
+ auto maskType =
+ VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
Value mask =
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
new file mode 100644
index 0000000000000..1831acf092afb
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
@@ -0,0 +1,400 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
+
+func.func @matmul(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>,
+ %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ %0 = linalg.matmul
+ ins(%A, %B : tensor<8x4xf32>, tensor<4x16xf32>)
+ outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+ return %0 : tensor<8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul(
+// CHECK-SAME: %[[A:.*]]: tensor<8x4xf32>, %[[B:.*]]: tensor<4x16xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<8x16xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<8x4xf32>, vector<8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<4x16xf32>, vector<4x16xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic(
+// CHECK-SAME: %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<?x?xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x16xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+ {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_dynamic_memref(%A: memref<?x?xf32>, %B: memref<?x?xf32>,
+ %C: memref<?x?xf32>) {
+ linalg.matmul
+ ins(%A, %B : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%C: memref<?x?xf32>)
+ return
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic_memref(
+// CHECK-SAME: %[[A:.*]]: memref<?x?xf32>, %[[B:.*]]: memref<?x?xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<?x?xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: memref<?x?xf32>, vector<8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: memref<?x?xf32>, vector<4x16xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: memref<?x?xf32>, vector<8x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, memref<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+ {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_scalable(
+// CHECK-SAME: %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<?x?xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x[16]xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x[16]xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x[16]xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [8, [16], 4]
+ {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_transpose(%A: tensor<4x8xf32>, %B: tensor<16x4xf32>,
+ %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ %0 = linalg.matmul
+ indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose A
+ affine_map<(m, n, k) -> (n, k)>, // transpose B
+ affine_...
[truncated]
|
@llvm/pr-subscribers-mlir-vector Author: Adam Siemieniuk (adam-smnk) ChangesExtends linalg vectorizer with a path to lower contraction ops directly into The direct rewriting preserves high-level op semantics and provides more progressive lowering compared to reconstructing contraction back from multi dimensional reduction. The new path is optional and disabled by default to avoid changing the default vectorizer behavior. Patch is 32.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147296.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 4360055e78691..3bbde1240286c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2445,6 +2445,8 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
$static_vector_sizes,
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
+ OptionalAttr<UnitAttr>:$flatten1D_depthwise_conv,
+ OptionalAttr<UnitAttr>:$create_named_contraction,
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
$scalable_sizes);
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 74280fdd82f4e..eee856f3eba68 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -876,11 +876,14 @@ struct VectorizationResult {
/// greater than or equal to their counterpart iteration space sizes, if static.
/// `inputVectorShapes` also allows the vectorization of operations with dynamic
/// shapes.
+/// Optionally, `createNamedContraction` can force compatible contractions to be
+/// vectorized directly to vector.contract operation.
FailureOr<VectorizationResult>
vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
ArrayRef<bool> inputScalableVecDims = {},
- bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
+ bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
+ bool createNamedContraction = false);
/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index cc8421b23a074..9b765d0b8ede6 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -226,7 +226,8 @@ bool isLinearizableVector(VectorType type);
/// Note: all read offsets are set to 0.
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
ArrayRef<int64_t> inputVectorSizes, Value padValue,
- bool useInBoundsInsteadOfMasking = false);
+ bool useInBoundsInsteadOfMasking = false,
+ ArrayRef<bool> scalableDims = {});
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
/// given `shape`, i.e., it meets:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7ec3f3445281a..1038615f21b17 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3920,7 +3920,9 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
}
FailureOr<VectorizationResult> vectorResults =
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
- getVectorizeNdExtract().value_or(false));
+ getVectorizeNdExtract().value_or(false),
+ getFlatten1DDepthwiseConv().value_or(false),
+ getCreateNamedContraction().value_or(false));
if (failed(vectorResults)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Attempted to vectorize, but failed";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2864c5b807d69..bb1454cf43932 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -25,6 +25,7 @@
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -1681,10 +1682,13 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
return write;
// Compute the mask and mask the write Op.
- auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+ auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
+ vecToStoreType.getScalableDims());
SmallVector<OpFoldResult> destSizes =
- tensor::getMixedSizes(builder, loc, dest);
+ isa<MemRefType>(dest.getType())
+ ? memref::getMixedSizes(builder, loc, dest)
+ : tensor::getMixedSizes(builder, loc, dest);
SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
destSizes.end());
@@ -2093,6 +2097,84 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
return success();
}
+/// Vectorize a named linalg contraction op into:
+/// vector::TransferReadOp - Reads vectors from the operands
+/// vector::ContractionOp - Performs contraction
+/// vector::TransferWriteOp - Write the result vector back to the
+/// destination
+/// The operands shapes are preserved and loaded directly into vectors.
+/// Any further permutations or numerical casting remain within contraction.
+static LogicalResult
+vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
+ LinalgOp linalgOp,
+ SmallVectorImpl<Value> &newResults) {
+ Location loc = linalgOp.getLoc();
+ MLIRContext *ctx = linalgOp.getContext();
+
+ if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
+ return failure();
+
+ OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
+ Operation *reduceOp = matchLinalgReduction(outOperand);
+ auto maybeKind = getCombinerOpKind(reduceOp);
+ if (!maybeKind)
+ return failure();
+
+ // Check that all dimensions are present in the input operands.
+ // Arbitrary broadcasts are not supported by the vector contraction.
+ // Broadcasts are expected to be materialized before vectorization.
+ AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
+ AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
+ if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
+ return failure();
+
+ // Load operands.
+ SmallVector<Value> vecOperands;
+ for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+ // The operand vector shape is computed by mapping the canonical vector
+ // shape to the operand's domain. Further permutations are left as a part of
+ // the contraction.
+ AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
+ AffineMap readMap = AffineMap::getMultiDimIdentityMap(
+ indexingMap.getNumResults(), rewriter.getContext());
+ Type elemType = getElementTypeOrSelf(opOperand.get());
+ VectorType readType =
+ state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
+
+ Value read = mlir::vector::createReadOrMaskedRead(
+ rewriter, loc, opOperand.get(), readType.getShape(),
+ /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
+ /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
+ vecOperands.push_back(read);
+ }
+
+ // Remap iterators from linalg to vector.
+ SmallVector<Attribute> iterAttrs;
+ auto iterators = linalgOp.getIteratorTypesArray();
+ for (utils::IteratorType iter : iterators) {
+ auto vecIter = iter == utils::IteratorType::parallel
+ ? vector::IteratorType::parallel
+ : vector::IteratorType::reduction;
+ iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
+ }
+
+ // Create contraction.
+ Value contractOp = rewriter.create<vector::ContractionOp>(
+ loc, /*lhs=*/vecOperands[0],
+ /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
+ linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
+
+ // Store result.
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, contractOp, outOperand->get());
+
+ // Finalize.
+ if (!write->getResults().empty())
+ newResults.push_back(write->getResult(0));
+
+ return success();
+}
+
namespace {
enum class ConvOperationKind { Conv, Pool };
} // namespace
@@ -2528,11 +2610,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
tensor::InsertSliceOp>(op);
}
-FailureOr<VectorizationResult>
-mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
- ArrayRef<int64_t> inputVectorSizes,
- ArrayRef<bool> inputScalableVecDims,
- bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
+FailureOr<VectorizationResult> mlir::linalg::vectorize(
+ RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
+ ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
+ bool flatten1DDepthwiseConv, bool createNamedContraction) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2578,6 +2659,21 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return failure();
}
+ // For simplicity, contraction vectorization is limited to linalg
+ // named ops. Generic op is ignored as not every arbitrary
+ // contraction body can be expressed by a vector.contract.
+ if (createNamedContraction &&
+ isa<ContractionOpInterface>(linalgOp.getOperation())) {
+ // Attempt vectorizing directly into a named contraction.
+ // In case of failure, fall back to the generic path.
+ LogicalResult res = vectorizeAsLinalgContraction(
+ rewriter, state, linalgOp, results);
+ if (succeeded(res))
+ return success();
+
+ LDBG("Failed to vectorize as a named contraction.\n");
+ }
+
LDBG("Vectorize generic by broadcasting to the canonical vector "
"shape\n");
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 674a93d7c520e..594bfce57e598 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -320,14 +320,16 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
Value source,
ArrayRef<int64_t> inputVectorSizes,
Value padValue,
- bool useInBoundsInsteadOfMasking) {
+ bool useInBoundsInsteadOfMasking,
+ ArrayRef<bool> scalableDims) {
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
"invalid input vector sizes");
auto sourceShapedType = cast<ShapedType>(source.getType());
auto sourceShape = sourceShapedType.getShape();
assert(sourceShape.size() == inputVectorSizes.size() &&
"expected same ranks.");
- auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+ auto vectorType =
+ VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
assert(padValue.getType() == sourceShapedType.getElementType() &&
"expected same pad element type to match source element type");
int64_t readRank = inputVectorSizes.size();
@@ -352,9 +354,12 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
return transferReadOp;
SmallVector<OpFoldResult> mixedSourceDims =
- tensor::getMixedSizes(builder, loc, source);
+ isa<MemRefType>(source.getType())
+ ? memref::getMixedSizes(builder, loc, source)
+ : tensor::getMixedSizes(builder, loc, source);
- auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
+ auto maskType =
+ VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
Value mask =
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
new file mode 100644
index 0000000000000..1831acf092afb
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
@@ -0,0 +1,400 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
+
+func.func @matmul(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>,
+ %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ %0 = linalg.matmul
+ ins(%A, %B : tensor<8x4xf32>, tensor<4x16xf32>)
+ outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+ return %0 : tensor<8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul(
+// CHECK-SAME: %[[A:.*]]: tensor<8x4xf32>, %[[B:.*]]: tensor<4x16xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<8x16xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<8x4xf32>, vector<8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<4x16xf32>, vector<4x16xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic(
+// CHECK-SAME: %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<?x?xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x16xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+ {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_dynamic_memref(%A: memref<?x?xf32>, %B: memref<?x?xf32>,
+ %C: memref<?x?xf32>) {
+ linalg.matmul
+ ins(%A, %B : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%C: memref<?x?xf32>)
+ return
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic_memref(
+// CHECK-SAME: %[[A:.*]]: memref<?x?xf32>, %[[B:.*]]: memref<?x?xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<?x?xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: memref<?x?xf32>, vector<8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: memref<?x?xf32>, vector<4x16xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: memref<?x?xf32>, vector<8x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, memref<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+ {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_scalable(
+// CHECK-SAME: %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<?x?xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x[16]xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x[16]xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x[16]xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [8, [16], 4]
+ {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_transpose(%A: tensor<4x8xf32>, %B: tensor<16x4xf32>,
+ %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ %0 = linalg.matmul
+ indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose A
+ affine_map<(m, n, k) -> (n, k)>, // transpose B
+ affine_...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Adam Siemieniuk (adam-smnk) ChangesExtends linalg vectorizer with a path to lower contraction ops directly into The direct rewriting preserves high-level op semantics and provides more progressive lowering compared to reconstructing contraction back from multi dimensional reduction. The new path is optional and disabled by default to avoid changing the default vectorizer behavior. Patch is 32.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147296.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 4360055e78691..3bbde1240286c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2445,6 +2445,8 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
$static_vector_sizes,
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
+ OptionalAttr<UnitAttr>:$flatten1D_depthwise_conv,
+ OptionalAttr<UnitAttr>:$create_named_contraction,
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
$scalable_sizes);
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 74280fdd82f4e..eee856f3eba68 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -876,11 +876,14 @@ struct VectorizationResult {
/// greater than or equal to their counterpart iteration space sizes, if static.
/// `inputVectorShapes` also allows the vectorization of operations with dynamic
/// shapes.
+/// Optionally, `createNamedContraction` can force compatible contractions to be
+/// vectorized directly to vector.contract operation.
FailureOr<VectorizationResult>
vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
ArrayRef<bool> inputScalableVecDims = {},
- bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
+ bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
+ bool createNamedContraction = false);
/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index cc8421b23a074..9b765d0b8ede6 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -226,7 +226,8 @@ bool isLinearizableVector(VectorType type);
/// Note: all read offsets are set to 0.
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
ArrayRef<int64_t> inputVectorSizes, Value padValue,
- bool useInBoundsInsteadOfMasking = false);
+ bool useInBoundsInsteadOfMasking = false,
+ ArrayRef<bool> scalableDims = {});
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
/// given `shape`, i.e., it meets:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7ec3f3445281a..1038615f21b17 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3920,7 +3920,9 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
}
FailureOr<VectorizationResult> vectorResults =
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
- getVectorizeNdExtract().value_or(false));
+ getVectorizeNdExtract().value_or(false),
+ getFlatten1DDepthwiseConv().value_or(false),
+ getCreateNamedContraction().value_or(false));
if (failed(vectorResults)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Attempted to vectorize, but failed";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2864c5b807d69..bb1454cf43932 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -25,6 +25,7 @@
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -1681,10 +1682,13 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
return write;
// Compute the mask and mask the write Op.
- auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+ auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
+ vecToStoreType.getScalableDims());
SmallVector<OpFoldResult> destSizes =
- tensor::getMixedSizes(builder, loc, dest);
+ isa<MemRefType>(dest.getType())
+ ? memref::getMixedSizes(builder, loc, dest)
+ : tensor::getMixedSizes(builder, loc, dest);
SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
destSizes.end());
@@ -2093,6 +2097,84 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
return success();
}
+/// Vectorize a named linalg contraction op into:
+/// vector::TransferReadOp - Reads vectors from the operands
+/// vector::ContractionOp - Performs contraction
+/// vector::TransferWriteOp - Write the result vector back to the
+/// destination
+/// The operands shapes are preserved and loaded directly into vectors.
+/// Any further permutations or numerical casting remain within contraction.
+static LogicalResult
+vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
+ LinalgOp linalgOp,
+ SmallVectorImpl<Value> &newResults) {
+ Location loc = linalgOp.getLoc();
+ MLIRContext *ctx = linalgOp.getContext();
+
+ if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
+ return failure();
+
+ OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
+ Operation *reduceOp = matchLinalgReduction(outOperand);
+ auto maybeKind = getCombinerOpKind(reduceOp);
+ if (!maybeKind)
+ return failure();
+
+ // Check that all dimensions are present in the input operands.
+ // Arbitrary broadcasts are not supported by the vector contraction.
+ // Broadcasts are expected to be materialized before vectorization.
+ AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
+ AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
+ if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
+ return failure();
+
+ // Load operands.
+ SmallVector<Value> vecOperands;
+ for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+ // The operand vector shape is computed by mapping the canonical vector
+ // shape to the operand's domain. Further permutations are left as a part of
+ // the contraction.
+ AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
+ AffineMap readMap = AffineMap::getMultiDimIdentityMap(
+ indexingMap.getNumResults(), rewriter.getContext());
+ Type elemType = getElementTypeOrSelf(opOperand.get());
+ VectorType readType =
+ state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
+
+ Value read = mlir::vector::createReadOrMaskedRead(
+ rewriter, loc, opOperand.get(), readType.getShape(),
+ /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
+ /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
+ vecOperands.push_back(read);
+ }
+
+ // Remap iterators from linalg to vector.
+ SmallVector<Attribute> iterAttrs;
+ auto iterators = linalgOp.getIteratorTypesArray();
+ for (utils::IteratorType iter : iterators) {
+ auto vecIter = iter == utils::IteratorType::parallel
+ ? vector::IteratorType::parallel
+ : vector::IteratorType::reduction;
+ iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
+ }
+
+ // Create contraction.
+ Value contractOp = rewriter.create<vector::ContractionOp>(
+ loc, /*lhs=*/vecOperands[0],
+ /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
+ linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
+
+ // Store result.
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, contractOp, outOperand->get());
+
+ // Finalize.
+ if (!write->getResults().empty())
+ newResults.push_back(write->getResult(0));
+
+ return success();
+}
+
namespace {
enum class ConvOperationKind { Conv, Pool };
} // namespace
@@ -2528,11 +2610,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
tensor::InsertSliceOp>(op);
}
-FailureOr<VectorizationResult>
-mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
- ArrayRef<int64_t> inputVectorSizes,
- ArrayRef<bool> inputScalableVecDims,
- bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
+FailureOr<VectorizationResult> mlir::linalg::vectorize(
+ RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
+ ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
+ bool flatten1DDepthwiseConv, bool createNamedContraction) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2578,6 +2659,21 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return failure();
}
+ // For simplicity, contraction vectorization is limited to linalg
+ // named ops. Generic op is ignored as not every arbitrary
+ // contraction body can be expressed by a vector.contract.
+ if (createNamedContraction &&
+ isa<ContractionOpInterface>(linalgOp.getOperation())) {
+ // Attempt vectorizing directly into a named contraction.
+ // In case of failure, fall back to the generic path.
+ LogicalResult res = vectorizeAsLinalgContraction(
+ rewriter, state, linalgOp, results);
+ if (succeeded(res))
+ return success();
+
+ LDBG("Failed to vectorize as a named contraction.\n");
+ }
+
LDBG("Vectorize generic by broadcasting to the canonical vector "
"shape\n");
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 674a93d7c520e..594bfce57e598 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -320,14 +320,16 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
Value source,
ArrayRef<int64_t> inputVectorSizes,
Value padValue,
- bool useInBoundsInsteadOfMasking) {
+ bool useInBoundsInsteadOfMasking,
+ ArrayRef<bool> scalableDims) {
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
"invalid input vector sizes");
auto sourceShapedType = cast<ShapedType>(source.getType());
auto sourceShape = sourceShapedType.getShape();
assert(sourceShape.size() == inputVectorSizes.size() &&
"expected same ranks.");
- auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+ auto vectorType =
+ VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
assert(padValue.getType() == sourceShapedType.getElementType() &&
"expected same pad element type to match source element type");
int64_t readRank = inputVectorSizes.size();
@@ -352,9 +354,12 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
return transferReadOp;
SmallVector<OpFoldResult> mixedSourceDims =
- tensor::getMixedSizes(builder, loc, source);
+ isa<MemRefType>(source.getType())
+ ? memref::getMixedSizes(builder, loc, source)
+ : tensor::getMixedSizes(builder, loc, source);
- auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
+ auto maskType =
+ VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
Value mask =
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
new file mode 100644
index 0000000000000..1831acf092afb
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorization/contraction-named.mlir
@@ -0,0 +1,400 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
+
+func.func @matmul(%A: tensor<8x4xf32>, %B: tensor<4x16xf32>,
+ %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ %0 = linalg.matmul
+ ins(%A, %B : tensor<8x4xf32>, tensor<4x16xf32>)
+ outs(%C: tensor<8x16xf32>) -> tensor<8x16xf32>
+ return %0 : tensor<8x16xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul(
+// CHECK-SAME: %[[A:.*]]: tensor<8x4xf32>, %[[B:.*]]: tensor<4x16xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<8x16xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.transfer_read %[[A]]{{.*}}: tensor<8x4xf32>, vector<8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.transfer_read %[[B]]{{.*}}: tensor<4x16xf32>, vector<4x16xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.transfer_read %[[C]]{{.*}}: tensor<8x16xf32>, vector<8x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<8x16xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic(
+// CHECK-SAME: %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<?x?xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x16xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+ {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_dynamic_memref(%A: memref<?x?xf32>, %B: memref<?x?xf32>,
+ %C: memref<?x?xf32>) {
+ linalg.matmul
+ ins(%A, %B : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%C: memref<?x?xf32>)
+ return
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_dynamic_memref(
+// CHECK-SAME: %[[A:.*]]: memref<?x?xf32>, %[[B:.*]]: memref<?x?xf32>,
+// CHECK-SAME: %[[C:.*]]: memref<?x?xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: memref<?x?xf32>, vector<8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: memref<?x?xf32>, vector<4x16xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: memref<?x?xf32>, vector<8x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x16xf32>, memref<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [8, 16, 4]
+ {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_scalable(
+// CHECK-SAME: %[[A:.*]]: tensor<?x?xf32>, %[[B:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[C:.*]]: tensor<?x?xf32>)
+// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
+// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x[16]xf32>
+// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x[16]xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
+// CHECK-SAME: kind = #vector.kind<add>
+// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
+// CHECK: vector.mask{{.*}}{ vector.transfer_write %[[CONTRACT]], %[[C]]{{.*}}: vector<8x[16]xf32>, tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [8, [16], 4]
+ {create_named_contraction} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @matmul_transpose(%A: tensor<4x8xf32>, %B: tensor<16x4xf32>,
+ %C: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ %0 = linalg.matmul
+ indexing_maps = [affine_map<(m, n, k) -> (k, m)>, // transpose A
+ affine_map<(m, n, k) -> (n, k)>, // transpose B
+ affine_...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments, but this looks great, very lean and to the point.
I'll let others review and approve, but for the record, it looks good to me already, thanks!
@@ -2445,6 +2445,8 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize", | |||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">: | |||
$static_vector_sizes, | |||
OptionalAttr<UnitAttr>:$vectorize_nd_extract, | |||
OptionalAttr<UnitAttr>:$flatten1D_depthwise_conv, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any special reason to add this unrelated option?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for completness.
As I already tweak the transform op with a new option, this one present in linalg::vectorize
API was missing here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me add a bit more context.
I added the option to flatten the depthwise convs as an optimisation for convs with low channel dim count. While great for NEON (i.e. fixed width vectors), it's something that's tricky to generalise to scalable vectors. So I deliberately avoided extending the support ( I am waiting to see whether others find it useful).
I am fine with extending this Op, but if we do, we should also add more tests. Your call :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the insight 🙂
In this case, I'll remove the option to limit the scope of this PR.
|
||
// Check that all dimensions are present in the input operands. | ||
// Arbitrary broadcasts are not supported by the vector contraction. | ||
// Broadcasts are expected to be materialized before vectorization. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would the broadcast vectorization work in tandem with this one? Or can you call pinpoint vectorization on the contract and not on the surrounding code (say, a transform)?
If the latter, then we may (at some point later) validate the producers and consumers to make sure the vectorization won't break anything around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes to both. You can also vectorize selectively (either in a pass or a transform).
In general, the current vectorizer rewrites one op at the time. It creates read and write ops at the boundaries exactly to ensure seamless transition between a vectorized op and its consumers and producers.
At tensor level, these read-write pairs can easily cancel out thanks to value semantics. In memrefs, cleanup is obviously tricker due to possible aliasing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At tensor level, these read-write pairs can easily cancel out thanks to value semantics. In memrefs, cleanup is obviously tricker due to possible aliasing.
Yes, that's my worry, but I guess it's up to the transform user to know memrefs are harder and adjust strategy. This won't be the only problem they'll have with memrefs anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also don't want to guess how to materialize such broadcasts. We could just pick one default broadcasting scheme (e.g., canonical vector shape like in the linalg generic vectorizer) but it's likely to be suboptimal too.
Perhaps this broadcasting information could be encoded in the operand's layout? Sth I could experiment with later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I am missing "materialize" context in the vectorization concept. Is it a blocker to make the option default? What does materializing broadcasts mean? Is it breaking a matmul into something like broadcast(LHS)->matmul
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it a blocker to make the option default?
Not necessarily but it'd be good to align if broadcasts should be handled at all. If yes, then how.
What does materializing broadcasts mean? Is it breaking a matmul into something like
broadcast(LHS)->matmul
?
Correct. Today, broadcast semantics can't be preserved and kept within vector.contract
so the extra dimension has to be created somewhere.
I'll elaborate more in a separate answer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My terminology would be decompose
, because it is what we use for pack/unpack/pad/etc ops in upstream. If we break an op into a sequence of simpler ops, I'd call it decomposition. E.g., DecomposeGenericByUnfoldingPermutation,
llvm-project/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Lines 1615 to 1701 in c57fe2f
/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and | |
/// InsertSliceOp. For now, only constant padding values are supported. | |
struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> { | |
DecomposePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1) | |
: OpRewritePattern<tensor::PadOp>(context, benefit) {} | |
LogicalResult matchAndRewrite(tensor::PadOp padOp, | |
PatternRewriter &rewriter) const override; | |
protected: | |
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, | |
Value dest, | |
const SmallVector<Value> &dynSizes) const; | |
}; | |
/// Rewrites a linalg::PackOp into a sequence of: | |
/// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp + | |
/// tensor::InsertSliceOp ops. | |
/// | |
/// Requires that all the outer dims of the input linalg::PackOp are 1. | |
/// | |
/// Before: | |
/// ``` | |
/// %packed = linalg.pack %input | |
/// padding_value(%pad : f32) | |
/// inner_dims_pos = [1, 0] | |
/// inner_tiles = [2, %high] | |
/// into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32> | |
/// ``` | |
/// | |
/// After: | |
/// ``` | |
/// // PadOp | |
/// %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] { | |
/// ^bb0(...): | |
/// tensor.yield %arg2 : f32 | |
/// } : tensor<5x1xf32> to tensor<?x2xf32> | |
/// // EmptyOp + TransposeOp | |
/// %empty = tensor.empty(%arg3) : tensor<2x?xf32> | |
/// %transposed = linalg.transpose | |
/// ins(%extracted_slice : tensor<?x2xf32>) | |
/// outs(%empty : tensor<2x?xf32>) | |
/// permutation = [1, 0] | |
/// // InsertSliceOp | |
/// %inserted_slice = tensor.insert_slice %transposed | |
/// into %arg1[0, 0, 0, 0] [1, 1, 2, %tile_dim_1] [1, 1, 1, 1] | |
/// : tensor<2x?xf32> into tensor<1x1x2x?xf32> | |
/// ``` | |
struct DecomposeOuterUnitDimsPackOpPattern | |
: public OpRewritePattern<linalg::PackOp> { | |
using OpRewritePattern<linalg::PackOp>::OpRewritePattern; | |
LogicalResult matchAndRewrite(linalg::PackOp packOp, | |
PatternRewriter &rewriter) const override; | |
}; | |
/// Rewrites a linalg::UnPackOp into a sequence of rank-reduced | |
/// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp | |
/// | |
/// Requires that all the outer dims of the input linalg::PackOp are 1. | |
/// | |
/// Before: | |
/// ``` | |
/// %packed = linalg.unpack %input | |
/// inner_dims_pos = [1, 0] | |
/// inner_tiles = [2, 8] | |
/// into %output : tensor<1x1x2x8xf32> -> tensor<5x1xf32> | |
/// ``` | |
/// | |
/// After: | |
/// ``` | |
/// // Rank-reduced extract to obtain the tile | |
/// %slice = tensor.extract_slice %arg0[0, 0, 0, 0] [1, 1, 2, 8] [1, 1, 1, 1] | |
/// : tensor<1x1x2x8xf32> to tensor<2x8xf32> | |
/// // EmptyOp + TransposeOp | |
/// %init = tensor.empty() : tensor<8x2xf32> | |
/// %transposed = linalg.transpose | |
/// ins(%extracted_slice : tensor<2x8xf32>) | |
/// outs(%0 : tensor<8x2xf32>) permutation = [1, 0] | |
/// // Extract a slice matching the specified output size | |
/// %result = tensor.extract_slice %transposed[0, 0] [5, 1] [1, 1] | |
/// : tensor<8x2xf32> to tensor<5x1xf32> | |
/// ``` | |
struct DecomposeOuterUnitDimsUnPackOpPattern | |
: public OpRewritePattern<linalg::UnPackOp> { | |
using OpRewritePattern<linalg::UnPackOp>::OpRewritePattern; | |
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, | |
PatternRewriter &rewriter) const override; | |
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice, thanks!
I've made some small suggestions inline, but nothing major. The approach that you are taking makes a lot of sense to me.
The new path is optional and disabled by default to avoid changing the default vectorizer behavior.
This is good time to discuss the future direction for this.
@dcaballe has already hinted in Vector Dialect: Refactoring + Re-design ideas that we should look into removing vector.multi_reduction
. I am not sure about that myself just yet, but avoiding vector.multi_reduction
in the vectorization path makes a lot of sense to me (i.e. specifically for "Linalg contractions"). So:
- Why shouldn't we make it the default?
- What are the next step for you?
I am happy to work with you to make this the default if we/you discover that's something is missing. Otherwise, I would prepare people to flip the switch soonish.
@@ -2445,6 +2445,8 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize", | |||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">: | |||
$static_vector_sizes, | |||
OptionalAttr<UnitAttr>:$vectorize_nd_extract, | |||
OptionalAttr<UnitAttr>:$flatten1D_depthwise_conv, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me add a bit more context.
I added the option to flatten the depthwise convs as an optimisation for convs with low channel dim count. While great for NEON (i.e. fixed width vectors), it's something that's tricky to generalise to scalable vectors. So I deliberately avoided extending the support ( I am waiting to see whether others find it useful).
I am fine with extending this Op, but if we do, we should also add more tests. Your call :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note, I have recently refactored the vectorisation tests and grouped them as discussed here:
Specifically, this is what we have today:
mlir/test/Dialect/Linalg/vectorization/
├── conv.mlir
├── conv-flatten.mlir
├── conv-with-patterns.mlir
├── extract.mlir
├── extract-with-patterns.mlir
├── insert-slice.mlir
├── insert-slice-with-patterns.mlir
├── linalg-ops.mlir
├── linalg-ops-with-patterns.mlir
├── pad.mlir
├── pad-with-patterns.mlir
├── unsupported.mlir
So, contraction-named.mlir
doesn't quite fit. However, one of these options would:
contractione.mlir
(1 file), orcontraction-interface.mlir
(1 file), orcontract.mlir
+matmul.mlir
+batch_matmul.mlir
... (multiple files)
Given that the updated logic looks at ContractionOpInterface
, I vote for contraction-interface.mlir
. Other names are fine too, I would just add a top-level comment to explain what this file covers (this is the first test file to focus on a specific interface, so it's worth clarifying).
|
||
// ----- | ||
|
||
func.func @matmul_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Given that this is a variation of @matmul_dynamic
...
func.func @matmul_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, | |
func.func @matmul_dynamic_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, |
|
||
// ----- | ||
|
||
func.func @negative_matmul_broadcast(%A: tensor<4xf32>, %B: tensor<4x16xf32>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps add a note that "negative" in this case means: "Vectorization works, but you are not getting vector.contract as explicitly requested"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! The direction looks good to me. This is a long overdue. I think we should discuss what is missing to make this the default path and work towards that. Do you know what is missing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The approach looks okay to me because I'm -1 on recovering information to achieve the same thing when we have enough information. It'd be good to see what is missing to make it default, as I think it is usually bad to have boolean options, especially when they are linalg specific. It makes generalization harder if we want vectorization interface in the future. I'd like to learn why it can't be default and what is missing.
I also have a question about "materializing broadcast" in my inline comment.
|
||
// Check that all dimensions are present in the input operands. | ||
// Arbitrary broadcasts are not supported by the vector contraction. | ||
// Broadcasts are expected to be materialized before vectorization. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I am missing "materialize" context in the vectorization concept. Is it a blocker to make the option default? What does materializing broadcasts mean? Is it breaking a matmul into something like broadcast(LHS)->matmul
?
Thanks a lot for the feedback! This new lowering tries to preserve as much information as possible to avoid any need for reconstruction (like current GenericsOne point I want to address first is why we should ignore I think generics should be treated as such (well, generic) and be vectorized as they are today. One can always specialize a generic to capture more narrow behavior. Equally, linalg ops can be generalized before vectorization to retain current vectorizer behavior. Thus, no obstacle here preventing enabling specialized vectorization by default. Mixed precisionBoth linalg and vector support the same mixed precision semantics where inputs are converted into the output precision before computation. Today, the casts are externalized by default and folding them back into BroadcastsNow, let's look into broadcast semantics present in linalg named ops. %0 = linalg.matmul
indexing_maps = [affine_map<(m, n, k) -> (k)>, // broadcast LHS
affine_map<(m, n, k) -> (k)>, // broadcast RHS
affine_map<(m, n, k) -> (m, n)>]
ins(%arg0, %arg1 : tensor<12xf32>, tensor<12xf32>)
outs(%arg2: tensor<24x25xf32>) -> tensor<24x25xf32> The
%1 = vector.transfer_read %LHS : tensor<12xf32>, vector<12xf32>
%2 = vector.broadcast %1 : vector<12xf32> to vector<24x25x12xf32>
%3 = vector.transfer_read %RHS : tensor<12xf32>, vector<12xf32>
%4 = vector.transfer_read %ACC : tensor<24x25xf32>, vector<24x25xf32>
%5 = vector.contract ... kind = #vector.kind<add>} %2, %3, %4
: vector<24x25x12xf32>, vector<12xf32> into vector<24x25xf32> which is a valid realization. However, the 2D matmul semantics are lost and the computation turned it into an arbitrary contraction. At the moment, I went with the option to bail (fall back to generic behavior) in presence of broadcasts. This keeps the overall vectorization consistent between specialized lowering and generic+reconstruction. So, that makes for a fine default path as well. As a compromise, the same broadcast into canonical vector shape could be performed or a somewhat better broadcasting for contraction could be added to this specialized path. This makes specialized vs generic vectorization diverge a bit - it is fine but it might be better to give users time to adjust before making it the default. |
For the next steps, I'd love to hear your thoughts on how broadcasts should be handled.
Me neither🙂 |
When we broadcast like this, the maps should be the default indexing maps for |
I think there is ambiguity to begin with. Default maps are a reasonable ("canonical"?) choice in this case. |
Does it not? To me, if the maps are not uniquely defined, that would be a "bug" (as in, if there's ambiguity, it should be removed). However, taking this example: func.func @example%A : memref<?x?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?x?xf32>) {
linalg.contract
indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>,
affine_map<(batch, m, n, k) -> (k, n)>,
affine_map<(batch, m, n, k) -> (batch, m, n)>]
ins(%A, %B: memref<?x?x?xf32>, memref<?x?xf32>)
outs(%C: memref<?x?x?xf32>)
return
} "generalisation" does yield unique maps ( #map = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
module {
func.func @generalize_matmul_buffer(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2 : memref<?x?x?xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%0 = arith.mulf %in, %in_0 : f32
%1 = arith.addf %out, %0 : f32
linalg.yield %1 : f32
}
return
}
} What am I missing here? 🤔 |
I think we're talking about two slightly different things. Let me try to clarify.
For the earlier case I took your statement:
as in resolving broadcasting for this matmul: %0 = linalg.matmul
indexing_maps = [affine_map<(m, n, k) -> (k)>, // broadcast LHS
affine_map<(m, n, k) -> (k)>, // broadcast RHS
affine_map<(m, n, k) -> (m, n)>] should yield indexing_maps = [affine_map<(m, n, k) -> (m, k)>, // broadcast LHS
affine_map<(m, n, k) -> (k, n)>, // broadcast RHS
affine_map<(m, n, k) -> (m, n)>] as these can be derived from the default indexing maps for `linalg.matmul. I agree that's a valid strategy. However, the same indexing_maps = [affine_map<(m, n, k) -> (m, n, k)>,
affine_map<(m, n, k) -> (k)>,
affine_map<(m, n, k) -> (m, n)>] is a valid interpretation for broadcast decomposition. |
The impedance mismatch between linalg and vector contractions comes from the fact that Today, vectorizer obviously can vectorize such linalg ops with broadcasts. |
Personally, I lean toward forcing users to first decompose unsupported affine maps to not have to guess or compromise on an arbitrary default decomposition for these edge cases. All the default named op contractions and maps like these (note: it's not really indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k, m)>,
affine_map<(batch, m, n, k) -> (k, n)>,
affine_map<(batch, m, n, k) -> (batch, m, n)>] are directly representable by |
Thanks for elaborating, Adam! Now I see what was the source of confusion for me. Basically, there is no ambiguity at the
+1 |
Extends linalg vectorizer with a path to lower contraction ops directly into
vector.contract
.The direct rewriting preserves high-level op semantics and provides more progressive lowering compared to reconstructing contraction back from multi dimensional reduction.
The added lowering focuses on named linalg ops and leverages their well defined semantics to avoid complex precondition verification.
The new path is optional and disabled by default to avoid changing the default vectorizer behavior.