Skip to content

Commit 0257255

Browse files
authored
Merge pull request #36 from graphcore-research/sr-bias-example
Add comparison to "SRFast"
2 parents ca391b8 + f8ca462 commit 0257255

File tree

8 files changed

+1170
-80
lines changed

8 files changed

+1170
-80
lines changed

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
- uses: actions/checkout@v3
1515
- uses: actions/setup-python@v4
1616
with:
17-
python-version: "3.9"
17+
python-version: "3.10"
1818
cache: "pip"
1919

2020
- name: Install requirements

docs/source/05-stochastic-rounding.ipynb

Lines changed: 994 additions & 25 deletions
Large diffs are not rendered by default.

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
numpy
2+
more_itertools

src/gfloat/printing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def float_tilde_unless_roundtrip_str(v: float, width: int = 14, d: int = 8) -> s
5151
# it is preceded by a "~" to indicate "approximately equal to"
5252
s = f"{v}"
5353
if len(s) > width:
54-
if abs(v) < 1 and not "e" in s:
54+
if abs(v) < 1 and "e" not in s:
5555
s = f"{v:.{d}f}"
5656
else:
5757
s = f"{v:.{d}}"

src/gfloat/round.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def round_float(
4949
p = fi.precision
5050
bias = fi.expBias
5151

52-
if rnd == RoundMode.Stochastic:
52+
if rnd in (RoundMode.Stochastic, RoundMode.StochasticFast):
5353
if srbits >= 2**srnumbits:
5454
raise ValueError(f"srnumbits={srnumbits} >= 2**srnumbits={2**srnumbits}")
5555

@@ -94,18 +94,39 @@ def round_float(
9494
else (isignificand != 0 and _isodd(expval + bias))
9595
)
9696

97-
if rnd == RoundMode.TowardZero:
98-
should_round_away = False
99-
if rnd == RoundMode.TowardPositive:
100-
should_round_away = not sign and delta > 0
101-
if rnd == RoundMode.TowardNegative:
102-
should_round_away = sign and delta > 0
103-
if rnd == RoundMode.TiesToAway:
104-
should_round_away = delta >= 0.5
105-
if rnd == RoundMode.TiesToEven:
106-
should_round_away = delta > 0.5 or (delta == 0.5 and code_is_odd)
107-
if rnd == RoundMode.Stochastic:
108-
should_round_away = delta > (0.5 + srbits) * 2.0**-srnumbits
97+
match rnd:
98+
case RoundMode.TowardZero:
99+
should_round_away = False
100+
case RoundMode.TowardPositive:
101+
should_round_away = not sign and delta > 0
102+
case RoundMode.TowardNegative:
103+
should_round_away = sign and delta > 0
104+
case RoundMode.TiesToAway:
105+
should_round_away = delta >= 0.5
106+
case RoundMode.TiesToEven:
107+
should_round_away = delta > 0.5 or (delta == 0.5 and code_is_odd)
108+
case RoundMode.Stochastic:
109+
## RTNE delta to srbits
110+
d = delta * 2.0**srnumbits
111+
floord = np.floor(d).astype(np.int64)
112+
d = floord + (
113+
(d - floord > 0.5) or ((d - floord == 0.5) and _isodd(floord))
114+
)
115+
116+
should_round_away = d > srbits
117+
case RoundMode.StochasticOdd:
118+
## RTNE delta to srbits
119+
d = delta * 2.0**srnumbits
120+
floord = np.floor(d).astype(np.int64)
121+
d = floord + (
122+
(d - floord > 0.5) or ((d - floord == 0.5) and ~_isodd(floord))
123+
)
124+
125+
should_round_away = d > srbits
126+
case RoundMode.StochasticFast:
127+
should_round_away = delta > (0.5 + srbits) * 2.0**-srnumbits
128+
case RoundMode.StochasticFastest:
129+
should_round_away = delta > srbits * 2.0**-srnumbits
109130

110131
if should_round_away:
111132
# This may increase isignificand to 2**p,

src/gfloat/round_ndarray.py

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
22

3+
from typing import Optional
34
from types import ModuleType
45
from .types import FormatInfo, RoundMode
56
import numpy as np
6-
import math
77

88

99
def _isodd(v: np.ndarray) -> np.ndarray:
@@ -15,6 +15,8 @@ def round_ndarray(
1515
v: np.ndarray,
1616
rnd: RoundMode = RoundMode.TiesToEven,
1717
sat: bool = False,
18+
srbits: Optional[np.ndarray] = None,
19+
srnumbits: int = 0,
1820
np: ModuleType = np,
1921
) -> np.ndarray:
2022
"""
@@ -30,9 +32,12 @@ def round_ndarray(
3032
3133
Args:
3234
fi (FormatInfo): Describes the target format
33-
v (float): Input value to be rounded
35+
v (float array): Input values to be rounded
3436
rnd (RoundMode): Rounding mode to use
3537
sat (bool): Saturation flag: if True, round overflowed values to `fi.max`
38+
srbits (int array): Bits to use for stochastic rounding if rnd == Stochastic.
39+
srnumbits (int): How many bits are in srbits. Implies srbits < 2**srnumbits.
40+
3641
np (Module): May be `numpy`, `jax.numpy` or another module cloning numpy
3742
3843
Returns:
@@ -70,18 +75,43 @@ def round_ndarray(
7075
else:
7176
code_is_odd = (isignificand != 0) & _isodd(expval + bias)
7277

73-
if rnd == RoundMode.TowardPositive:
74-
round_up = ~is_negative & (delta > 0)
75-
elif rnd == RoundMode.TowardNegative:
76-
round_up = is_negative & (delta > 0)
77-
elif rnd == RoundMode.TiesToAway:
78-
round_up = delta >= 0.5
79-
elif rnd == RoundMode.TiesToEven:
80-
round_up = (delta > 0.5) | ((delta == 0.5) & code_is_odd)
81-
else:
82-
round_up = np.zeros_like(delta, dtype=bool)
83-
84-
isignificand = np.where(round_up, isignificand + 1, isignificand)
78+
match rnd:
79+
case RoundMode.TowardZero:
80+
should_round_away = np.zeros_like(delta, dtype=bool)
81+
case RoundMode.TowardPositive:
82+
should_round_away = ~is_negative & (delta > 0)
83+
case RoundMode.TowardNegative:
84+
should_round_away = is_negative & (delta > 0)
85+
case RoundMode.TiesToAway:
86+
should_round_away = delta >= 0.5
87+
case RoundMode.TiesToEven:
88+
should_round_away = (delta > 0.5) | ((delta == 0.5) & code_is_odd)
89+
case RoundMode.Stochastic:
90+
assert srbits is not None
91+
## RTNE delta to srbits
92+
d = delta * 2.0**srnumbits
93+
floord = np.floor(d).astype(np.int64)
94+
dd = d - floord
95+
drnd = floord + (dd > 0.5) + ((dd == 0.5) & _isodd(floord))
96+
97+
should_round_away = drnd > srbits
98+
case RoundMode.StochasticOdd:
99+
assert srbits is not None
100+
## RTNO delta to srbits
101+
d = delta * 2.0**srnumbits
102+
floord = np.floor(d).astype(np.int64)
103+
dd = d - floord
104+
drnd = floord + (dd > 0.5) + ((dd == 0.5) & ~_isodd(floord))
105+
106+
should_round_away = drnd > srbits
107+
case RoundMode.StochasticFast:
108+
assert srbits is not None
109+
should_round_away = delta > (2 * srbits + 1) * 2.0 ** -(1 + srnumbits)
110+
case RoundMode.StochasticFastest:
111+
assert srbits is not None
112+
should_round_away = delta > srbits * 2.0**-srnumbits
113+
114+
isignificand = np.where(should_round_away, isignificand + 1, isignificand)
85115

86116
result = np.where(finite_nonzero, np.ldexp(isignificand, expval), absv)
87117

@@ -90,14 +120,15 @@ def round_ndarray(
90120
if sat:
91121
result = np.where(result > amax, amax, result)
92122
else:
93-
if rnd == RoundMode.TowardNegative:
94-
put_amax_at = (result > amax) & ~is_negative
95-
elif rnd == RoundMode.TowardPositive:
96-
put_amax_at = (result > amax) & is_negative
97-
elif rnd == RoundMode.TowardZero:
98-
put_amax_at = result > amax
99-
else:
100-
put_amax_at = np.zeros_like(result, dtype=bool)
123+
match rnd:
124+
case RoundMode.TowardNegative:
125+
put_amax_at = (result > amax) & ~is_negative
126+
case RoundMode.TowardPositive:
127+
put_amax_at = (result > amax) & is_negative
128+
case RoundMode.TowardZero:
129+
put_amax_at = result > amax
130+
case _:
131+
put_amax_at = np.zeros_like(result, dtype=bool)
101132

102133
result = np.where(finite_nonzero & put_amax_at, amax, result)
103134

src/gfloat/types.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,21 @@ class RoundMode(Enum):
1717
TiesToEven = 4 #: Round to nearest, ties to even
1818
TiesToAway = 5 #: Round to nearest, ties away from zero
1919
Stochastic = 6 #: Stochastic rounding
20+
StochasticFast = 7 #: Stochastic rounding - faster, but biased, see [Note 1].
21+
StochasticFastest = 8 #: Stochastic rounding - incorrect, see [Note 1].
22+
StochasticOdd = 9 #: Stochastic rounding, RTNO before comparison
23+
24+
25+
# [Note 1]:
26+
# StochasticFast implements a stochastic rounding scheme that is unbiased in
27+
# infinite precision, but biased when the quantity to be rounded is computed to
28+
# a finite precision.
29+
#
30+
# StochasticFastest implements a stochastic rounding scheme that is biased
31+
# (the rounded value is on average farther from zero than the true value).
32+
#
33+
# With a lot of SRbits (say 8 or more), these biases are negligible, and there
34+
# may be some efficiency advantage in using StochasticFast or StochasticFastest.
2035

2136

2237
class FloatClass(Enum):

test/test_round.py

Lines changed: 71 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
import pytest
88

9-
from gfloat import RoundMode, decode_float, round_float, round_ndarray
9+
from gfloat import RoundMode, decode_float, decode_ndarray, round_float, round_ndarray
1010
from gfloat.formats import *
1111

1212

@@ -428,7 +428,7 @@ def get_vals() -> Iterator[Tuple[float, float]]:
428428
]
429429

430430

431-
def _linterp(a: float, b: float, t: float) -> float:
431+
def _linterp(a, b, t): # type: ignore[no-untyped-def]
432432
return a * (1 - t) + b * t
433433

434434

@@ -494,13 +494,16 @@ def test_round_roundtrip(round_float: Callable, fi: FormatInfo) -> None:
494494
"v, srnumbits, expected_up",
495495
(
496496
(259, 3, 0.0 / 8),
497-
(259, 5, 1.0 / 32),
497+
(259, 5, 2.0 / 32),
498498
(277, 3, 3.0 / 8),
499499
(288, 3, 0.5),
500500
(311, 3, 7.0 / 8),
501501
),
502502
)
503-
def test_stochastic_rounding(v: float, srnumbits: int, expected_up: float) -> None:
503+
@pytest.mark.parametrize("impl", ("scalar", "array"))
504+
def test_stochastic_rounding(
505+
impl: bool, v: float, srnumbits: int, expected_up: float
506+
) -> None:
504507
fi = format_info_ocp_e5m2
505508

506509
v0 = round_float(fi, v, RoundMode.TowardNegative)
@@ -510,23 +513,73 @@ def test_stochastic_rounding(v: float, srnumbits: int, expected_up: float) -> No
510513
expected_up_count = expected_up * n
511514

512515
srbits = np.random.randint(0, 2**srnumbits, size=(n,))
513-
count_v1 = 0
514-
for k in range(n):
515-
r = round_float(
516-
fi,
517-
v,
518-
RoundMode.Stochastic,
519-
sat=False,
520-
srbits=srbits[k],
521-
srnumbits=srnumbits,
522-
)
523-
if r == v1:
524-
count_v1 += 1
525-
else:
526-
assert r == v0
516+
if impl == "scalar":
517+
count_v1 = 0
518+
for k in range(n):
519+
r = round_float(
520+
fi,
521+
v,
522+
RoundMode.Stochastic,
523+
sat=False,
524+
srbits=srbits[k],
525+
srnumbits=srnumbits,
526+
)
527+
if r == v1:
528+
count_v1 += 1
529+
else:
530+
assert r == v0
531+
else:
532+
vs = np.full(n, v)
533+
rs = round_ndarray(fi, vs, RoundMode.Stochastic, False, srbits, srnumbits)
534+
assert np.all((rs == v0) | (rs == v1))
535+
count_v1 = np.sum(rs == v1)
527536

528537
print(f"SRBits={srnumbits}, observed = {count_v1}, expected = {expected_up_count} ")
529538
# e.g. if expected is 1250/10000, want to be within 0.5,1.5
530539
# this is loose, but should still catch logic errors
531540
atol = n * 2.0 ** (-1 - srnumbits)
532541
np.testing.assert_allclose(count_v1, expected_up_count, atol=atol)
542+
543+
544+
@pytest.mark.parametrize(
545+
"rnd",
546+
(RoundMode.Stochastic, RoundMode.StochasticFast, RoundMode.StochasticFastest),
547+
)
548+
def test_stochastic_rounding_scalar_eq_array(rnd: RoundMode) -> None:
549+
fi = format_info_p3109(3)
550+
551+
v0 = decode_ndarray(fi, np.arange(255))
552+
v1 = decode_ndarray(fi, np.arange(255) + 1)
553+
ok = np.isfinite(v0) & np.isfinite(v1)
554+
v0 = v0[ok]
555+
v1 = v1[ok]
556+
557+
srnumbits = 3
558+
for srbits in range(2**srnumbits):
559+
for alpha in (0, 0.3, 0.5, 0.6, 0.9, 1.25):
560+
v = _linterp(v0, v1, alpha)
561+
assert np.isfinite(v).all()
562+
val_array = round_ndarray(
563+
fi,
564+
v,
565+
rnd,
566+
sat=False,
567+
srbits=np.asarray(srbits),
568+
srnumbits=srnumbits,
569+
)
570+
571+
val_scalar = [
572+
round_float(
573+
fi,
574+
v,
575+
rnd,
576+
sat=False,
577+
srbits=srbits,
578+
srnumbits=srnumbits,
579+
)
580+
for v in v
581+
]
582+
if alpha < 1.0:
583+
assert ((val_array == v0) | (val_array == v1)).all()
584+
585+
np.testing.assert_equal(val_array, val_scalar)

0 commit comments

Comments
 (0)