Skip to content

Commit 7451e4c

Browse files
authored
[mlir][Vector] Support scalar 'vector.insert' in vector linearization (#146954)
This PR add support for linearizing the insertion of a scalar element by just linearizing the `vector.insert` op.
1 parent 26a766a commit 7451e4c

File tree

2 files changed

+70
-29
lines changed

2 files changed

+70
-29
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -445,16 +445,36 @@ struct LinearizeVectorExtract final
445445
}
446446
};
447447

448-
/// This pattern converts the InsertOp to a ShuffleOp that works on a
449-
/// linearized vector.
450-
/// Following,
451-
/// vector.insert %source %destination [ position ]
452-
/// is converted to :
453-
/// %source_1d = vector.shape_cast %source
454-
/// %destination_1d = vector.shape_cast %destination
455-
/// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
456-
/// ] %out_nd = vector.shape_cast %out_1d
457-
/// `shuffle_indices_1d` is computed using the position of the original insert.
448+
/// This pattern linearizes `vector.insert` operations. It generates a 1-D
449+
/// version of the `vector.insert` operation when inserting a scalar into a
450+
/// vector. It generates a 1-D `vector.shuffle` operation when inserting a
451+
/// vector into another vector.
452+
///
453+
/// Example #1:
454+
///
455+
/// %0 = vector.insert %source, %destination[0] :
456+
/// vector<2x4xf32> into vector<2x2x4xf32>
457+
///
458+
/// is converted to:
459+
///
460+
/// %0 = vector.shape_cast %source : vector<2x4xf32> to vector<8xf32>
461+
/// %1 = vector.shape_cast %destination :
462+
/// vector<2x2x4xf32> to vector<16xf32>
463+
/// %2 = vector.shuffle %1, %0 [16, 17, 18, 19, 20, 21, 22, 23
464+
/// 8, 9, 10, 11, 12, 13, 14, 15] :
465+
/// vector<16xf32>, vector<8xf32>
466+
/// %3 = vector.shape_cast %2 : vector<16xf32> to vector<2x2x4xf32>
467+
///
468+
/// Example #2:
469+
///
470+
/// %0 = vector.insert %source, %destination[1, 2]: f32 into vector<2x4xf32>
471+
///
472+
/// is converted to:
473+
///
474+
/// %0 = vector.shape_cast %destination : vector<2x4xf32> to vector<8xf32>
475+
/// %1 = vector.insert %source, %0[6]: f32 into vector<8xf32>
476+
/// %2 = vector.shape_cast %1 : vector<8xf32> to vector<2x4xf32>
477+
///
458478
struct LinearizeVectorInsert final
459479
: public OpConversionPattern<vector::InsertOp> {
460480
using OpConversionPattern::OpConversionPattern;
@@ -468,48 +488,55 @@ struct LinearizeVectorInsert final
468488
insertOp.getDestVectorType());
469489
assert(dstTy && "vector type destination expected.");
470490

471-
// dynamic position is not supported
491+
// Dynamic position is not supported.
472492
if (insertOp.hasDynamicPosition())
473493
return rewriter.notifyMatchFailure(insertOp,
474494
"dynamic position is not supported.");
475495
auto srcTy = insertOp.getValueToStoreType();
476496
auto srcAsVec = dyn_cast<VectorType>(srcTy);
477-
uint64_t srcSize = 0;
478-
if (srcAsVec) {
479-
srcSize = srcAsVec.getNumElements();
480-
} else {
481-
return rewriter.notifyMatchFailure(insertOp,
482-
"scalars are not supported.");
483-
}
497+
uint64_t srcSize = srcAsVec ? srcAsVec.getNumElements() : 1;
484498

485499
auto dstShape = insertOp.getDestVectorType().getShape();
486500
const auto dstSize = insertOp.getDestVectorType().getNumElements();
487501
auto dstSizeForOffsets = dstSize;
488502

