Skip to content

Commit 7e98eb3

Browse files
author
jvreca
committed
reformated with pre-commit hooks.
1 parent e2c1504 commit 7e98eb3

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

src/qonnx/custom_op/general/quant.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,18 +142,24 @@ def resolve_rounding_mode(mode_string):
142142
elif normalized_mode_string == "FLOOR":
143143
return np.floor
144144
elif normalized_mode_string == "UP":
145+
145146
def round_up(x):
146147
return np.sign(x) * np.ceil(np.abs(x))
148+
147149
return round_up
148150
elif normalized_mode_string == "DOWN":
149151
return np.fix
150152
elif normalized_mode_string == "HALF_UP":
153+
151154
def round_half_up(x):
152155
return np.sign(x) * np.floor(np.abs(x) + 0.5)
156+
153157
return round_half_up
154158
elif normalized_mode_string == "HALF_DOWN":
159+
155160
def round_half_down(x):
156161
return np.sign(x) * np.ceil(np.abs(x) - 0.5)
162+
157163
return round_half_down
158164
else:
159165
raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}")

tests/custom_op/test_runding_mode.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@
44

55
from qonnx.custom_op.general.quant import resolve_rounding_mode
66

7-
@pytest.mark.parametrize("rmode,exp", [
7+
8+
@pytest.mark.parametrize(
9+
"rmode,exp",
10+
[
811
("ROUND", np.array([6, 2, 2, 1, 1, -1, -1, -2, -2, -6])),
9-
("CEIL", np.array([6, 3, 2, 2, 1, -1, -1, -1, -2, - 5])),
12+
("CEIL", np.array([6, 3, 2, 2, 1, -1, -1, -1, -2, -5])),
1013
("FLOOR", np.array([5, 2, 1, 1, 1, -1, -2, -2, -3, -6])),
1114
("UP", np.array([6, 3, 2, 2, 1, -1, -2, -2, -3, -6])),
1215
("DOWN", np.array([5, 2, 1, 1, 1, -1, -1, -1, -2, -5])),
1316
("HALF_UP", np.array([6, 3, 2, 1, 1, -1, -1, -2, -3, -6])),
14-
("HALF_DOWN", np.array([5, 2, 2, 1, 1, -1, -1, -2, -2, -5]))
15-
]
17+
("HALF_DOWN", np.array([5, 2, 2, 1, 1, -1, -1, -2, -2, -5])),
18+
],
1619
)
1720
def test_rounding_modes(rmode, exp):
1821
test_array = np.array([5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5])

0 commit comments

Comments
 (0)