From 48b888330a93cf52abb35939db3cea5811334b29 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Wed, 2 Jul 2025 17:14:18 +0100 Subject: [PATCH 1/3] [mlir][Vector] Add canonicalization for extract_strided_slice(create_mask) --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 62 ++++++++++++++++++++-- mlir/test/Dialect/Vector/canonicalize.mlir | 17 ++++++ 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 1fb8c7a928e06..66c2fe50529d7 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4081,6 +4081,62 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl &results) { namespace { +class StridedSliceCreateMaskFolder final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + +public: + LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, + PatternRewriter &rewriter) const override { + Location loc = extractStridedSliceOp.getLoc(); + // Return if 'extractStridedSliceOp' operand is not defined by a + // CreateMaskOp. + auto *defOp = extractStridedSliceOp.getVector().getDefiningOp(); + auto createMaskOp = dyn_cast_or_null(defOp); + if (!createMaskOp) + return failure(); + // Return if 'extractStridedSliceOp' has non-unit strides. + if (extractStridedSliceOp.hasNonUnitStrides()) + return failure(); + // Gather constant mask dimension sizes. + SmallVector maskDimSizes(createMaskOp.getOperands()); + // Gather strided slice offsets and sizes. + SmallVector sliceOffsets; + populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(), + sliceOffsets); + SmallVector sliceSizes; + populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes); + + // Compute slice of vector mask region. + SmallVector sliceMaskDimSizes; + sliceMaskDimSizes.reserve(maskDimSizes.size()); + for (auto [maskDimSize, sliceOffset, sliceSize] : + llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) { + // No need to clamp on min/max values, because create_mask has clamping + // semantics, i.e. the sliceMaskDimSize is allowed to be negative or + // greater than the vector dim size. + IntegerAttr offsetAttr = + rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset); + Value offset = rewriter.create(loc, offsetAttr); + Value sliceMaskDimSize = + rewriter.create(loc, maskDimSize, offset); + sliceMaskDimSizes.push_back(sliceMaskDimSize); + } + // Add unchanged dimensions. + if (sliceMaskDimSizes.size() < maskDimSizes.size()) { + for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i) { + sliceMaskDimSizes.push_back(maskDimSizes[i]); + } + } + // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask + // region. + rewriter.replaceOpWithNewOp( + extractStridedSliceOp, extractStridedSliceOp.getResult().getType(), + sliceMaskDimSizes); + return success(); + } +}; + // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to // ConstantMaskOp. class StridedSliceConstantMaskFolder final @@ -4279,9 +4335,9 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp. - results.add( - context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 0282e9cac5e02..c09dab8232900 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -361,6 +361,23 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) { // ----- +// CHECK-LABEL: func.func @extract_strided_slice_of_create_mask +// CHECK-SAME: (%[[DIM0:.+]]: index, %[[DIM1:.+]]: index) +func.func @extract_strided_slice_of_create_mask(%dim0: index, %dim1: index) -> (vector<2x2xi1>) { + %0 = vector.create_mask %dim0, %dim1 : vector<4x3xi1> + %1 = vector.extract_strided_slice %0 + {offsets = [2, 1], sizes = [2, 2], strides = [1, 1]} + : vector<4x3xi1> to vector<2x2xi1> + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index + // CHECK-DAG: %[[A:.+]] = arith.subi %[[DIM0]], %[[C2]] + // CHECK-DAG: %[[B:.+]] = arith.subi %[[DIM1]], %[[C1]] + // CHECK: vector.create_mask %[[A]], %[[B]] : vector<2x2xi1> + return %1 : vector<2x2xi1> +} + +// ----- + // CHECK-LABEL: extract_strided_fold // CHECK-SAME: (%[[ARG:.*]]: vector<4x3xi1>) // CHECK-NEXT: return %[[ARG]] : vector<4x3xi1> From 8dc4dda3c0b671589951424e94b215921ab028e5 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Fri, 4 Jul 2025 11:52:38 +0100 Subject: [PATCH 2/3] address comments --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 25 ++++++++++++++++------ mlir/test/Dialect/Vector/canonicalize.mlir | 18 ++++++++++++++++ 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 66c2fe50529d7..6f166bc26811d 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4081,6 +4081,18 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl &results) { namespace { +// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to +// CreateMaskOp. +// +// Example: +// +// %mask = vector.create_mask %ub : vector<16xi1> +// %slice = vector.extract_strided_slice [%offset] [8] [1] +// +// to +// +// %new_ub = arith.subi %ub, %offset +// %mask = vector.create_mask %new_ub : vector<8xi1> class StridedSliceCreateMaskFolder final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -4101,10 +4113,10 @@ class StridedSliceCreateMaskFolder final // Gather constant mask dimension sizes. SmallVector maskDimSizes(createMaskOp.getOperands()); // Gather strided slice offsets and sizes. - SmallVector sliceOffsets; + SmallVector sliceOffsets; populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(), sliceOffsets); - SmallVector sliceSizes; + SmallVector sliceSizes; populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes); // Compute slice of vector mask region. @@ -4124,7 +4136,8 @@ class StridedSliceCreateMaskFolder final } // Add unchanged dimensions. if (sliceMaskDimSizes.size() < maskDimSizes.size()) { - for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i) { + for (size_t i = sliceMaskDimSizes.size(), e = maskDimSizes.size(); i < e; + ++i) { sliceMaskDimSizes.push_back(maskDimSizes[i]); } } @@ -4158,14 +4171,14 @@ class StridedSliceConstantMaskFolder final // Gather constant mask dimension sizes. ArrayRef maskDimSizes = constantMaskOp.getMaskDimSizes(); // Gather strided slice offsets and sizes. - SmallVector sliceOffsets; + SmallVector sliceOffsets; populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(), sliceOffsets); - SmallVector sliceSizes; + SmallVector sliceSizes; populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes); // Compute slice of vector mask region. - SmallVector sliceMaskDimSizes; + SmallVector sliceMaskDimSizes; sliceMaskDimSizes.reserve(maskDimSizes.size()); for (auto [maskDimSize, sliceOffset, sliceSize] : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index c09dab8232900..e05eb4b0ee5bb 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -378,6 +378,24 @@ func.func @extract_strided_slice_of_create_mask(%dim0: index, %dim1: index) -> ( // ----- +// CHECK-LABEL: func.func @extract_strided_slice_partial_of_create_mask +// CHECK-SAME: (%[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[DIM2:.+]]: index) +func.func @extract_strided_slice_partial_of_create_mask( + %dim0: index, %dim1: index, %dim2 : index) -> (vector<2x2x8xi1>) { + %0 = vector.create_mask %dim0, %dim1, %dim2 : vector<4x3x8xi1> + %1 = vector.extract_strided_slice %0 + {offsets = [2, 1], sizes = [2, 2], strides = [1, 1]} + : vector<4x3x8xi1> to vector<2x2x8xi1> + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index + // CHECK-DAG: %[[A:.+]] = arith.subi %[[DIM0]], %[[C2]] + // CHECK-DAG: %[[B:.+]] = arith.subi %[[DIM1]], %[[C1]] + // CHECK: vector.create_mask %[[A]], %[[B]], %[[DIM2]] : vector<2x2x8xi1> + return %1 : vector<2x2x8xi1> +} + +// ----- + // CHECK-LABEL: extract_strided_fold // CHECK-SAME: (%[[ARG:.*]]: vector<4x3xi1>) // CHECK-NEXT: return %[[ARG]] : vector<4x3xi1> From faee3e5820234df8ef06c0fecb33e6184eef1f41 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Wed, 9 Jul 2025 14:08:29 +0100 Subject: [PATCH 3/3] address comments --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 6f166bc26811d..1d0f325621d1d 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4103,8 +4103,8 @@ class StridedSliceCreateMaskFolder final Location loc = extractStridedSliceOp.getLoc(); // Return if 'extractStridedSliceOp' operand is not defined by a // CreateMaskOp. - auto *defOp = extractStridedSliceOp.getVector().getDefiningOp(); - auto createMaskOp = dyn_cast_or_null(defOp); + auto createMaskOp = + extractStridedSliceOp.getVector().getDefiningOp(); if (!createMaskOp) return failure(); // Return if 'extractStridedSliceOp' has non-unit strides. @@ -4122,6 +4122,9 @@ class StridedSliceCreateMaskFolder final // Compute slice of vector mask region. SmallVector sliceMaskDimSizes; sliceMaskDimSizes.reserve(maskDimSizes.size()); + // sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and + // only iterate on the leading dim sizes. The tail accounts for the + // remaining dim sizes. for (auto [maskDimSize, sliceOffset, sliceSize] : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) { // No need to clamp on min/max values, because create_mask has clamping @@ -4135,12 +4138,9 @@ class StridedSliceCreateMaskFolder final sliceMaskDimSizes.push_back(sliceMaskDimSize); } // Add unchanged dimensions. - if (sliceMaskDimSizes.size() < maskDimSizes.size()) { - for (size_t i = sliceMaskDimSizes.size(), e = maskDimSizes.size(); i < e; - ++i) { - sliceMaskDimSizes.push_back(maskDimSizes[i]); - } - } + llvm::append_range( + sliceMaskDimSizes, + llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size())); // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask // region. rewriter.replaceOpWithNewOp(