diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 7cac1cbafdd64..8b232aafbca9d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -395,15 +395,32 @@ struct LinearizeVectorShuffle final } }; -/// This pattern converts the ExtractOp to a ShuffleOp that works on a -/// linearized vector. -/// Following, -/// vector.extract %source [ position ] -/// is converted to : -/// %source_1d = vector.shape_cast %source -/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] -/// %out_nd = vector.shape_cast %out_1d -/// `shuffle_indices_1d` is computed using the position of the original extract. +/// This pattern linearizes `vector.extract` operations. It generates a 1-D +/// version of the `vector.extract` operation when extracting a scalar from a +/// vector. It generates a 1-D `vector.shuffle` operation when extracting a +/// subvector from a larger vector. +/// +/// Example #1: +/// +/// %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32> +/// +/// is converted to: +/// +/// %0 = vector.shape_cast %arg0 : vector<2x8x2xf32> to vector<32xf32> +/// %1 = vector.shuffle %0, %0 [16, 17, 18, 19, 20, 21, 22, 23, +/// 24, 25, 26, 27, 28, 29, 30, 31] : +/// vector<32xf32>, vector<32xf32> +/// %2 = vector.shape_cast %1 : vector<16xf32> to vector<8x2xf32> +/// +/// Example #2: +/// +/// %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32> +/// +/// is converted to: +/// +/// %0 = vector.shape_cast %arg0 : vector<2x4xi32> to vector<8xi32> +/// %1 = vector.extract %0[6] : i32 from vector<8xi32> +/// struct LinearizeVectorExtract final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -413,10 +430,6 @@ struct LinearizeVectorExtract final LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Skip if result is not a vector type - if (!isa(extractOp.getType())) - return rewriter.notifyMatchFailure(extractOp, - "scalar extract not supported"); Type dstTy = getTypeConverter()->convertType(extractOp.getType()); assert(dstTy && "expected 1-D vector type"); @@ -436,6 +449,16 @@ struct LinearizeVectorExtract final linearizedOffset += offsets[i] * size; } + if (!isa(extractOp.getType())) { + // Scalar case: generate a 1-D extract. + Value result = rewriter.createOrFold( + extractOp.getLoc(), adaptor.getVector(), linearizedOffset); + rewriter.replaceOp(extractOp, result); + return success(); + } + + // Vector case: generate a shuffle. + llvm::SmallVector indices(size); std::iota(indices.begin(), indices.end(), linearizedOffset); rewriter.replaceOpWithNewOp( diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 894171500d9d6..cbc15f34918f6 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -264,6 +264,19 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) - // ----- +// CHECK-LABEL: test_vector_extract_scalar +// CHECK-SAME: (%[[ARG:.*]]: vector<2x4xi32>) -> i32 { +func.func @test_vector_extract_scalar(%arg0 : vector<2x4xi32>) -> i32 { + + // CHECK: %[[SRC_1D:.*]] = vector.shape_cast %[[ARG]] : vector<2x4xi32> to vector<8xi32> + // CHECK: %[[EXTRACT_1D:.*]] = vector.extract %[[SRC_1D]][6] : i32 from vector<8xi32> + // CHECK: return %[[EXTRACT_1D]] : i32 + %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32> + return %0 : i32 +} + +// ----- + // CHECK-LABEL: test_vector_extract // CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> { func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> { @@ -341,19 +354,6 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector // ----- -// CHECK-LABEL: test_vector_extract_scalar -func.func @test_vector_extract_scalar(%idx : index) { - %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32> - - // CHECK-NOT: vector.shuffle - // CHECK: vector.extract - // CHECK-NOT: vector.shuffle - %0 = vector.extract %cst[%idx] : i32 from vector<4xi32> - return -} - -// ----- - // CHECK-LABEL: test_vector_bitcast // CHECK-SAME: %[[ARG_0:.*]]: vector<4x4xf32> func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {