diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 1fb8c7a928e06..1d0f325621d1d 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4081,6 +4081,75 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl &results) { namespace { +// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to +// CreateMaskOp. +// +// Example: +// +// %mask = vector.create_mask %ub : vector<16xi1> +// %slice = vector.extract_strided_slice [%offset] [8] [1] +// +// to +// +// %new_ub = arith.subi %ub, %offset +// %mask = vector.create_mask %new_ub : vector<8xi1> +class StridedSliceCreateMaskFolder final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + +public: + LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, + PatternRewriter &rewriter) const override { + Location loc = extractStridedSliceOp.getLoc(); + // Return if 'extractStridedSliceOp' operand is not defined by a + // CreateMaskOp. + auto createMaskOp = + extractStridedSliceOp.getVector().getDefiningOp(); + if (!createMaskOp) + return failure(); + // Return if 'extractStridedSliceOp' has non-unit strides. + if (extractStridedSliceOp.hasNonUnitStrides()) + return failure(); + // Gather constant mask dimension sizes. + SmallVector maskDimSizes(createMaskOp.getOperands()); + // Gather strided slice offsets and sizes. + SmallVector sliceOffsets; + populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(), + sliceOffsets); + SmallVector sliceSizes; + populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes); + + // Compute slice of vector mask region. + SmallVector sliceMaskDimSizes; + sliceMaskDimSizes.reserve(maskDimSizes.size()); + // sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and + // only iterate on the leading dim sizes. The tail accounts for the + // remaining dim sizes. + for (auto [maskDimSize, sliceOffset, sliceSize] : + llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) { + // No need to clamp on min/max values, because create_mask has clamping + // semantics, i.e. the sliceMaskDimSize is allowed to be negative or + // greater than the vector dim size. + IntegerAttr offsetAttr = + rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset); + Value offset = rewriter.create(loc, offsetAttr); + Value sliceMaskDimSize = + rewriter.create(loc, maskDimSize, offset); + sliceMaskDimSizes.push_back(sliceMaskDimSize); + } + // Add unchanged dimensions. + llvm::append_range( + sliceMaskDimSizes, + llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size())); + // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask + // region. + rewriter.replaceOpWithNewOp( + extractStridedSliceOp, extractStridedSliceOp.getResult().getType(), + sliceMaskDimSizes); + return success(); + } +}; + // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to // ConstantMaskOp. class StridedSliceConstantMaskFolder final @@ -4102,14 +4171,14 @@ class StridedSliceConstantMaskFolder final // Gather constant mask dimension sizes. ArrayRef maskDimSizes = constantMaskOp.getMaskDimSizes(); // Gather strided slice offsets and sizes. - SmallVector sliceOffsets; + SmallVector sliceOffsets; populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(), sliceOffsets); - SmallVector sliceSizes; + SmallVector sliceSizes; populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes); // Compute slice of vector mask region. - SmallVector sliceMaskDimSizes; + SmallVector sliceMaskDimSizes; sliceMaskDimSizes.reserve(maskDimSizes.size()); for (auto [maskDimSize, sliceOffset, sliceSize] : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) { @@ -4279,9 +4348,9 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp. - results.add( - context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 0282e9cac5e02..e05eb4b0ee5bb 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -361,6 +361,41 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) { // ----- +// CHECK-LABEL: func.func @extract_strided_slice_of_create_mask +// CHECK-SAME: (%[[DIM0:.+]]: index, %[[DIM1:.+]]: index) +func.func @extract_strided_slice_of_create_mask(%dim0: index, %dim1: index) -> (vector<2x2xi1>) { + %0 = vector.create_mask %dim0, %dim1 : vector<4x3xi1> + %1 = vector.extract_strided_slice %0 + {offsets = [2, 1], sizes = [2, 2], strides = [1, 1]} + : vector<4x3xi1> to vector<2x2xi1> + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index + // CHECK-DAG: %[[A:.+]] = arith.subi %[[DIM0]], %[[C2]] + // CHECK-DAG: %[[B:.+]] = arith.subi %[[DIM1]], %[[C1]] + // CHECK: vector.create_mask %[[A]], %[[B]] : vector<2x2xi1> + return %1 : vector<2x2xi1> +} + +// ----- + +// CHECK-LABEL: func.func @extract_strided_slice_partial_of_create_mask +// CHECK-SAME: (%[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[DIM2:.+]]: index) +func.func @extract_strided_slice_partial_of_create_mask( + %dim0: index, %dim1: index, %dim2 : index) -> (vector<2x2x8xi1>) { + %0 = vector.create_mask %dim0, %dim1, %dim2 : vector<4x3x8xi1> + %1 = vector.extract_strided_slice %0 + {offsets = [2, 1], sizes = [2, 2], strides = [1, 1]} + : vector<4x3x8xi1> to vector<2x2x8xi1> + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index + // CHECK-DAG: %[[A:.+]] = arith.subi %[[DIM0]], %[[C2]] + // CHECK-DAG: %[[B:.+]] = arith.subi %[[DIM1]], %[[C1]] + // CHECK: vector.create_mask %[[A]], %[[B]], %[[DIM2]] : vector<2x2x8xi1> + return %1 : vector<2x2x8xi1> +} + +// ----- + // CHECK-LABEL: extract_strided_fold // CHECK-SAME: (%[[ARG:.*]]: vector<4x3xi1>) // CHECK-NEXT: return %[[ARG]] : vector<4x3xi1>