Skip to content

Commit 719af72

Browse files
authored
[PT] Fix aten::add decomposition for i4 weights (#29525)
**Details:** Fix aten::add decomposition for i4 weights Cherry-pick of #29511 **Ticket:** 164397 --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
1 parent c796469 commit 719af72

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

src/frontends/pytorch/src/op/add.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ OutputVector translate_add_common(const NodeContext& context, bool inplace) {
6767

6868
if (alpha.get_node_shared_ptr()) {
6969
auto converted_alpha = ComplexTypeMark::convert_like(context, alpha, rhs);
70-
rhs = ComplexTypeMark::mul(context, rhs, converted_alpha);
70+
rhs = ComplexTypeMark::mul(context, converted_alpha, rhs);
7171
}
7272

7373
auto add = ComplexTypeMark::add(context, lhs, rhs);

src/frontends/pytorch/src/op/sub.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ OutputVector translate_sub_common(const NodeContext& context, bool inplace) {
3838
auto alpha = context.get_input(2);
3939
auto casted_alpha = ComplexTypeMark::convert_like(context, alpha, y);
4040

41-
y = ComplexTypeMark::mul(context, y, casted_alpha);
41+
y = ComplexTypeMark::mul(context, casted_alpha, y);
4242
}
4343

4444
auto sub = ComplexTypeMark::sub(context, x, y);

tests/layer_tests/py_frontend_tests/test_torch_frontend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def forward(self, x):
455455
converted_model = fe.convert(input_model)
456456
assert converted_model
457457
assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [
458-
"Parameter", "Convert", "Convert", "Cos", "Relu", "Constant", "Convert", "Multiply", "Add", "Result"]
458+
"Parameter", "Convert", "Convert", "Cos", "Constant", "Convert", "Relu", "Multiply", "Add", "Result"]
459459

460460
converted_model = convert_model(model, example_input=(
461461
torch.randn(100),), extension=[ModuleExtension(CosModel, "aten::sin"), ModuleExtension(model.relu_module, "aten::tan")])

0 commit comments

Comments
 (0)