Skip to content

Commit 0bb857a

Browse files
authored
Merge pull request #34 from graphcore-research/stochastic-rounding
Add stochastic rounding
2 parents a7bb326 + 01972f1 commit 0bb857a

File tree

4 files changed

+164
-1
lines changed

4 files changed

+164
-1
lines changed

docs/source/05-stochastic-rounding.ipynb

Lines changed: 103 additions & 0 deletions
Large diffs are not rendered by default.

src/gfloat/round.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@ def _isodd(v: int) -> bool:
1212

1313

1414
def round_float(
15-
fi: FormatInfo, v: float, rnd: RoundMode = RoundMode.TiesToEven, sat: bool = False
15+
fi: FormatInfo,
16+
v: float,
17+
rnd: RoundMode = RoundMode.TiesToEven,
18+
sat: bool = False,
19+
srbits: int = -1,
20+
srnumbits: int = 0,
1621
) -> float:
1722
"""
1823
Round input to the given :py:class:`FormatInfo`, given rounding mode and saturation flag
@@ -27,6 +32,8 @@ def round_float(
2732
v (float): Input value to be rounded
2833
rnd (RoundMode): Rounding mode to use
2934
sat (bool): Saturation flag: if True, round overflowed values to `fi.max`
35+
srbits (int): Bits to use for stochastic rounding if rnd == Stochastic.
36+
srnumbits (int): How many bits are in srbits. Implies srbits < 2**srnumbits.
3037
3138
Returns:
3239
A float which is one of the values in the format.
@@ -35,12 +42,17 @@ def round_float(
3542
ValueError: The target format cannot represent the input
3643
(e.g. converting a `NaN`, or an `Inf` when the target has no
3744
`NaN` or `Inf`, and :paramref:`sat` is false)
45+
ValueError: Inconsistent arguments, e.g. srnumbits >= 2**srnumbits
3846
"""
3947

4048
# Constants
4149
p = fi.precision
4250
bias = fi.expBias
4351

52+
if rnd == RoundMode.Stochastic:
53+
if srbits >= 2**srnumbits:
54+
raise ValueError(f"srnumbits={srnumbits} >= 2**srnumbits={2**srnumbits}")
55+
4456
if np.isnan(v):
4557
if fi.num_nans == 0:
4658
raise ValueError(f"No NaN in format {fi}")
@@ -75,12 +87,15 @@ def round_float(
7587
# Round
7688
isignificand = math.floor(fsignificand)
7789
delta = fsignificand - isignificand
90+
91+
# fmt: off
7892
if (
7993
(rnd == RoundMode.TowardPositive and not sign and delta > 0)
8094
or (rnd == RoundMode.TowardNegative and sign and delta > 0)
8195
or (rnd == RoundMode.TiesToAway and delta >= 0.5)
8296
or (rnd == RoundMode.TiesToEven and delta > 0.5)
8397
or (rnd == RoundMode.TiesToEven and delta == 0.5 and _isodd(isignificand))
98+
or (rnd == RoundMode.Stochastic and delta > (0.5 + srbits) * 2.0**-srnumbits)
8499
):
85100
isignificand += 1
86101

@@ -95,6 +110,7 @@ def round_float(
95110
or (rnd == RoundMode.TiesToAway and delta >= 0.5)
96111
or (rnd == RoundMode.TiesToEven and delta > 0.5)
97112
or (rnd == RoundMode.TiesToEven and delta == 0.5 and code_is_odd)
113+
or (rnd == RoundMode.Stochastic and delta > (0.5 + srbits) * 2.0**-srnumbits)
98114
):
99115
# Go to nextUp.
100116
# Increment isignificand if zero,
@@ -105,6 +121,7 @@ def round_float(
105121
assert isignificand == 1
106122
expval += 1
107123
## End special case for Precision=1.
124+
# fmt: on
108125

109126
# Reconstruct rounded result to float
110127
result = isignificand * (2.0**expval)

src/gfloat/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class RoundMode(Enum):
1616
TowardPositive = 3 #: :math:`\min \{ r ~ s.t. ~ r \ge v \}`
1717
TiesToEven = 4 #: Round to nearest, ties to even
1818
TiesToAway = 5 #: Round to nearest, ties away from zero
19+
Stochastic = 6 #: Stochastic rounding
1920

2021

2122
class FloatClass(Enum):

test/test_round.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,3 +488,45 @@ def test_round_roundtrip(round_float: Callable, fi: FormatInfo) -> None:
488488
fv = decode_float(fi, i)
489489
fval2 = round_float(fi, fv.fval)
490490
np.testing.assert_equal(fval2, fv.fval)
491+
492+
493+
@pytest.mark.parametrize(
494+
"v, srnumbits, expected_up",
495+
(
496+
(259, 3, 0.0 / 8),
497+
(259, 5, 1.0 / 32),
498+
(277, 3, 3.0 / 8),
499+
(288, 3, 0.5),
500+
(311, 3, 7.0 / 8),
501+
),
502+
)
503+
def test_stochastic_rounding(v: float, srnumbits: int, expected_up: float) -> None:
504+
fi = format_info_ocp_e5m2
505+
506+
v0 = round_float(fi, v, RoundMode.TowardNegative)
507+
v1 = round_float(fi, v, RoundMode.TowardPositive)
508+
509+
n = 10_000
510+
expected_up_count = expected_up * n
511+
512+
srbits = np.random.randint(0, 2**srnumbits, size=(n,))
513+
count_v1 = 0
514+
for k in range(n):
515+
r = round_float(
516+
fi,
517+
v,
518+
RoundMode.Stochastic,
519+
sat=False,
520+
srbits=srbits[k],
521+
srnumbits=srnumbits,
522+
)
523+
if r == v1:
524+
count_v1 += 1
525+
else:
526+
assert r == v0
527+
528+
print(f"SRBits={srnumbits}, observed = {count_v1}, expected = {expected_up_count} ")
529+
# e.g. if expected is 1250/10000, want to be within 0.5,1.5
530+
# this is loose, but should still catch logic errors
531+
atol = n * 2.0 ** (-1 - srnumbits)
532+
np.testing.assert_allclose(count_v1, expected_up_count, atol=atol)

0 commit comments

Comments
 (0)