diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index d64f94a49f781..472df21cb464e 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2440,12 +2440,11 @@ def VectorizeOp : Op:$vector_sizes, - DefaultValuedOptionalAttr: - $static_vector_sizes, - OptionalAttr:$vectorize_nd_extract, - DefaultValuedOptionalAttr: - $scalable_sizes); + Variadic:$vector_sizes, + DefaultValuedOptionalAttr:$static_vector_sizes, + OptionalAttr:$vectorize_nd_extract, + OptionalAttr:$assume_dynamic_dims_match_vec_sizes, + DefaultValuedOptionalAttr:$scalable_sizes); let results = (outs); diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 2b4855f49695c..8ba4f8f218721 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -871,7 +871,8 @@ FailureOr vectorize(RewriterBase &rewriter, Operation *op, ArrayRef inputVectorSizes = {}, ArrayRef inputScalableVecDims = {}, - bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false); + bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false, + bool assumeDynamicDimsMatchVecSizes = false); /// Emit a suitable vector form for a Copy op with fully static shape. LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 8571d641e26d1..7e1911a56693f 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3921,7 +3921,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply( } FailureOr vectorResults = linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(), - getVectorizeNdExtract().value_or(false)); + getVectorizeNdExtract().value_or(false), false, + getAssumeDynamicDimsMatchVecSizes().value_or(false)); if (failed(vectorResults)) { return mlir::emitSilenceableFailure(target->getLoc()) << "Attempted to vectorize, but failed"; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index b467114c72f7d..11ba3eebc128a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -222,7 +222,8 @@ struct VectorizationState { /// canonical vector shape for vectorization. LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef inputVectorSizes, - ArrayRef inputScalableVecDims); + ArrayRef inputScalableVecDims, + bool assumeDynamicDimsMatchVecSizes = false); /// Returns the canonical vector shape used to vectorize the iteration space. ArrayRef getCanonicalVecShape() const { return canonicalVecShape; } @@ -331,6 +332,14 @@ struct VectorizationState { /// Global vectorization guard for the incoming rewriter. It's initialized /// when the vectorization state is initialized. OpBuilder::InsertionGuard rewriterGuard; + + /// Do all dynamic dims match the corresponding vector sizes? + /// + /// When a dynamic tensor/memref dimension matches the corresponding vector + /// dimension, masking can be safely skipped, despite the presence of dynamic + /// shapes. Use this flag with care and only for cases where you are + /// confident the assumption holds. + bool assumeDynamicDimsMatchVecSizes = false; }; LogicalResult @@ -367,10 +376,12 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter, /// Initializes the vectorization state, including the computation of the /// canonical vector shape for vectorization. // TODO: Move this to the constructor when we can remove the failure cases. -LogicalResult -VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp, - ArrayRef inputVectorSizes, - ArrayRef inputScalableVecDims) { +LogicalResult VectorizationState::initState(RewriterBase &rewriter, + LinalgOp linalgOp, + ArrayRef inputVectorSizes, + ArrayRef inputScalableVecDims, + bool assumeDimsMatchVec) { + assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec; // Initialize the insertion point. rewriter.setInsertionPoint(linalgOp); @@ -470,6 +481,21 @@ Value VectorizationState::getOrCreateMaskFor( return Value(); } + if (assumeDynamicDimsMatchVecSizes) { + // Given that all _scalable vector sizes_ match the corresponding + // memref/tensor dim sizes, masking can be skipped provided that: + // * all vector sizes corresponding to dynamic dims are scalable. + if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getScalableDims()), + [](auto it) { + return std::get<0>(it) == ShapedType::kDynamic + ? std::get<1>(it) + : false; + })) + LDBG("Masking is not needed for masking map: " << maskingMap << "\n"); + activeMaskCache[maskingMap] = Value(); + return Value(); + } + // Permute the iteration space value sizes to compute the mask upper bounds. SmallVector upperBounds = applyPermutationMap(maskingMap, ArrayRef(iterSpaceValueSizes)); @@ -2479,7 +2505,8 @@ vectorizeScalableVectorPrecondition(Operation *op, return success(isElementwise(linalgOp) || isa(op) || isa(op) || isa(op) || - isa(op) || hasReductionIterator(linalgOp)); + isa(op) || isa(op) || + hasReductionIterator(linalgOp)); } LogicalResult mlir::linalg::vectorizeOpPrecondition( @@ -2535,11 +2562,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) { tensor::InsertSliceOp>(op); } -FailureOr -mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, - ArrayRef inputVectorSizes, - ArrayRef inputScalableVecDims, - bool vectorizeNDExtract, bool flatten1DDepthwiseConv) { +FailureOr mlir::linalg::vectorize( + RewriterBase &rewriter, Operation *op, ArrayRef inputVectorSizes, + ArrayRef inputScalableVecDims, bool vectorizeNDExtract, + bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes) { LDBG("Attempting to vectorize:\n" << *op << "\n"); LDBG("Input vector sizes: "); LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); @@ -2559,7 +2585,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, VectorizationState state(rewriter); if (auto linalgOp = dyn_cast(op)) { if (failed(state.initState(rewriter, linalgOp, inputVectorSizes, - inputScalableVecDims))) { + inputScalableVecDims, + assumeDynamicDimsMatchVecSizes))) { LDBG("Vectorization state couldn't be initialized\n"); return failure(); } diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir index 6722de817f6bf..188f03069938f 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir @@ -840,6 +840,99 @@ module attributes {transform.with_named_sequence} { } } +// ----- + +///---------------------------------------------------------------------------------------- +/// Tests for linalg.mmt4d +///---------------------------------------------------------------------------------------- + +func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) { + linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>) + outs(%C_in: memref<16x16x8x8xf32>) + return +} + +// CHECK-LABEL: func.func @mmt4d( +// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) { +// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32> +// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32> +// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32> +// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32> +// CHECK: %[[RED:.*]] = vector.multi_reduction , %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32> +// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %mmt4d : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @mmt4d_scalable(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) { + linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>) + outs(%C_in: memref<16x16x8x?xf32>) + return +} +// CHECK-LABEL: func.func @mmt4d_scalable( +// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, +// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>, +// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) { +// CHECK: %[[VAL_0:.*]] = arith.constant 16 : index +// CHECK: %[[VAL_1:.*]] = arith.constant 16 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 16 : index +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[DIM_2:.*]] = memref.dim %[[B]], %[[C2]] : memref<16x16x?x1xf32> +// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index +// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32> +// CHECK: %[[MASK_1:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x[4]x1xi1> +// CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> } : vector<16x16x[4]x1xi1> -> vector<16x16x16x8x[4]x1xf32> +// CHECK: %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1> +// CHECK: %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32> +// CHECK: %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32> +// CHECK: %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x16x8x[4]x1xi1> +// CHECK: %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction , %[[VAL_16]], %[[VAL_15]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32> +// CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1> + + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) { + linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>) + outs(%C_in: memref<16x16x8x?xf32>) + return +} +// CHECK-LABEL: func.func @mmt4d_scalable_with_assume( +// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, +// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>, +// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) { +// CHECK-NOT: mask +// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32> +// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> +// CHECK: %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> +// CHECK: %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32> +// CHECK: %[[VAL_15:.*]] = vector.multi_reduction , %[[VAL_14]], %[[VAL_13]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> +// CHECK: vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] {assume_scalable_sizes_match_dim_size} : !transform.any_op + transform.yield + } +} + ///---------------------------------------------------------------------------------------- /// Tests for other Ops ///---------------------------------------------------------------------------------------- @@ -1094,30 +1187,6 @@ module attributes {transform.with_named_sequence} { } } -// ----- - -func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) { - linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>) - outs(%C_in: memref<16x16x8x8xf32>) - return -} - -// CHECK-LABEL: func.func @mmt4d( -// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) { -// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32> -// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32> -// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32> -// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32> -// CHECK: %[[RED:.*]] = vector.multi_reduction , %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32> -// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32> - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.structured.vectorize %mmt4d : !transform.any_op - transform.yield - } -} // -----