Skip to content

Commit 7219aa8

Browse files
hyunbacktimxu826e-ddykim
authored
[GPU] Ignore SDPAScaleFusion pass when output of Q & V scales have di… (#29554)
Backport of (#29450) to releases/2025/1 Co-authored-by: jag.Xu <jia3.xu@intel.com> Co-authored-by: Eddy Kim <eddy.kim@intel.com>
1 parent 047ae94 commit 7219aa8

File tree

2 files changed

+66
-9
lines changed

2 files changed

+66
-9
lines changed

src/common/transformations/src/transformations/common_optimizations/sdpa_scale_fusion.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ SDPAScaleFusion::SDPAScaleFusion() {
4949

5050
auto sdpa = m.get_match_root();
5151

52-
const bool has_q_scale = pattern_map.count(scaled_q);
53-
const bool has_k_scale = pattern_map.count(scaled_k);
52+
bool has_q_scale = pattern_map.count(scaled_q);
53+
bool has_k_scale = pattern_map.count(scaled_k);
5454

5555
// Nothing to do
5656
if (!has_q_scale && !has_k_scale)
@@ -83,22 +83,32 @@ SDPAScaleFusion::SDPAScaleFusion() {
8383
// Extract scalar scale values for Q and K if those are constant and set new inputs for SDPA
8484
if (has_q_scale) {
8585
scale_q_node = pattern_map.at(scale_q).get_node_shared_ptr();
86-
if (ov::is_type<ov::op::v0::Constant>(scale_q_node)) {
87-
scale_q_value = ov::as_type_ptr<ov::op::v0::Constant>(scale_q_node)->cast_vector<float>()[0];
88-
q_input = pattern_map.at(q);
86+
if (pattern_map.at(q).get_element_type() == q_input.get_element_type()) {
87+
if (ov::is_type<ov::op::v0::Constant>(scale_q_node)) {
88+
scale_q_value = ov::as_type_ptr<ov::op::v0::Constant>(scale_q_node)->cast_vector<float>()[0];
89+
q_input = pattern_map.at(q);
90+
}
91+
} else {
92+
has_q_scale = false;
8993
}
9094
}
9195
if (has_k_scale) {
9296
scale_k_node = pattern_map.at(scale_k).get_node_shared_ptr();
93-
if (ov::is_type<ov::op::v0::Constant>(scale_k_node)) {
94-
scale_k_value = ov::as_type_ptr<ov::op::v0::Constant>(scale_k_node)->cast_vector<float>()[0];
95-
k_input = pattern_map.at(k);
97+
if (pattern_map.at(k).get_element_type() == k_input.get_element_type()) {
98+
if (ov::is_type<ov::op::v0::Constant>(scale_k_node)) {
99+
scale_k_value = ov::as_type_ptr<ov::op::v0::Constant>(scale_k_node)->cast_vector<float>()[0];
100+
k_input = pattern_map.at(k);
101+
}
102+
} else {
103+
has_k_scale = false;
96104
}
97105
}
98106

107+
if (!has_q_scale && !has_k_scale)
108+
return false;
109+
99110
Output<ov::Node> new_scale_node;
100111
auto new_scale_val = prev_scale_value * scale_q_value * scale_k_value;
101-
102112
// If new scale is 1 and we have non-constant scale node for either Q or K, then we can make it a scale of SDPA
103113
if (new_scale_val == 1.0f) {
104114
if (has_q_scale && !ov::is_type<ov::op::v0::Constant>(scale_q_node)) {

src/common/transformations/tests/common_optimizations/sdpa_scale_fusion_test.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "openvino/op/constant.hpp"
1616
#include "openvino/op/multiply.hpp"
1717
#include "openvino/op/scaled_dot_product_attention.hpp"
18+
#include "ov_ops/type_relaxed.hpp"
1819

1920
using namespace testing;
2021
using namespace ov::pass;
@@ -226,3 +227,49 @@ TEST_F(TransformationTestsF, SDPAScaleFusionTest5) {
226227
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
227228
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
228229
}
230+
231+
TEST_F(TransformationTestsF, SDPAScaleFusionTest6) {
232+
const PartialShape query_shape{1, 32, -1, 32};
233+
const PartialShape key_shape{1, 32, -1, 32};
234+
const PartialShape value_shape{1, 32, -1, 32};
235+
236+
const auto query = std::make_shared<ov::op::v0::Parameter>(element::f16, query_shape);
237+
const auto key = std::make_shared<ov::op::v0::Parameter>(element::i8, key_shape);
238+
const auto value = std::make_shared<ov::op::v0::Parameter>(element::f16, value_shape);
239+
const auto scale_const = ov::op::v0::Constant::create(element::f16, ov::Shape{}, std::vector<float>{8.0f});
240+
const auto v_scaled = std::make_shared<ov::op::v1::Multiply>(value, scale_const);
241+
const auto casual = false;
242+
{
243+
const auto q_scaled = std::make_shared<ov::op::v1::Multiply>(query, scale_const);
244+
const auto k_scaled = std::make_shared<ov::op::TypeRelaxed<ov::op::v1::Multiply>>(
245+
std::vector<element::Type>{element::f16, element::f16},
246+
std::vector<element::Type>{element::f16},
247+
ov::op::TemporaryReplaceOutputType(key, element::f16).get(),
248+
ov::op::TemporaryReplaceOutputType(scale_const, element::f16).get());
249+
const auto sdpa =
250+
std::make_shared<ov::op::v13::ScaledDotProductAttention>(q_scaled, k_scaled, v_scaled, casual);
251+
252+
model = std::make_shared<ov::Model>(NodeVector{sdpa}, ParameterVector{query, key, value});
253+
manager.register_pass<ov::pass::SDPAScaleFusion>();
254+
}
255+
256+
{
257+
const auto k_scaled_ref = std::make_shared<ov::op::TypeRelaxed<ov::op::v1::Multiply>>(
258+
std::vector<element::Type>{element::f16, element::f16},
259+
std::vector<element::Type>{element::f16},
260+
ov::op::TemporaryReplaceOutputType(key, element::f16).get(),
261+
ov::op::TemporaryReplaceOutputType(scale_const, element::f16).get());
262+
const auto new_mask_const = ov::op::v0::Constant::create(element::f16, ov::Shape{}, std::vector<float>{0.0f});
263+
const auto new_scale_const =
264+
ov::op::v0::Constant::create(element::f16, ov::Shape{}, std::vector<float>{8.0f / std::sqrt(32.0f)});
265+
const auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(query,
266+
k_scaled_ref,
267+
v_scaled,
268+
new_mask_const,
269+
new_scale_const,
270+
casual);
271+
model_ref = std::make_shared<ov::Model>(NodeVector{sdpa}, ParameterVector{query, key, value});
272+
}
273+
274+
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
275+
}

0 commit comments

Comments
 (0)