Skip to content

Commit ace1c83

Browse files
authored
[mlir][Vector] Support scalar vector.extract in VectorLinearize (#147440)
It generates a linearized version of the `vector.extract` for scalar cases.
1 parent aec3016 commit ace1c83

File tree

2 files changed

+52
-28
lines changed

2 files changed

+52
-28
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -395,15 +395,32 @@ struct LinearizeVectorShuffle final
395395
}
396396
};
397397

398-
/// This pattern converts the ExtractOp to a ShuffleOp that works on a
399-
/// linearized vector.
400-
/// Following,
401-
/// vector.extract %source [ position ]
402-
/// is converted to :
403-
/// %source_1d = vector.shape_cast %source
404-
/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
405-
/// %out_nd = vector.shape_cast %out_1d
406-
/// `shuffle_indices_1d` is computed using the position of the original extract.
398+
/// This pattern linearizes `vector.extract` operations. It generates a 1-D
399+
/// version of the `vector.extract` operation when extracting a scalar from a
400+
/// vector. It generates a 1-D `vector.shuffle` operation when extracting a
401+
/// subvector from a larger vector.
402+
///
403+
/// Example #1:
404+
///
405+
/// %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
406+
///
407+
/// is converted to:
408+
///
409+
/// %0 = vector.shape_cast %arg0 : vector<2x8x2xf32> to vector<32xf32>
410+
/// %1 = vector.shuffle %0, %0 [16, 17, 18, 19, 20, 21, 22, 23,
411+
/// 24, 25, 26, 27, 28, 29, 30, 31] :
412+
/// vector<32xf32>, vector<32xf32>
413+
/// %2 = vector.shape_cast %1 : vector<16xf32> to vector<8x2xf32>
414+
///
415+
/// Example #2:
416+
///
417+
/// %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
418+
///
419+
/// is converted to:
420+
///
421+
/// %0 = vector.shape_cast %arg0 : vector<2x4xi32> to vector<8xi32>
422+
/// %1 = vector.extract %0[6] : i32 from vector<8xi32>
423+
///
407424
struct LinearizeVectorExtract final
408425
: public OpConversionPattern<vector::ExtractOp> {
409426
using OpConversionPattern::OpConversionPattern;
@@ -413,10 +430,6 @@ struct LinearizeVectorExtract final
413430
LogicalResult
414431
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
415432
ConversionPatternRewriter &rewriter) const override {
416-
// Skip if result is not a vector type
417-
if (!isa<VectorType>(extractOp.getType()))
418-
return rewriter.notifyMatchFailure(extractOp,
419-
"scalar extract not supported");
420433
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
421434
assert(dstTy && "expected 1-D vector type");
422435

@@ -436,10 +449,21 @@ struct LinearizeVectorExtract final
436449
linearizedOffset += offsets[i] * size;
437450
}
438451

452+
Value srcVector = adaptor.getVector();
453+
if (!isa<VectorType>(extractOp.getType())) {
454+
// Scalar case: generate a 1-D extract.
455+
Value result = rewriter.createOrFold<vector::ExtractOp>(
456+
extractOp.getLoc(), srcVector, linearizedOffset);
457+
rewriter.replaceOp(extractOp, result);
458+
return success();
459+
}
460+
461+
// Vector case: generate a shuffle.
462+
439463
llvm::SmallVector<int64_t, 2> indices(size);
440464
std::iota(indices.begin(), indices.end(), linearizedOffset);
441-
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
442-
extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
465+
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(extractOp, dstTy, srcVector,
466+
srcVector, indices);
443467

444468
return success();
445469
}

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,19 @@ func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -
264264

265265
// -----
266266

267+
// CHECK-LABEL: test_vector_extract_scalar
268+
// CHECK-SAME: (%[[ARG:.*]]: vector<2x4xi32>) -> i32 {
269+
func.func @test_vector_extract_scalar(%arg0 : vector<2x4xi32>) -> i32 {
270+
271+
// CHECK: %[[SRC_1D:.*]] = vector.shape_cast %[[ARG]] : vector<2x4xi32> to vector<8xi32>
272+
// CHECK: %[[EXTRACT_1D:.*]] = vector.extract %[[SRC_1D]][6] : i32 from vector<8xi32>
273+
// CHECK: return %[[EXTRACT_1D]] : i32
274+
%0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
275+
return %0 : i32
276+
}
277+
278+
// -----
279+
267280
// CHECK-LABEL: test_vector_extract
268281
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> {
269282
func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
@@ -341,19 +354,6 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector
341354

342355
// -----
343356

344-
// CHECK-LABEL: test_vector_extract_scalar
345-
func.func @test_vector_extract_scalar(%idx : index) {
346-
%cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
347-
348-
// CHECK-NOT: vector.shuffle
349-
// CHECK: vector.extract
350-
// CHECK-NOT: vector.shuffle
351-
%0 = vector.extract %cst[%idx] : i32 from vector<4xi32>
352-
return
353-
}
354-
355-
// -----
356-
357357
// CHECK-LABEL: test_vector_bitcast
358358
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x4xf32>
359359
func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {

0 commit comments

Comments
 (0)