-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][Vector] Support scalar vector.extract
in VectorLinearize
#147440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Generate a linearized version of the `vector.extract` for these cases.
@llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesIt generates a linearized version of the Full diff: https://github.com/llvm/llvm-project/pull/147440.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7cac1cbafdd64..8b232aafbca9d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -395,15 +395,32 @@ struct LinearizeVectorShuffle final
}
};
-/// This pattern converts the ExtractOp to a ShuffleOp that works on a
-/// linearized vector.
-/// Following,
-/// vector.extract %source [ position ]
-/// is converted to :
-/// %source_1d = vector.shape_cast %source
-/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
-/// %out_nd = vector.shape_cast %out_1d
-/// `shuffle_indices_1d` is computed using the position of the original extract.
+/// This pattern linearizes `vector.extract` operations. It generates a 1-D
+/// version of the `vector.extract` operation when extracting a scalar from a
+/// vector. It generates a 1-D `vector.shuffle` operation when extracting a
+/// subvector from a larger vector.
+///
+/// Example #1:
+///
+/// %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
+///
+/// is converted to:
+///
+/// %0 = vector.shape_cast %arg0 : vector<2x8x2xf32> to vector<32xf32>
+/// %1 = vector.shuffle %0, %0 [16, 17, 18, 19, 20, 21, 22, 23,
+/// 24, 25, 26, 27, 28, 29, 30, 31] :
+/// vector<32xf32>, vector<32xf32>
+/// %2 = vector.shape_cast %1 : vector<16xf32> to vector<8x2xf32>
+///
+/// Example #2:
+///
+/// %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
+///
+/// is converted to:
+///
+/// %0 = vector.shape_cast %arg0 : vector<2x4xi32> to vector<8xi32>
+/// %1 = vector.extract %0[6] : i32 from vector<8xi32>
+///
struct LinearizeVectorExtract final
: public OpConversionPattern<vector::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
@@ -413,10 +430,6 @@ struct LinearizeVectorExtract final
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // Skip if result is not a vector type
- if (!isa<VectorType>(extractOp.getType()))
- return rewriter.notifyMatchFailure(extractOp,
- "scalar extract not supported");
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
assert(dstTy && "expected 1-D vector type");
@@ -436,6 +449,16 @@ struct LinearizeVectorExtract final
linearizedOffset += offsets[i] * size;
}
+ if (!isa<VectorType>(extractOp.getType())) {
+ // Scalar case: generate a 1-D extract.
+ Value result = rewriter.createOrFold<vector::ExtractOp>(
+ extractOp.getLoc(), adaptor.getVector(), linearizedOffset);
+ rewriter.replaceOp(extractOp, result);
+ return success();
+ }
+
+ // Vector case: generate a shuffle.
+
llvm::SmallVector<int64_t, 2> indices(size);
std::iota(indices.begin(), indices.end(), linearizedOffset);
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 894171500d9d6..cbc15f34918f6 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -264,6 +264,19 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -
// -----
+// CHECK-LABEL: test_vector_extract_scalar
+// CHECK-SAME: (%[[ARG:.*]]: vector<2x4xi32>) -> i32 {
+func.func @test_vector_extract_scalar(%arg0 : vector<2x4xi32>) -> i32 {
+
+ // CHECK: %[[SRC_1D:.*]] = vector.shape_cast %[[ARG]] : vector<2x4xi32> to vector<8xi32>
+ // CHECK: %[[EXTRACT_1D:.*]] = vector.extract %[[SRC_1D]][6] : i32 from vector<8xi32>
+ // CHECK: return %[[EXTRACT_1D]] : i32
+ %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
+ return %0 : i32
+}
+
+// -----
+
// CHECK-LABEL: test_vector_extract
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> {
func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
@@ -341,19 +354,6 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector
// -----
-// CHECK-LABEL: test_vector_extract_scalar
-func.func @test_vector_extract_scalar(%idx : index) {
- %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
-
- // CHECK-NOT: vector.shuffle
- // CHECK: vector.extract
- // CHECK-NOT: vector.shuffle
- %0 = vector.extract %cst[%idx] : i32 from vector<4xi32>
- return
-}
-
-// -----
-
// CHECK-LABEL: test_vector_bitcast
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x4xf32>
func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {
|
@llvm/pr-subscribers-mlir-vector Author: Diego Caballero (dcaballe) ChangesIt generates a linearized version of the Full diff: https://github.com/llvm/llvm-project/pull/147440.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7cac1cbafdd64..8b232aafbca9d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -395,15 +395,32 @@ struct LinearizeVectorShuffle final
}
};
-/// This pattern converts the ExtractOp to a ShuffleOp that works on a
-/// linearized vector.
-/// Following,
-/// vector.extract %source [ position ]
-/// is converted to :
-/// %source_1d = vector.shape_cast %source
-/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
-/// %out_nd = vector.shape_cast %out_1d
-/// `shuffle_indices_1d` is computed using the position of the original extract.
+/// This pattern linearizes `vector.extract` operations. It generates a 1-D
+/// version of the `vector.extract` operation when extracting a scalar from a
+/// vector. It generates a 1-D `vector.shuffle` operation when extracting a
+/// subvector from a larger vector.
+///
+/// Example #1:
+///
+/// %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
+///
+/// is converted to:
+///
+/// %0 = vector.shape_cast %arg0 : vector<2x8x2xf32> to vector<32xf32>
+/// %1 = vector.shuffle %0, %0 [16, 17, 18, 19, 20, 21, 22, 23,
+/// 24, 25, 26, 27, 28, 29, 30, 31] :
+/// vector<32xf32>, vector<32xf32>
+/// %2 = vector.shape_cast %1 : vector<16xf32> to vector<8x2xf32>
+///
+/// Example #2:
+///
+/// %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
+///
+/// is converted to:
+///
+/// %0 = vector.shape_cast %arg0 : vector<2x4xi32> to vector<8xi32>
+/// %1 = vector.extract %0[6] : i32 from vector<8xi32>
+///
struct LinearizeVectorExtract final
: public OpConversionPattern<vector::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
@@ -413,10 +430,6 @@ struct LinearizeVectorExtract final
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // Skip if result is not a vector type
- if (!isa<VectorType>(extractOp.getType()))
- return rewriter.notifyMatchFailure(extractOp,
- "scalar extract not supported");
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
assert(dstTy && "expected 1-D vector type");
@@ -436,6 +449,16 @@ struct LinearizeVectorExtract final
linearizedOffset += offsets[i] * size;
}
+ if (!isa<VectorType>(extractOp.getType())) {
+ // Scalar case: generate a 1-D extract.
+ Value result = rewriter.createOrFold<vector::ExtractOp>(
+ extractOp.getLoc(), adaptor.getVector(), linearizedOffset);
+ rewriter.replaceOp(extractOp, result);
+ return success();
+ }
+
+ // Vector case: generate a shuffle.
+
llvm::SmallVector<int64_t, 2> indices(size);
std::iota(indices.begin(), indices.end(), linearizedOffset);
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 894171500d9d6..cbc15f34918f6 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -264,6 +264,19 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -
// -----
+// CHECK-LABEL: test_vector_extract_scalar
+// CHECK-SAME: (%[[ARG:.*]]: vector<2x4xi32>) -> i32 {
+func.func @test_vector_extract_scalar(%arg0 : vector<2x4xi32>) -> i32 {
+
+ // CHECK: %[[SRC_1D:.*]] = vector.shape_cast %[[ARG]] : vector<2x4xi32> to vector<8xi32>
+ // CHECK: %[[EXTRACT_1D:.*]] = vector.extract %[[SRC_1D]][6] : i32 from vector<8xi32>
+ // CHECK: return %[[EXTRACT_1D]] : i32
+ %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
+ return %0 : i32
+}
+
+// -----
+
// CHECK-LABEL: test_vector_extract
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> {
func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
@@ -341,19 +354,6 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector
// -----
-// CHECK-LABEL: test_vector_extract_scalar
-func.func @test_vector_extract_scalar(%idx : index) {
- %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
-
- // CHECK-NOT: vector.shuffle
- // CHECK: vector.extract
- // CHECK-NOT: vector.shuffle
- %0 = vector.extract %cst[%idx] : i32 from vector<4xi32>
- return
-}
-
-// -----
-
// CHECK-LABEL: test_vector_bitcast
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x4xf32>
func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {
|
It generates a linearized version of the
vector.extract
for scalar cases.