Skip to content

Commit 279e4f1

Browse files
authored
Add-based SR (#44)
Previous implementation was based on comparison: x - floor(x) > rand This changes to be based on add: x + rand >= floor(x) + 1 No behavioural change expected, but bias direction and rounding profiles are inverted
1 parent 465ee41 commit 279e4f1

File tree

7 files changed

+154
-93
lines changed

7 files changed

+154
-93
lines changed

docs/source/03-value-tables.ipynb

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,7 @@
230230
" 0xFF,\n",
231231
" ):\n",
232232
" print(\n",
233-
" str_tablerow(\n",
234-
" fi, decode_float(fi, i), show_b16_info=True, vs_width=8, vs_d=4\n",
235-
" )\n",
233+
" str_tablerow(fi, decode_float(fi, i), show_b16_info=True, vs_width=8, vs_d=4)\n",
236234
" )"
237235
]
238236
},
@@ -3266,7 +3264,7 @@
32663264
],
32673265
"metadata": {
32683266
"kernelspec": {
3269-
"display_name": "Python 3",
3267+
"display_name": "gfloat-clean",
32703268
"language": "python",
32713269
"name": "python3"
32723270
},

docs/source/05-stochastic-rounding.ipynb

Lines changed: 131 additions & 72 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ dependencies = {file = ["requirements.txt"]}
3131
optional-dependencies = {dev = {file = ["requirements-dev.txt"]}}
3232

3333
[tool.black]
34-
line-length = 88
34+
line-length = 90
3535
fast = true
3636

3737
[tool.mypy]

src/gfloat/decode_ndarray.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
from .types import FormatInfo
66

77

8-
def decode_ndarray(
9-
fi: FormatInfo, codes: np.ndarray, np: ModuleType = np
10-
) -> np.ndarray:
8+
def decode_ndarray(fi: FormatInfo, codes: np.ndarray, np: ModuleType = np) -> np.ndarray:
119
r"""
1210
Vectorized version of :meth:`decode_float`
1311

src/gfloat/round.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def round_float(
102102
case RoundMode.TowardNegative:
103103
should_round_away = sign and delta > 0
104104
case RoundMode.TiesToAway:
105-
should_round_away = delta >= 0.5
105+
should_round_away = delta + 0.5 >= 1.0
106106
case RoundMode.TiesToEven:
107107
should_round_away = delta > 0.5 or (delta == 0.5 and code_is_odd)
108108
case RoundMode.Stochastic:
@@ -113,20 +113,20 @@ def round_float(
113113
(d - floord > 0.5) or ((d - floord == 0.5) and _isodd(floord))
114114
)
115115

116-
should_round_away = d > srbits
116+
should_round_away = d + srbits >= 2.0**srnumbits
117117
case RoundMode.StochasticOdd:
118118
## RTNE delta to srbits
119119
d = delta * 2.0**srnumbits
120120
floord = np.floor(d).astype(np.int64)
121121
d = floord + (
122-
(d - floord > 0.5) or ((d - floord == 0.5) and ~_isodd(floord))
122+
(d - floord > 0.5) or ((d - floord == 0.5) and not _isodd(floord))
123123
)
124124

125-
should_round_away = d > srbits
125+
should_round_away = d + srbits >= 2.0**srnumbits
126126
case RoundMode.StochasticFast:
127-
should_round_away = delta > (0.5 + srbits) * 2.0**-srnumbits
127+
should_round_away = delta + (0.5 + srbits) * 2.0**-srnumbits >= 1.0
128128
case RoundMode.StochasticFastest:
129-
should_round_away = delta > srbits * 2.0**-srnumbits
129+
should_round_away = delta + srbits * 2.0**-srnumbits >= 1.0
130130

131131
if should_round_away:
132132
# This may increase isignificand to 2**p,

src/gfloat/round_ndarray.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,19 @@ def round_ndarray(
7878
match rnd:
7979
case RoundMode.TowardZero:
8080
should_round_away = np.zeros_like(delta, dtype=bool)
81+
8182
case RoundMode.TowardPositive:
8283
should_round_away = ~is_negative & (delta > 0)
84+
8385
case RoundMode.TowardNegative:
8486
should_round_away = is_negative & (delta > 0)
87+
8588
case RoundMode.TiesToAway:
8689
should_round_away = delta >= 0.5
90+
8791
case RoundMode.TiesToEven:
8892
should_round_away = (delta > 0.5) | ((delta == 0.5) & code_is_odd)
93+
8994
case RoundMode.Stochastic:
9095
assert srbits is not None
9196
## RTNE delta to srbits
@@ -94,7 +99,8 @@ def round_ndarray(
9499
dd = d - floord
95100
drnd = floord + (dd > 0.5) + ((dd == 0.5) & _isodd(floord))
96101

97-
should_round_away = drnd > srbits
102+
should_round_away = drnd + srbits >= 2.0**srnumbits
103+
98104
case RoundMode.StochasticOdd:
99105
assert srbits is not None
100106
## RTNO delta to srbits
@@ -103,13 +109,15 @@ def round_ndarray(
103109
dd = d - floord
104110
drnd = floord + (dd > 0.5) + ((dd == 0.5) & ~_isodd(floord))
105111

106-
should_round_away = drnd > srbits
112+
should_round_away = drnd + srbits >= 2.0**srnumbits
113+
107114
case RoundMode.StochasticFast:
108115
assert srbits is not None
109-
should_round_away = delta > (2 * srbits + 1) * 2.0 ** -(1 + srnumbits)
116+
should_round_away = delta + (2 * srbits + 1) * 2.0 ** -(1 + srnumbits) >= 1.0
117+
110118
case RoundMode.StochasticFastest:
111119
assert srbits is not None
112-
should_round_away = delta > srbits * 2.0**-srnumbits
120+
should_round_away = delta + srbits * 2.0**-srnumbits >= 1.0
113121

114122
isignificand = np.where(should_round_away, isignificand + 1, isignificand)
115123

test/test_decode.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,7 @@ def test_p3109_k8_specials(fi: FormatInfo) -> None:
222222
assert fi.code_of_neginf == 0xFF
223223

224224

225-
@pytest.mark.parametrize(
226-
"k,p", [(8, 3), (8, 1), (6, 1), (6, 5), (3, 1), (3, 2), (11, 3)]
227-
)
225+
@pytest.mark.parametrize("k,p", [(8, 3), (8, 1), (6, 1), (6, 5), (3, 1), (3, 2), (11, 3)])
228226
def test_p3109_specials(k: int, p: int) -> None:
229227
fi = format_info_p3109(k, p)
230228
assert fi.code_of_nan == 2 ** (k - 1)

0 commit comments

Comments
 (0)