Skip to content

Commit e1ec5a6

Browse files
[GPU] Force FP32 to avoid losing precision on long contexts during LongRoPE (#29591)
Backport of (#29556) to releases/2025/1 --------- Signed-off-by: Andrew Park <andrew.park@intel.com>
1 parent 383d470 commit e1ec5a6

File tree

2 files changed

+74
-2
lines changed

2 files changed

+74
-2
lines changed

src/plugins/intel_gpu/src/plugin/transformations/increase_position_ids_precision.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,19 @@ IncreasePositionIdsPrecision::IncreasePositionIdsPrecision() {
4040
auto sin_reshape = wrap_type<ov::op::v1::Reshape>({sin, wrap_type<ov::op::v0::Constant>()});
4141
auto sin_squeeze = wrap_type<ov::op::v0::Squeeze>({sin, wrap_type<ov::op::v0::Constant>()});
4242
auto sin_unsqueeze = wrap_type<ov::op::v0::Unsqueeze>({sin, wrap_type<ov::op::v0::Constant>()});
43+
// Adjust scale factor to positional embedding for LongRoPE
44+
auto sin_multiply = wrap_type<ov::op::v1::Multiply>({sin, wrap_type<ov::op::v0::Constant>()});
45+
auto sin_multiply_reshape = wrap_type<ov::op::v1::Reshape>({sin_multiply, wrap_type<ov::op::v0::Constant>()});
4346

4447
auto cos_reshape = wrap_type<ov::op::v1::Reshape>({cos, wrap_type<ov::op::v0::Constant>()});
4548
auto cos_squeeze = wrap_type<ov::op::v0::Squeeze>({cos, wrap_type<ov::op::v0::Constant>()});
4649
auto cos_unsqueeze = wrap_type<ov::op::v0::Unsqueeze>({cos, wrap_type<ov::op::v0::Constant>()});
50+
// Adjust scale factor to positional embedding for LongRoPE
51+
auto cos_multiply = wrap_type<ov::op::v1::Multiply>({cos, wrap_type<ov::op::v0::Constant>()});
52+
auto cos_multiply_reshape = wrap_type<ov::op::v1::Reshape>({cos_multiply, wrap_type<ov::op::v0::Constant>()});
4753

48-
auto rope_sin_input = std::make_shared<Or>(OutputVector{sin_reshape, sin_squeeze, sin_unsqueeze, sin});
49-
auto rope_cos_input = std::make_shared<Or>(OutputVector{cos_reshape, cos_squeeze, cos_unsqueeze, cos});
54+
auto rope_sin_input = std::make_shared<Or>(OutputVector{sin_reshape, sin_squeeze, sin_unsqueeze, sin_multiply_reshape, sin});
55+
auto rope_cos_input = std::make_shared<Or>(OutputVector{cos_reshape, cos_squeeze, cos_unsqueeze, cos_multiply_reshape, cos});
5056

5157
auto rope = wrap_type<ov::op::internal::RoPE>({any_input(), rope_cos_input, rope_sin_input});
5258

src/plugins/intel_gpu/tests/unit/transformations/increase_position_ids_precision_test.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "openvino/op/constant.hpp"
1919
#include "openvino/op/concat.hpp"
2020
#include "openvino/op/cos.hpp"
21+
#include "openvino/op/multiply.hpp"
2122
#include "openvino/op/sin.hpp"
2223
#include "openvino/op/reshape.hpp"
2324
#include "openvino/op/squeeze.hpp"
@@ -226,3 +227,68 @@ TEST_F(TransformationTestsF, IncreasePositionIdsReshapeAfterMatmul) {
226227
}
227228
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
228229
}
230+
231+
TEST_F(TransformationTestsF, IncreasePositionIdsLongRoPE) {
232+
{
233+
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ -1 });
234+
auto input_convert = std::make_shared<ov::op::v0::Convert>(input, ov::element::i32);
235+
auto input_unsqueeze = std::make_shared<ov::op::v0::Unsqueeze>(input_convert, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{3}, std::vector<int64_t>{-1, 1, 1}));
236+
auto input_convert_fp = std::make_shared<ov::op::v0::Convert>(input_unsqueeze, ov::element::f16);
237+
auto rotary_embd = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1, 48, 1});
238+
239+
auto matmul = std::make_shared<ov::op::v0::MatMul>(rotary_embd, input_convert_fp);
240+
auto reshape = std::make_shared<ov::op::v1::Reshape>(matmul, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{3}, std::vector<int64_t>{0, 1, 48}), true);
241+
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{reshape, reshape}, 2);
242+
243+
auto cos = std::make_shared<ov::op::v0::Cos>(concat);
244+
auto sin = std::make_shared<ov::op::v0::Sin>(concat);
245+
246+
auto const_scale_const = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{ 1, 1, 1 }, { 1.19043 });
247+
auto const_scale = std::make_shared<ov::op::v1::Multiply>(cos, const_scale_const);
248+
auto sin_scale_const = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{ 1, 1, 1 }, { 1.19043 });
249+
auto sin_scale = std::make_shared<ov::op::v1::Multiply>(sin, sin_scale_const);
250+
251+
auto cos_unsqueeze = std::make_shared<ov::op::v1::Reshape>(const_scale, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{-1, 1, 1, 96}), true);
252+
auto sin_unsqueeze = std::make_shared<ov::op::v1::Reshape>(sin_scale, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{-1, 1, 1, 96}), true);
253+
254+
auto rope_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape::dynamic(4));
255+
auto rope = std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{rope_input, cos_unsqueeze, sin_unsqueeze}, ov::op::internal::RoPE::Config());
256+
257+
model = std::make_shared<ov::Model>(ov::NodeVector{ rope }, ov::ParameterVector{ input, rope_input, rotary_embd });
258+
manager.register_pass<IncreasePositionIdsPrecision>();
259+
}
260+
{
261+
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ -1 });
262+
auto input_convert = std::make_shared<ov::op::v0::Convert>(input, ov::element::i32);
263+
auto input_unsqueeze = std::make_shared<ov::op::v0::Unsqueeze>(input_convert, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{3}, std::vector<int64_t>{-1, 1, 1}));
264+
auto input_convert_fp = std::make_shared<ov::op::v0::Convert>(input_unsqueeze, ov::element::f16);
265+
auto rotary_embd = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1, 48, 1});
266+
267+
auto input_convert_f32 = std::make_shared<ov::op::v0::Convert>(input_unsqueeze, ov::element::f32);
268+
auto rotary_embd_convert_f32 = std::make_shared<ov::op::v0::Convert>(rotary_embd, ov::element::f32);
269+
270+
auto matmul = std::make_shared<ov::op::v0::MatMul>(rotary_embd_convert_f32, input_convert_f32);
271+
auto reshape = std::make_shared<ov::op::v1::Reshape>(matmul, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{3}, std::vector<int64_t>{0, 1, 48}), true);
272+
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{reshape, reshape}, 2);
273+
274+
auto cos = std::make_shared<ov::op::v0::Cos>(concat);
275+
auto sin = std::make_shared<ov::op::v0::Sin>(concat);
276+
277+
auto cos_convert = std::make_shared<ov::op::v0::Convert>(cos, ov::element::f16);
278+
auto sin_convert = std::make_shared<ov::op::v0::Convert>(sin, ov::element::f16);
279+
280+
auto const_scale_const = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{ 1, 1, 1 }, { 1.19043 });
281+
auto const_scale = std::make_shared<ov::op::v1::Multiply>(cos_convert, const_scale_const);
282+
auto sin_scale_const = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{ 1, 1, 1 }, { 1.19043 });
283+
auto sin_scale = std::make_shared<ov::op::v1::Multiply>(sin_convert, sin_scale_const);
284+
285+
auto cos_unsqueeze = std::make_shared<ov::op::v1::Reshape>(const_scale, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{-1, 1, 1, 96}), true);
286+
auto sin_unsqueeze = std::make_shared<ov::op::v1::Reshape>(sin_scale, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{-1, 1, 1, 96}), true);
287+
288+
auto rope_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape::dynamic(4));
289+
auto rope = std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{rope_input, cos_unsqueeze, sin_unsqueeze}, ov::op::internal::RoPE::Config());
290+
291+
model_ref = std::make_shared<ov::Model>(ov::NodeVector{ rope }, ov::ParameterVector{ input, rope_input, rotary_embd });
292+
}
293+
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
294+
}

0 commit comments

Comments
 (0)