489-
// compute linearized offset
503+
// Compute linearized offset.
490504
int64_t linearizedOffset = 0;
491505
auto offsetsNd = insertOp.getStaticPosition();
492506
for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
493507
dstSizeForOffsets /= dstShape[dim];
494508
linearizedOffset += offset * dstSizeForOffsets;
495509
}
496510

511+
Location loc = insertOp.getLoc();
512+
Value valueToStore = adaptor.getValueToStore();
513+
514+
if (!isa<VectorType>(valueToStore.getType())) {
515+
// Scalar case: generate a 1-D insert.
516+
Value result = rewriter.createOrFold<vector::InsertOp>(
517+
loc, valueToStore, adaptor.getDest(), linearizedOffset);
518+
rewriter.replaceOp(insertOp, result);
519+
return success();
520+
}
521+
522+
// Vector case: generate a shuffle.
497523
llvm::SmallVector<int64_t, 2> indices(dstSize);
498524
auto *origValsUntil = indices.begin();
499525
std::advance(origValsUntil, linearizedOffset);
500-
std::iota(indices.begin(), origValsUntil,
501-
0); // original values that remain [0, offset)
526+
527+
// Original values that remain [0, offset).
528+
std::iota(indices.begin(), origValsUntil, 0);
502529
auto *newValsUntil = origValsUntil;
503530
std::advance(newValsUntil, srcSize);
504-
std::iota(origValsUntil, newValsUntil,
505-
dstSize); // new values [offset, offset+srcNumElements)
506-
std::iota(newValsUntil, indices.end(),
507-
linearizedOffset + srcSize); // the rest of original values
508-
// [offset+srcNumElements, end)
531+
// New values [offset, offset+srcNumElements).
532+
std::iota(origValsUntil, newValsUntil, dstSize);
533+
// The rest of original values [offset+srcNumElements, end);
534+
std::iota(newValsUntil, indices.end(), linearizedOffset + srcSize);
509535

510-
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
511-
insertOp, dstTy, adaptor.getDest(), adaptor.getValueToStore(), indices);
536+
Value result = rewriter.createOrFold<vector::ShuffleOp>(
537+
loc, dstTy, adaptor.getDest(), valueToStore, indices);
512538

539+
rewriter.replaceOp(insertOp, result);
513540
return success();
514541
}
515542
};

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,20 @@ func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x
294294

295295
// -----
296296

297+
// CHECK-LABEL: test_vector_insert_scalar
298+
// CHECK-SAME: (%[[DEST:.*]]: vector<2x4xf32>, %[[SRC:.*]]: f32) -> vector<2x4xf32> {
299+
func.func @test_vector_insert_scalar(%arg0: vector<2x4xf32>, %arg1: f32) -> vector<2x4xf32> {
300+
301+
// CHECK: %[[DEST_1D:.*]] = vector.shape_cast %[[DEST]] : vector<2x4xf32> to vector<8xf32>
302+
// CHECK: %[[INSERT_1D:.*]] = vector.insert %[[SRC]], %[[DEST_1D]] [6] : f32 into vector<8xf32>
303+
// CHECK: %[[RES:.*]] = vector.shape_cast %[[INSERT_1D]] : vector<8xf32> to vector<2x4xf32>
304+
// CHECK: return %[[RES]] : vector<2x4xf32>
305+
%0 = vector.insert %arg1, %arg0[1, 2]: f32 into vector<2x4xf32>
306+
return %0 : vector<2x4xf32>
307+
}
308+
309+
// -----
310+
297311
// CHECK-LABEL: test_vector_insert
298312
// CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
299313
func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
@@ -444,7 +458,7 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
444458
// CHECK-LABEL: linearize_create_mask
445459
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1>
446460
func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> {
447-
461+
448462
// CHECK: %[[C0:.*]] = arith.constant 0 : index
449463
// CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index
450464
// CHECK: %[[INDEXCAST:.*]] = arith.index_cast %[[CMP]] : i1 to index

0 commit comments

Comments
 (0)