Skip to content

Commit 4a4b0d3

Browse files
authored
[LPT] BroadcastTransformation: broadcast with 2 inputs support (#31258)
### Tickets: - *CVS-170294*
1 parent 761e85c commit 4a4b0d3

File tree

3 files changed

+53
-42
lines changed

3 files changed

+53
-42
lines changed

src/common/low_precision_transformations/src/broadcast.cpp

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,40 +5,30 @@
55
#include "low_precision/broadcast.hpp"
66

77
#include <memory>
8-
#include "openvino/opsets/opset1_decl.hpp"
9-
#include "openvino/opsets/opset3_decl.hpp"
10-
#include "openvino/pass/pattern/op/or.hpp"
11-
#include "openvino/pass/pattern/op/wrap_type.hpp"
12-
#include "low_precision/network_helper.hpp"
138

149
#include "itt.hpp"
10+
#include "low_precision/network_helper.hpp"
1511
#include "openvino/op/broadcast.hpp"
12+
#include "openvino/pass/pattern/op/wrap_type.hpp"
1613

1714
using namespace ov::pass::low_precision;
1815

1916
BroadcastTransformation::BroadcastTransformation(const Params& params) : TransparentBaseTransformation(params) {
2017
MATCHER_SCOPE(BroadcastTransformation);
21-
auto broadcast1 = pattern::wrap_type<ov::opset1::Broadcast>({
22-
pattern::wrap_type<ov::opset1::Multiply>(),
23-
ov::pass::pattern::any_input(),
24-
ov::pass::pattern::any_input() });
25-
26-
auto broadcast3 = pattern::wrap_type<ov::opset3::Broadcast>({
27-
pattern::wrap_type<ov::opset1::Multiply>(),
28-
ov::pass::pattern::any_input(),
29-
ov::pass::pattern::any_input() });
30-
31-
const auto matcher = std::make_shared<ov::pass::pattern::op::Or>(ov::OutputVector{ broadcast1, broadcast3 });
18+
using namespace ov::pass::pattern;
19+
auto mul = wrap_type<ov::op::v1::Multiply>();
20+
auto matcher = wrap_type<ov::op::v3::Broadcast, ov::op::v1::Broadcast>({mul, any_input(), any_input()}) |
21+
wrap_type<ov::op::v3::Broadcast>({mul, any_input()});
3222

33-
ov::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
23+
ov::graph_rewrite_callback callback = [this](Matcher& m) {
3424
auto op = m.get_match_root();
3525
if (transformation_callback(op)) {
3626
return false;
3727
}
3828
return transform(m);
3929
};
4030

41-
auto m = std::make_shared<ov::pass::pattern::Matcher>(matcher, matcher_name);
31+
auto m = std::make_shared<Matcher>(matcher, matcher_name);
4232
this->register_matcher(m, callback);
4333
}
4434

@@ -61,17 +51,26 @@ bool BroadcastTransformation::canBeTransformed(const std::shared_ptr<ov::Node>&
6151
return false;
6252
}
6353

64-
const auto targetShapeConstant = ov::as_type_ptr<ov::opset1::Constant>(layer->get_input_node_shared_ptr(1));
65-
const auto& targetShape = targetShapeConstant->cast_vector<int64_t>();
66-
if (targetShape[dequantization.channelDimIndex] != inputShape[dequantization.channelDimIndex].get_length()) {
54+
const auto& outputShape = layer->get_output_partial_shape(0);
55+
if (outputShape[dequantization.channelDimIndex] != inputShape[dequantization.channelDimIndex]) {
6756
return false;
6857
}
6958

70-
const auto axesMappingConstant = ov::as_type_ptr<ov::opset1::Constant>(layer->get_input_node_shared_ptr(2));
71-
const auto& axesMapping = axesMappingConstant->cast_vector<int64_t>();
72-
if (static_cast<size_t>(axesMapping[dequantization.channelDimIndex]) != dequantization.channelDimIndex) {
59+
const auto bcast = ov::as_type_ptr<ov::op::util::BroadcastBase>(layer);
60+
if (bcast == nullptr) {
7361
return false;
7462
}
63+
// axisMapping input affects the result only in case of explicit broadcast.
64+
if (bcast->get_broadcast_spec().m_type == ov::op::BroadcastType::EXPLICIT && bcast->get_input_size() == 3) {
65+
const auto axesMappingConstant = ov::as_type_ptr<ov::op::v0::Constant>(bcast->get_input_node_shared_ptr(2));
66+
if (!axesMappingConstant) {
67+
return false;
68+
}
69+
const auto& axesMapping = axesMappingConstant->cast_vector<size_t>();
70+
if (axesMapping[dequantization.channelDimIndex] != dequantization.channelDimIndex) {
71+
return false;
72+
}
73+
}
7574

7675
return true;
7776
}

src/common/low_precision_transformations/tests/broadcast_transformation.cpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,13 @@ typedef std::tuple<
4444
class BroadcastTransformation : public LayerTransformation, public testing::WithParamInterface<BroadcastTransformationParams> {
4545
public:
4646
void SetUp() override {
47-
const ov::PartialShape inputShape = std::get<0>(GetParam());
48-
const bool v1 = std::get<1>(GetParam());
49-
const BroadcastTransformationTestValues testValues = std::get<2>(GetParam());
50-
51-
// batch update support
52-
auto tagetShape = testValues.tagetShape;
53-
tagetShape[0] = inputShape[0].get_length();
54-
47+
const auto [inputShape, v1, testValues] = GetParam();
5548
actualFunction = BroadcastFunction::get(
5649
v1,
5750
inputShape,
5851
testValues.actual.precisionBeforeDequantization,
5952
testValues.actual.dequantizationBefore,
60-
tagetShape,
53+
testValues.tagetShape,
6154
testValues.axesMapping,
6255
testValues.actual.dequantizationAfter);
6356

@@ -70,16 +63,13 @@ class BroadcastTransformation : public LayerTransformation, public testing::With
7063
inputShape,
7164
testValues.expected.precisionBeforeDequantization,
7265
testValues.expected.dequantizationBefore,
73-
tagetShape,
66+
testValues.tagetShape,
7467
testValues.axesMapping,
7568
testValues.expected.dequantizationAfter);
7669
}
7770

7871
static std::string getTestCaseName(testing::TestParamInfo<BroadcastTransformationParams> obj) {
79-
const ov::PartialShape inputShape = std::get<0>(obj.param);
80-
const bool v1 = std::get<1>(obj.param);
81-
const BroadcastTransformationTestValues testValues = std::get<2>(obj.param);
82-
72+
const auto [inputShape, v1, testValues] = obj.param;
8373
std::ostringstream result;
8474
result <<
8575
v1 << "_" <<
@@ -108,7 +98,6 @@ TEST_P(BroadcastTransformation, CompareFunctions) {
10898
namespace hw_broadcast {
10999
const std::vector<ov::PartialShape> inputShapes = {
110100
{ 1, 3, 1, 1 },
111-
{ 4, 3, 1, 1 },
112101
};
113102

114103
const std::vector<BroadcastTransformationTestValues> testValues = {
@@ -181,6 +170,22 @@ const std::vector<BroadcastTransformationTestValues> testValues = {
181170
{{}, {}, {}},
182171
{{ov::element::f32}, {0.1f}, {0.2f}}
183172
}
173+
},
174+
{
175+
LayerTransformation::createParamsU8I8(),
176+
{ 1, 9, 9, 9},
177+
// empty axis mapping => bcast with 2 inputs is created
178+
{},
179+
{
180+
ov::element::u8,
181+
{{ov::element::f32}, {0.1f}, {0.2f}},
182+
{{}, {}, {}},
183+
},
184+
{
185+
ov::element::u8,
186+
{{}, {}, {}},
187+
{{ov::element::f32}, {0.1f}, {0.2f}}
188+
}
184189
}
185190
};
186191

src/tests/ov_helpers/ov_lpt_models/src/broadcast.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,18 @@ namespace subgraph {
1717

1818
namespace {
1919
template <typename T>
20-
std::shared_ptr<ov::Node> make_broadcast(const std::shared_ptr<ov::Node>& parent, const Shape& tagetShape, const Shape& axesMapping) {
20+
std::shared_ptr<ov::Node> make_broadcast(const std::shared_ptr<ov::Node>& parent,
21+
const Shape& tagetShape,
22+
const Shape& axesMapping) {
23+
if (axesMapping.empty()) {
24+
return std::make_shared<T>(
25+
parent,
26+
std::make_shared<ov::opset1::Constant>(ov::element::i32, Shape{tagetShape.size()}, tagetShape));
27+
}
2128
return std::make_shared<T>(
2229
parent,
23-
std::make_shared<ov::opset1::Constant>(ov::element::i32, Shape{ tagetShape.size() }, tagetShape),
24-
std::make_shared<ov::opset1::Constant>(ov::element::i32, Shape{ axesMapping.size() }, axesMapping));
30+
std::make_shared<ov::opset1::Constant>(ov::element::i32, Shape{tagetShape.size()}, tagetShape),
31+
std::make_shared<ov::opset1::Constant>(ov::element::i32, Shape{axesMapping.size()}, axesMapping));
2532
}
2633
} // namespace
2734

0 commit comments

Comments
 (0)