|
15 | 15 | #include "openvino/op/constant.hpp"
|
16 | 16 | #include "openvino/op/multiply.hpp"
|
17 | 17 | #include "openvino/op/scaled_dot_product_attention.hpp"
|
| 18 | +#include "ov_ops/type_relaxed.hpp" |
18 | 19 |
|
19 | 20 | using namespace testing;
|
20 | 21 | using namespace ov::pass;
|
@@ -226,3 +227,49 @@ TEST_F(TransformationTestsF, SDPAScaleFusionTest5) {
|
226 | 227 | comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
227 | 228 | comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
228 | 229 | }
|
| 230 | + |
| 231 | +TEST_F(TransformationTestsF, SDPAScaleFusionTest6) { |
| 232 | + const PartialShape query_shape{1, 32, -1, 32}; |
| 233 | + const PartialShape key_shape{1, 32, -1, 32}; |
| 234 | + const PartialShape value_shape{1, 32, -1, 32}; |
| 235 | + |
| 236 | + const auto query = std::make_shared<ov::op::v0::Parameter>(element::f16, query_shape); |
| 237 | + const auto key = std::make_shared<ov::op::v0::Parameter>(element::i8, key_shape); |
| 238 | + const auto value = std::make_shared<ov::op::v0::Parameter>(element::f16, value_shape); |
| 239 | + const auto scale_const = ov::op::v0::Constant::create(element::f16, ov::Shape{}, std::vector<float>{8.0f}); |
| 240 | + const auto v_scaled = std::make_shared<ov::op::v1::Multiply>(value, scale_const); |
| 241 | + const auto casual = false; |
| 242 | + { |
| 243 | + const auto q_scaled = std::make_shared<ov::op::v1::Multiply>(query, scale_const); |
| 244 | + const auto k_scaled = std::make_shared<ov::op::TypeRelaxed<ov::op::v1::Multiply>>( |
| 245 | + std::vector<element::Type>{element::f16, element::f16}, |
| 246 | + std::vector<element::Type>{element::f16}, |
| 247 | + ov::op::TemporaryReplaceOutputType(key, element::f16).get(), |
| 248 | + ov::op::TemporaryReplaceOutputType(scale_const, element::f16).get()); |
| 249 | + const auto sdpa = |
| 250 | + std::make_shared<ov::op::v13::ScaledDotProductAttention>(q_scaled, k_scaled, v_scaled, casual); |
| 251 | + |
| 252 | + model = std::make_shared<ov::Model>(NodeVector{sdpa}, ParameterVector{query, key, value}); |
| 253 | + manager.register_pass<ov::pass::SDPAScaleFusion>(); |
| 254 | + } |
| 255 | + |
| 256 | + { |
| 257 | + const auto k_scaled_ref = std::make_shared<ov::op::TypeRelaxed<ov::op::v1::Multiply>>( |
| 258 | + std::vector<element::Type>{element::f16, element::f16}, |
| 259 | + std::vector<element::Type>{element::f16}, |
| 260 | + ov::op::TemporaryReplaceOutputType(key, element::f16).get(), |
| 261 | + ov::op::TemporaryReplaceOutputType(scale_const, element::f16).get()); |
| 262 | + const auto new_mask_const = ov::op::v0::Constant::create(element::f16, ov::Shape{}, std::vector<float>{0.0f}); |
| 263 | + const auto new_scale_const = |
| 264 | + ov::op::v0::Constant::create(element::f16, ov::Shape{}, std::vector<float>{8.0f / std::sqrt(32.0f)}); |
| 265 | + const auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(query, |
| 266 | + k_scaled_ref, |
| 267 | + v_scaled, |
| 268 | + new_mask_const, |
| 269 | + new_scale_const, |
| 270 | + casual); |
| 271 | + model_ref = std::make_shared<ov::Model>(NodeVector{sdpa}, ParameterVector{query, key, value}); |
| 272 | + } |
| 273 | + |
| 274 | + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); |
| 275 | +} |
0 commit comments