|
18 | 18 | #include "openvino/op/constant.hpp"
|
19 | 19 | #include "openvino/op/concat.hpp"
|
20 | 20 | #include "openvino/op/cos.hpp"
|
| 21 | +#include "openvino/op/multiply.hpp" |
21 | 22 | #include "openvino/op/sin.hpp"
|
22 | 23 | #include "openvino/op/reshape.hpp"
|
23 | 24 | #include "openvino/op/squeeze.hpp"
|
@@ -226,3 +227,68 @@ TEST_F(TransformationTestsF, IncreasePositionIdsReshapeAfterMatmul) {
|
226 | 227 | }
|
227 | 228 | comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
228 | 229 | }
|
| 230 | + |
| 231 | +TEST_F(TransformationTestsF, IncreasePositionIdsLongRoPE) { |
| 232 | + { |
| 233 | + auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ -1 }); |
| 234 | + auto input_convert = std::make_shared<ov::op::v0::Convert>(input, ov::element::i32); |
| 235 | + auto input_unsqueeze = std::make_shared<ov::op::v0::Unsqueeze>(input_convert, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{3}, std::vector<int64_t>{-1, 1, 1})); |
| 236 | + auto input_convert_fp = std::make_shared<ov::op::v0::Convert>(input_unsqueeze, ov::element::f16); |
| 237 | + auto rotary_embd = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1, 48, 1}); |
| 238 | + |
| 239 | + auto matmul = std::make_shared<ov::op::v0::MatMul>(rotary_embd, input_convert_fp); |
| 240 | + auto reshape = std::make_shared<ov::op::v1::Reshape>(matmul, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{3}, std::vector<int64_t>{0, 1, 48}), true); |
| 241 | + auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{reshape, reshape}, 2); |
| 242 | + |
| 243 | + auto cos = std::make_shared<ov::op::v0::Cos>(concat); |
| 244 | + auto sin = std::make_shared<ov::op::v0::Sin>(concat); |
| 245 | + |
| 246 | + auto const_scale_const = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{ 1, 1, 1 }, { 1.19043 }); |
| 247 | + auto const_scale = std::make_shared<ov::op::v1::Multiply>(cos, const_scale_const); |
| 248 | + auto sin_scale_const = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{ 1, 1, 1 }, { 1.19043 }); |
| 249 | + auto sin_scale = std::make_shared<ov::op::v1::Multiply>(sin, sin_scale_const); |
| 250 | + |
| 251 | + auto cos_unsqueeze = std::make_shared<ov::op::v1::Reshape>(const_scale, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{-1, 1, 1, 96}), true); |
| 252 | + auto sin_unsqueeze = std::make_shared<ov::op::v1::Reshape>(sin_scale, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{-1, 1, 1, 96}), true); |
| 253 | + |
| 254 | + auto rope_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape::dynamic(4)); |
| 255 | + auto rope = std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{rope_input, cos_unsqueeze, sin_unsqueeze}, ov::op::internal::RoPE::Config()); |
| 256 | + |
| 257 | + model = std::make_shared<ov::Model>(ov::NodeVector{ rope }, ov::ParameterVector{ input, rope_input, rotary_embd }); |
| 258 | + manager.register_pass<IncreasePositionIdsPrecision>(); |
| 259 | + } |
| 260 | + { |
| 261 | + auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ -1 }); |
| 262 | + auto input_convert = std::make_shared<ov::op::v0::Convert>(input, ov::element::i32); |
| 263 | + auto input_unsqueeze = std::make_shared<ov::op::v0::Unsqueeze>(input_convert, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{3}, std::vector<int64_t>{-1, 1, 1})); |
| 264 | + auto input_convert_fp = std::make_shared<ov::op::v0::Convert>(input_unsqueeze, ov::element::f16); |
| 265 | + auto rotary_embd = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1, 48, 1}); |
| 266 | + |
| 267 | + auto input_convert_f32 = std::make_shared<ov::op::v0::Convert>(input_unsqueeze, ov::element::f32); |
| 268 | + auto rotary_embd_convert_f32 = std::make_shared<ov::op::v0::Convert>(rotary_embd, ov::element::f32); |
| 269 | + |
| 270 | + auto matmul = std::make_shared<ov::op::v0::MatMul>(rotary_embd_convert_f32, input_convert_f32); |
| 271 | + auto reshape = std::make_shared<ov::op::v1::Reshape>(matmul, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{3}, std::vector<int64_t>{0, 1, 48}), true); |
| 272 | + auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{reshape, reshape}, 2); |
| 273 | + |
| 274 | + auto cos = std::make_shared<ov::op::v0::Cos>(concat); |
| 275 | + auto sin = std::make_shared<ov::op::v0::Sin>(concat); |
| 276 | + |
| 277 | + auto cos_convert = std::make_shared<ov::op::v0::Convert>(cos, ov::element::f16); |
| 278 | + auto sin_convert = std::make_shared<ov::op::v0::Convert>(sin, ov::element::f16); |
| 279 | + |
| 280 | + auto const_scale_const = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{ 1, 1, 1 }, { 1.19043 }); |
| 281 | + auto const_scale = std::make_shared<ov::op::v1::Multiply>(cos_convert, const_scale_const); |
| 282 | + auto sin_scale_const = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{ 1, 1, 1 }, { 1.19043 }); |
| 283 | + auto sin_scale = std::make_shared<ov::op::v1::Multiply>(sin_convert, sin_scale_const); |
| 284 | + |
| 285 | + auto cos_unsqueeze = std::make_shared<ov::op::v1::Reshape>(const_scale, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{-1, 1, 1, 96}), true); |
| 286 | + auto sin_unsqueeze = std::make_shared<ov::op::v1::Reshape>(sin_scale, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{-1, 1, 1, 96}), true); |
| 287 | + |
| 288 | + auto rope_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape::dynamic(4)); |
| 289 | + auto rope = std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{rope_input, cos_unsqueeze, sin_unsqueeze}, ov::op::internal::RoPE::Config()); |
| 290 | + |
| 291 | + model_ref = std::make_shared<ov::Model>(ov::NodeVector{ rope }, ov::ParameterVector{ input, rope_input, rotary_embd }); |
| 292 | + } |
| 293 | + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); |
| 294 | +} |
0 commit comments