Skip to content

Commit a7bb326

Browse files
authored
Merge pull request #33 from graphcore-research/fix-overflows
Avoid overflow/conversion errors in vectorized code.
2 parents 49d2fb8 + a9d80c5 commit a7bb326

File tree

4 files changed

+26
-14
lines changed

4 files changed

+26
-14
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,4 @@ ignore_missing_imports = true
4141

4242
[tool.pytest.ini_options]
4343
addopts = "--nbval"
44-
testpaths = ["test", "doc"]
44+
testpaths = ["docs", "test"]

src/gfloat/decode_ndarray.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,14 @@ def decode_ndarray(
4848

4949
expBias = fi.expBias
5050

51-
iszero = (exp == 0) & (significand == 0) & fi.has_zero
52-
issubnormal = (exp == 0) & (significand != 0) & fi.has_subnormals
53-
isnormal = ~iszero & ~issubnormal
54-
expval = np.where(~isnormal, 1 - expBias, exp - expBias)
55-
fsignificand = np.where(~isnormal, significand * 2**-t, 1.0 + significand * 2**-t)
56-
57-
# Normal/Subnormal/Zero case, other values will be overwritten
58-
fval = np.where(iszero, 0.0, sign * fsignificand * 2.0**expval)
51+
fval = np.zeros_like(codes, dtype=np.float64)
52+
isspecial = np.zeros_like(codes, dtype=bool)
5953

6054
if fi.has_infs:
6155
fval = np.where(codes == fi.code_of_posinf, np.inf, fval)
56+
isspecial |= codes == fi.code_of_posinf
6257
fval = np.where(codes == fi.code_of_neginf, -np.inf, fval)
58+
isspecial |= codes == fi.code_of_neginf
6359

6460
if fi.num_nans > 0:
6561
code_is_nan = codes == fi.code_of_nan
@@ -70,9 +66,21 @@ def decode_ndarray(
7066
code_is_nan |= abse & (significand >= min_code_with_nan)
7167

7268
fval = np.where(code_is_nan, np.nan, fval)
69+
isspecial |= code_is_nan
7370

74-
# Negative zero
71+
# Zero
72+
iszero = ~isspecial & (exp == 0) & (significand == 0) & fi.has_zero
73+
fval = np.where(iszero, 0.0, fval)
7574
if fi.has_nz:
7675
fval = np.where(iszero & (sign < 0), -0.0, fval)
7776

77+
issubnormal = (exp == 0) & (significand != 0) & fi.has_subnormals
78+
expval = np.where(issubnormal, 1 - expBias, exp - expBias)
79+
fsignificand = np.where(issubnormal, 0.0, 1.0) + np.ldexp(significand, -t)
80+
81+
# Normal/Subnormal/Zero case, other values will be overwritten
82+
expval_safe = np.where(isspecial | iszero, 0, expval)
83+
fval_finite_safe = sign * np.ldexp(fsignificand, expval_safe)
84+
fval = np.where(~(iszero | isspecial), fval_finite_safe, fval)
85+
7886
return fval

src/gfloat/round_ndarray.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def round_ndarray(
9292
expval += round_up & (isignificand == 1)
9393
isignificand = np.where(round_up, 1, isignificand)
9494

95-
result = np.where(finite_nonzero, isignificand * (2.0**expval), absv)
95+
result = np.where(finite_nonzero, np.ldexp(isignificand, expval), absv)
9696

9797
amax = np.where(is_negative, -fi.min, fi.max)
9898

@@ -189,10 +189,11 @@ def encode_ndarray(fi: FormatInfo, v: np.ndarray) -> np.ndarray:
189189
biased_exp = exp.astype(np.int64) + (fi.expBias - 1)
190190
subnormal_mask = (biased_exp < 1) & fi.has_subnormals
191191

192-
tsig = np.where(subnormal_mask, sig * 2.0**biased_exp, sig * 2 - 1.0)
192+
biased_exp_safe = np.where(subnormal_mask, biased_exp, 0)
193+
tsig = np.where(subnormal_mask, np.ldexp(sig, biased_exp_safe), sig * 2 - 1.0)
193194
biased_exp[subnormal_mask] = 0
194195

195-
isig = np.floor(tsig * 2**t).astype(int)
196+
isig = np.floor(np.ldexp(tsig, t)).astype(np.int64)
196197

197198
zero_mask = fi.has_zero & (isig == 0) & (biased_exp == 0)
198199
if not fi.has_nz:

test/test_decode.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,10 @@ def test_consistent_decodes_all_values(
260260
npivals = np.arange(
261261
np.iinfo(int_dtype).min, int(np.iinfo(int_dtype).max) + 1, dtype=int_dtype
262262
)
263-
npfvals = npivals.view(dtype=npfmt).astype(np.float64)
263+
264+
with np.errstate(invalid="ignore"):
265+
# Warning here when converting bfloat16 NaNs to float64
266+
npfvals = npivals.view(dtype=npfmt).astype(np.float64)
264267

265268
# Scalar version
266269
for i, npfval in zip(npivals, npfvals):

0 commit comments

Comments
 (0)