Skip to content

Commit 54a5cdc

Browse files
[Pass][EliminateConcatStridedSlice] Add issue fix for pass EliminateConcatStridedSlice (#29555)
### Details: Make EliminateConcatStridedSlice transformation more strict, add an additional concat axis check. Fixed the compile_model regression on GPU. ### Tickets: - CVS-164273 Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
1 parent 70e3813 commit 54a5cdc

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,24 @@ pass::EliminateConcatStridedSlice::EliminateConcatStridedSlice() {
555555
return false;
556556
}
557557

558+
// check that concatenated and split axis is the same
559+
auto check_axis = [concat_axis](const std::vector<int64_t>& masks) {
560+
for (size_t axis = 0; axis < masks.size(); ++axis) {
561+
if (masks[axis] != 1 && axis != static_cast<size_t>(concat_axis)) {
562+
return false;
563+
}
564+
if (masks[axis] != 0 && axis == static_cast<size_t>(concat_axis)) {
565+
return false;
566+
}
567+
}
568+
return true;
569+
};
570+
auto begin_mask = strided_slice_node->get_begin_mask();
571+
auto end_mask = strided_slice_node->get_end_mask();
572+
if (!check_axis(begin_mask) || !check_axis(end_mask)) {
573+
return false;
574+
}
575+
558576
auto begin_node = strided_slice_node->get_input_node_shared_ptr(1);
559577
const auto& begin_constant_node = ov::util::get_constant_from_source(begin_node);
560578
if (begin_constant_node == nullptr)

src/common/transformations/tests/common_optimizations/nop_elimination.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1641,6 +1641,50 @@ TEST_F(TransformationTestsF, EliminateConcatStridedSliceConcat) {
16411641
model_ref = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{param1, param2, param3});
16421642
}
16431643
}
1644+
1645+
TEST_F(TransformationTestsF, EliminateConcatStridedSliceConcatDiffAxis) {
1646+
{
1647+
int64_t axis = 2;
1648+
auto param1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 3});
1649+
auto param2 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 4});
1650+
auto param3 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 5});
1651+
auto concat = make_shared<ov::op::v0::Concat>(ov::as_output_vector({param1, param2, param3}), axis);
1652+
1653+
auto begin_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 0});
1654+
auto end_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 3, 0});
1655+
auto strided_slice1 = std::make_shared<ov::op::v1::StridedSlice>(concat,
1656+
begin_const1,
1657+
end_const1,
1658+
std::vector<int64_t>{1, 0, 1},
1659+
std::vector<int64_t>{1, 0, 1});
1660+
auto relu = std::make_shared<op::v0::Relu>(strided_slice1);
1661+
1662+
auto result = std::make_shared<op::v0::Result>(relu);
1663+
model = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{param1, param2, param3});
1664+
manager.register_pass<ov::pass::EliminateConcatStridedSlice>();
1665+
}
1666+
{
1667+
int64_t axis = 2;
1668+
auto param1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 3});
1669+
auto param2 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 4});
1670+
auto param3 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 5});
1671+
auto concat = make_shared<ov::op::v0::Concat>(ov::as_output_vector({param1, param2, param3}), axis);
1672+
1673+
auto begin_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 0});
1674+
auto end_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 3, 0});
1675+
auto strided_slice1 = std::make_shared<ov::op::v1::StridedSlice>(concat,
1676+
begin_const1,
1677+
end_const1,
1678+
std::vector<int64_t>{1, 0, 1},
1679+
std::vector<int64_t>{1, 0, 1});
1680+
auto relu = std::make_shared<op::v0::Relu>(strided_slice1);
1681+
1682+
auto result = std::make_shared<op::v0::Result>(relu);
1683+
1684+
model_ref = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{param1, param2, param3});
1685+
}
1686+
}
1687+
16441688
TEST_F(TransformationTestsF, EliminateStridedSlice) {
16451689
{
16461690
auto input = std::make_shared<op::v0::Parameter>(ov::element::f32,

0 commit comments

Comments
 (0)