Skip to content

Commit 517cda1

Browse files
linuxlonelyeaglejoker-ephbanach-space
authored
[mlir][vector] Add foldInsertUseChain folder function to insert op (#147045)
When the result of an insert op is used by an insert op, and the subsequent insert op is inserted at the same location as the previous insert op, replaces the dest of the subsequent insert op with the dest of the previous insert op.This is because the previous insert op does not affect subsequent insert ops. --------- Co-authored-by: Mehdi Amini <joker.eph@gmail.com> Co-authored-by: Andrzej Warzyński <andrzej.warzynski@gmail.com>
1 parent d440809 commit 517cda1

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3334,7 +3334,6 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
33343334
return success();
33353335
}
33363336
};
3337-
33383337
} // namespace
33393338

33403339
static Attribute
@@ -3387,12 +3386,26 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
33873386
return newAttr;
33883387
}
33893388

3389+
/// Folder to replace the `dest` operand of the insert op with the root dest of
3390+
/// the insert op use chain.
3391+
static Value foldInsertUseChain(InsertOp insertOp) {
3392+
auto destInsert = insertOp.getDest().getDefiningOp<InsertOp>();
3393+
if (!destInsert)
3394+
return {};
3395+
3396+
if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3397+
return {};
3398+
3399+
insertOp.setOperand(1, destInsert.getDest());
3400+
return insertOp.getResult();
3401+
}
3402+
33903403
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
33913404
MLIRContext *context) {
33923405
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
33933406
}
33943407

3395-
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3408+
OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
33963409
// Do not create constants with more than `vectorSizeFoldThreashold` elements,
33973410
// unless the source vector constant has a single use.
33983411
constexpr int64_t vectorSizeFoldThreshold = 256;
@@ -3407,6 +3420,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
34073420
SmallVector<Value> operands = {getValueToStore(), getDest()};
34083421
auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
34093422

3423+
if (auto res = foldInsertUseChain(*this))
3424+
return res;
34103425
if (auto res = foldPoisonIndexInsertExtractOp(
34113426
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
34123427
return res;

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3470,3 +3470,32 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3
34703470
%res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
34713471
return %res : vector<4x1xi32>
34723472
}
3473+
3474+
// -----
3475+
3476+
// CHECK-LABEL: @fold_insert_use_chain(
3477+
// CHECK-SAME: %[[ARG:.*]]: vector<4x4xf32>,
3478+
// CHECK-SAME: %[[VAL:.*]]: f32,
3479+
// CHECK-SAME: %[[POS:.*]]: index) -> vector<4x4xf32> {
3480+
// CHECK-NEXT: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] {{\[}}%[[POS]], 0] : f32 into vector<4x4xf32>
3481+
// CHECK-NEXT: return %[[RES]] : vector<4x4xf32>
3482+
func.func @fold_insert_use_chain(%arg : vector<4x4xf32>, %val : f32, %pos: index) -> vector<4x4xf32> {
3483+
%v_0 = vector.insert %val, %arg[%pos, 0] : f32 into vector<4x4xf32>
3484+
%v_1 = vector.insert %val, %v_0[%pos, 0] : f32 into vector<4x4xf32>
3485+
%v_2 = vector.insert %val, %v_1[%pos, 0] : f32 into vector<4x4xf32>
3486+
return %v_2 : vector<4x4xf32>
3487+
}
3488+
3489+
// -----
3490+
3491+
// CHECK-LABEL: @no_fold_insert_use_chain_mismatch_static_position(
3492+
// CHECK-SAME: %[[ARG:.*]]: vector<4xf32>,
3493+
// CHECK-SAME: %[[VAL:.*]]: f32) -> vector<4xf32> {
3494+
// CHECK: %[[V_0:.*]] = vector.insert %[[VAL]], %[[ARG]] [0] : f32 into vector<4xf32>
3495+
// CHECK: %[[V_1:.*]] = vector.insert %[[VAL]], %[[V_0]] [1] : f32 into vector<4xf32>
3496+
// CHECK: return %[[V_1]] : vector<4xf32>
3497+
func.func @no_fold_insert_use_chain_mismatch_static_position(%arg : vector<4xf32>, %val : f32) -> vector<4xf32> {
3498+
%v_0 = vector.insert %val, %arg[0] : f32 into vector<4xf32>
3499+
%v_1 = vector.insert %val, %v_0[1] : f32 into vector<4xf32>
3500+
return %v_1 : vector<4xf32>
3501+
}

0 commit comments

Comments
 (0)