Skip to content

Commit 2366cf3

Browse files
authored
Merge pull request #134 from jurevreca12/rounding_mode_new
Rounding mode new
2 parents a6a23ed + bbd214b commit 2366cf3

File tree

3 files changed

+63
-2
lines changed

3 files changed

+63
-2
lines changed

docs/qonnx-custom-ops/quant_op.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ This operator is not part of the ONNX standard and is not currently versioned.
2121
<dt><tt>narrow</tt> : int (default is 0)</dt>
2222
<dd>Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].</dd>
2323
<dt><tt>rounding_mode</tt> : string (default is "ROUND")</dt>
24-
<dd>Defines how rounding should be applied during quantization. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".</dd>
24+
<dd>Defines how rounding should be applied during quantization. Avaiable options are ROUND, CEIL, FLOOR, UP, DOWN, HALF_UP, HALF_DOWN. The rounding modes are described in the table bellow. The names of rounding modes can be upper case or lower case.</dd>
2525
</dl>
2626

2727
#### Inputs
@@ -46,6 +46,24 @@ This operator is not part of the ONNX standard and is not currently versioned.
4646
</dl>
4747

4848

49+
#### Rounding modes
50+
<details>
51+
<summary>rounding modes</summary>
52+
53+
| **Number \ ROUNDING_MODE** | ROUND=HALF_EVEN | CEIL | FLOOR | UP | DOWN | HALF_UP | HALF_DOWN |
54+
|----------------------------|-----------------|------|-------|----|------|---------|-----------|
55+
| 5.5 | 6 | 6 | 5 | 6 | 5 | 6 | 5 |
56+
| 2.5 | 2 | 3 | 2 | 3 | 2 | 3 | 2 |
57+
| 1.6 | 2 | 2 | 1 | 2 | 1 | 2 | 2 |
58+
| 1.1 | 1 | 2 | 1 | 2 | 1 | 1 | 1 |
59+
| 1.0 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
60+
| -1.0 | -1 | -1 | -1 | -1 | -1 | -1 | -1 |
61+
| -1.1 | -1 | -1 | -2 | -2 | -1 | -1 | -1 |
62+
| -1.6 | -2 | -1 | -2 | -2 | -1 | -2 | -2 |
63+
| -2.5 | -2 | -2 | -3 | -3 | -2 | -3 | -2 |
64+
| -5.5 | -6 | -5 | -6 | -6 | -5 | -6 | -5 |
65+
</details>
66+
4967
#### Examples
5068
<details>
5169
<summary>Quant</summary>

src/qonnx/custom_op/general/quant.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,32 @@ def resolve_rounding_mode(mode_string):
135135
"""Resolve the rounding mode string of Quant and Trunc ops
136136
to the corresponding numpy functions."""
137137
normalized_mode_string = mode_string.upper()
138-
if normalized_mode_string == "ROUND":
138+
if normalized_mode_string == "ROUND" or normalized_mode_string == "HALF_EVEN":
139139
return np.round
140140
elif normalized_mode_string == "CEIL":
141141
return np.ceil
142142
elif normalized_mode_string == "FLOOR":
143143
return np.floor
144+
elif normalized_mode_string == "UP":
145+
146+
def round_up(x):
147+
return np.sign(x) * np.ceil(np.abs(x))
148+
149+
return round_up
150+
elif normalized_mode_string == "DOWN":
151+
return np.fix
152+
elif normalized_mode_string == "HALF_UP":
153+
154+
def round_half_up(x):
155+
return np.sign(x) * np.floor(np.abs(x) + 0.5)
156+
157+
return round_half_up
158+
elif normalized_mode_string == "HALF_DOWN":
159+
160+
def round_half_down(x):
161+
return np.sign(x) * np.ceil(np.abs(x) - 0.5)
162+
163+
return round_half_down
144164
else:
145165
raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}")
146166

tests/custom_op/test_runding_mode.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
3+
import numpy as np
4+
5+
from qonnx.custom_op.general.quant import resolve_rounding_mode
6+
7+
8+
@pytest.mark.parametrize(
9+
"rmode,exp",
10+
[
11+
("ROUND", np.array([6, 2, 2, 1, 1, -1, -1, -2, -2, -6])),
12+
("CEIL", np.array([6, 3, 2, 2, 1, -1, -1, -1, -2, -5])),
13+
("FLOOR", np.array([5, 2, 1, 1, 1, -1, -2, -2, -3, -6])),
14+
("UP", np.array([6, 3, 2, 2, 1, -1, -2, -2, -3, -6])),
15+
("DOWN", np.array([5, 2, 1, 1, 1, -1, -1, -1, -2, -5])),
16+
("HALF_UP", np.array([6, 3, 2, 1, 1, -1, -1, -2, -3, -6])),
17+
("HALF_DOWN", np.array([5, 2, 2, 1, 1, -1, -1, -2, -2, -5])),
18+
],
19+
)
20+
def test_rounding_modes(rmode, exp):
21+
test_array = np.array([5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5])
22+
rounding_fn = resolve_rounding_mode(rmode)
23+
assert np.array_equal(rounding_fn(test_array), exp)

0 commit comments

Comments
 (0)