5
5
#include " low_precision/broadcast.hpp"
6
6
7
7
#include < memory>
8
- #include " openvino/opsets/opset1_decl.hpp"
9
- #include " openvino/opsets/opset3_decl.hpp"
10
- #include " openvino/pass/pattern/op/or.hpp"
11
- #include " openvino/pass/pattern/op/wrap_type.hpp"
12
- #include " low_precision/network_helper.hpp"
13
8
14
9
#include " itt.hpp"
10
+ #include " low_precision/network_helper.hpp"
15
11
#include " openvino/op/broadcast.hpp"
12
+ #include " openvino/pass/pattern/op/wrap_type.hpp"
16
13
17
14
using namespace ov ::pass::low_precision;
18
15
19
16
BroadcastTransformation::BroadcastTransformation (const Params& params) : TransparentBaseTransformation(params) {
20
17
MATCHER_SCOPE (BroadcastTransformation);
21
- auto broadcast1 = pattern::wrap_type<ov::opset1::Broadcast>({
22
- pattern::wrap_type<ov::opset1::Multiply>(),
23
- ov::pass::pattern::any_input (),
24
- ov::pass::pattern::any_input () });
25
-
26
- auto broadcast3 = pattern::wrap_type<ov::opset3::Broadcast>({
27
- pattern::wrap_type<ov::opset1::Multiply>(),
28
- ov::pass::pattern::any_input (),
29
- ov::pass::pattern::any_input () });
30
-
31
- const auto matcher = std::make_shared<ov::pass::pattern::op::Or>(ov::OutputVector{ broadcast1, broadcast3 });
18
+ using namespace ov ::pass::pattern;
19
+ auto mul = wrap_type<ov::op::v1::Multiply>();
20
+ auto matcher = wrap_type<ov::op::v3::Broadcast, ov::op::v1::Broadcast>({mul, any_input (), any_input ()}) |
21
+ wrap_type<ov::op::v3::Broadcast>({mul, any_input ()});
32
22
33
- ov::graph_rewrite_callback callback = [this ](pattern:: Matcher& m) {
23
+ ov::graph_rewrite_callback callback = [this ](Matcher& m) {
34
24
auto op = m.get_match_root ();
35
25
if (transformation_callback (op)) {
36
26
return false ;
37
27
}
38
28
return transform (m);
39
29
};
40
30
41
- auto m = std::make_shared<ov::pass::pattern:: Matcher>(matcher, matcher_name);
31
+ auto m = std::make_shared<Matcher>(matcher, matcher_name);
42
32
this ->register_matcher (m, callback);
43
33
}
44
34
@@ -61,17 +51,26 @@ bool BroadcastTransformation::canBeTransformed(const std::shared_ptr<ov::Node>&
61
51
return false ;
62
52
}
63
53
64
- const auto targetShapeConstant = ov::as_type_ptr<ov::opset1::Constant>(layer->get_input_node_shared_ptr (1 ));
65
- const auto & targetShape = targetShapeConstant->cast_vector <int64_t >();
66
- if (targetShape[dequantization.channelDimIndex ] != inputShape[dequantization.channelDimIndex ].get_length ()) {
54
+ const auto & outputShape = layer->get_output_partial_shape (0 );
55
+ if (outputShape[dequantization.channelDimIndex ] != inputShape[dequantization.channelDimIndex ]) {
67
56
return false ;
68
57
}
69
58
70
- const auto axesMappingConstant = ov::as_type_ptr<ov::opset1::Constant>(layer->get_input_node_shared_ptr (2 ));
71
- const auto & axesMapping = axesMappingConstant->cast_vector <int64_t >();
72
- if (static_cast <size_t >(axesMapping[dequantization.channelDimIndex ]) != dequantization.channelDimIndex ) {
59
+ const auto bcast = ov::as_type_ptr<ov::op::util::BroadcastBase>(layer);
60
+ if (bcast == nullptr ) {
73
61
return false ;
74
62
}
63
+ // axisMapping input affects the result only in case of explicit broadcast.
64
+ if (bcast->get_broadcast_spec ().m_type == ov::op::BroadcastType::EXPLICIT && bcast->get_input_size () == 3 ) {
65
+ const auto axesMappingConstant = ov::as_type_ptr<ov::op::v0::Constant>(bcast->get_input_node_shared_ptr (2 ));
66
+ if (!axesMappingConstant) {
67
+ return false ;
68
+ }
69
+ const auto & axesMapping = axesMappingConstant->cast_vector <size_t >();
70
+ if (axesMapping[dequantization.channelDimIndex ] != dequantization.channelDimIndex ) {
71
+ return false ;
72
+ }
73
+ }
75
74
76
75
return true ;
77
76
}
0 commit comments