Skip to content

Commit 564c104

Browse files
authored
Merge pull request #39 from graphcore-research/refactor-encode-round
Refactor encode/round
2 parents 0257255 + 1d20838 commit 564c104

File tree

6 files changed

+188
-174
lines changed

6 files changed

+188
-174
lines changed

src/gfloat/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
)
1010
from .decode import decode_float
1111
from .printing import float_pow2str, float_tilde_unless_roundtrip_str
12-
from .round import encode_float, round_float
13-
from .round_ndarray import encode_ndarray, round_ndarray
12+
from .round import round_float
13+
from .encode import encode_float
14+
from .round_ndarray import round_ndarray
15+
from .encode_ndarray import encode_ndarray
1416
from .decode_ndarray import decode_ndarray
1517
from .types import FloatClass, FloatValue, FormatInfo, RoundMode
1618

src/gfloat/block.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import numpy.typing as npt
1111

1212
from .decode import decode_float
13-
from .round import RoundMode, encode_float, round_float
13+
from .round import RoundMode, round_float
14+
from .encode import encode_float
1415
from .types import FormatInfo
1516

1617

src/gfloat/encode.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
2+
3+
import math
4+
5+
import numpy as np
6+
7+
from .types import FormatInfo
8+
9+
10+
def encode_float(fi: FormatInfo, v: float) -> int:
11+
"""
12+
Encode input to the given :py:class:`FormatInfo`.
13+
14+
Will round toward zero if :paramref:`v` is not in the value set.
15+
Will saturate to `Inf`, `NaN`, `fi.max` in order of precedence.
16+
Encode -0 to 0 if not `fi.has_nz`
17+
18+
For other roundings and saturations, call :func:`round_float` first.
19+
20+
Args:
21+
fi (FormatInfo): Describes the target format
22+
v (float): The value to be encoded.
23+
24+
Returns:
25+
The integer code point
26+
"""
27+
28+
# Format Constants
29+
k = fi.bits
30+
p = fi.precision
31+
t = p - 1
32+
33+
# Encode
34+
if np.isnan(v):
35+
return fi.code_of_nan
36+
37+
# Overflow/underflow
38+
if v > fi.max:
39+
if fi.has_infs:
40+
return fi.code_of_posinf
41+
if fi.num_nans > 0:
42+
return fi.code_of_nan
43+
return fi.code_of_max
44+
45+
if v < fi.min:
46+
if fi.has_infs:
47+
return fi.code_of_neginf
48+
if fi.num_nans > 0:
49+
return fi.code_of_nan
50+
return fi.code_of_min
51+
52+
# Finite values
53+
sign = fi.is_signed and np.signbit(v)
54+
vpos = -v if sign else v
55+
56+
if fi.has_subnormals and vpos <= fi.smallest_subnormal / 2:
57+
isig = 0
58+
biased_exp = 0
59+
else:
60+
sig, exp = np.frexp(vpos)
61+
exp = int(exp) # All calculations in Python ints
62+
63+
# sig in range [0.5, 1)
64+
sig *= 2
65+
exp -= 1
66+
# now sig in range [1, 2)
67+
68+
biased_exp = exp + fi.expBias
69+
if biased_exp < 1 and fi.has_subnormals:
70+
# subnormal
71+
sig *= 2.0 ** (biased_exp - 1)
72+
biased_exp = 0
73+
assert vpos == sig * 2 ** (1 - fi.expBias)
74+
else:
75+
if sig > 0:
76+
sig -= 1.0
77+
78+
isig = math.floor(sig * 2**t)
79+
80+
# Zero
81+
if isig == 0 and biased_exp == 0 and fi.has_zero:
82+
if sign and fi.has_nz:
83+
return fi.code_of_negzero
84+
else:
85+
return fi.code_of_zero
86+
87+
# Nonzero
88+
assert isig < 2**t
89+
assert biased_exp < 2**fi.expBits or fi.is_twos_complement
90+
91+
# Handle two's complement encoding
92+
if fi.is_twos_complement and sign:
93+
isig = (1 << t) - isig
94+
95+
# Pack values into a single integer
96+
code = (int(sign) << (k - 1)) | (biased_exp << t) | (isig << 0)
97+
98+
return code

src/gfloat/encode_ndarray.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
2+
3+
from .types import FormatInfo
4+
import numpy as np
5+
6+
7+
def encode_ndarray(fi: FormatInfo, v: np.ndarray) -> np.ndarray:
8+
"""
9+
Vectorized version of :meth:`encode_float`.
10+
11+
Encode inputs to the given :py:class:`FormatInfo`.
12+
13+
Will round toward zero if :paramref:`v` is not in the value set.
14+
Will saturate to `Inf`, `NaN`, `fi.max` in order of precedence.
15+
Encode -0 to 0 if not `fi.has_nz`
16+
17+
For other roundings and saturations, call :func:`round_ndarray` first.
18+
19+
Args:
20+
fi (FormatInfo): Describes the target format
21+
v (float array): The value to be encoded.
22+
23+
Returns:
24+
The integer code point
25+
"""
26+
k = fi.bits
27+
p = fi.precision
28+
t = p - 1
29+
30+
sign = np.signbit(v) & fi.is_signed
31+
vpos = np.where(sign, -v, v)
32+
33+
nan_mask = np.isnan(v)
34+
35+
code = np.zeros_like(v, dtype=np.uint64)
36+
37+
if fi.num_nans > 0:
38+
code[nan_mask] = fi.code_of_nan
39+
else:
40+
assert not np.any(nan_mask)
41+
42+
if fi.has_infs:
43+
code[v > fi.max] = fi.code_of_posinf
44+
code[v < fi.min] = fi.code_of_neginf
45+
else:
46+
code[v > fi.max] = fi.code_of_nan if fi.num_nans > 0 else fi.code_of_max
47+
code[v < fi.min] = fi.code_of_nan if fi.num_nans > 0 else fi.code_of_min
48+
49+
if fi.has_zero:
50+
if fi.has_nz:
51+
code[v == 0] = np.where(sign[v == 0], fi.code_of_negzero, fi.code_of_zero)
52+
else:
53+
code[v == 0] = fi.code_of_zero
54+
55+
finite_mask = (code == 0) & (v != 0)
56+
assert not np.any(np.isnan(vpos[finite_mask]))
57+
if np.any(finite_mask):
58+
finite_vpos = vpos[finite_mask]
59+
finite_sign = sign[finite_mask]
60+
61+
sig, exp = np.frexp(finite_vpos)
62+
63+
biased_exp = exp.astype(np.int64) + (fi.expBias - 1)
64+
subnormal_mask = (biased_exp < 1) & fi.has_subnormals
65+
66+
biased_exp_safe = np.where(subnormal_mask, biased_exp, 0)
67+
tsig = np.where(subnormal_mask, np.ldexp(sig, biased_exp_safe), sig * 2 - 1.0)
68+
biased_exp[subnormal_mask] = 0
69+
70+
isig = np.floor(np.ldexp(tsig, t)).astype(np.int64)
71+
72+
zero_mask = fi.has_zero & (isig == 0) & (biased_exp == 0)
73+
if not fi.has_nz:
74+
finite_sign[zero_mask] = False
75+
76+
# Handle two's complement encoding
77+
if fi.is_twos_complement:
78+
isig[finite_sign] = (1 << t) - isig[finite_sign]
79+
80+
code[finite_mask] = (
81+
(finite_sign.astype(int) << (k - 1)) | (biased_exp << t) | (isig << 0)
82+
)
83+
84+
return code

src/gfloat/round.py

Lines changed: 0 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -166,94 +166,3 @@ def round_float(
166166
result = -result
167167

168168
return result
169-
170-
171-
def encode_float(fi: FormatInfo, v: float) -> int:
172-
"""
173-
Encode input to the given :py:class:`FormatInfo`.
174-
175-
Will round toward zero if :paramref:`v` is not in the value set.
176-
Will saturate to `Inf`, `NaN`, `fi.max` in order of precedence.
177-
Encode -0 to 0 if not `fi.has_nz`
178-
179-
For other roundings and saturations, call :func:`round_float` first.
180-
181-
Args:
182-
fi (FormatInfo): Describes the target format
183-
v (float): The value to be encoded.
184-
185-
Returns:
186-
The integer code point
187-
"""
188-
189-
# Format Constants
190-
k = fi.bits
191-
p = fi.precision
192-
t = p - 1
193-
194-
# Encode
195-
if np.isnan(v):
196-
return fi.code_of_nan
197-
198-
# Overflow/underflow
199-
if v > fi.max:
200-
if fi.has_infs:
201-
return fi.code_of_posinf
202-
if fi.num_nans > 0:
203-
return fi.code_of_nan
204-
return fi.code_of_max
205-
206-
if v < fi.min:
207-
if fi.has_infs:
208-
return fi.code_of_neginf
209-
if fi.num_nans > 0:
210-
return fi.code_of_nan
211-
return fi.code_of_min
212-
213-
# Finite values
214-
sign = fi.is_signed and np.signbit(v)
215-
vpos = -v if sign else v
216-
217-
if fi.has_subnormals and vpos <= fi.smallest_subnormal / 2:
218-
isig = 0
219-
biased_exp = 0
220-
else:
221-
sig, exp = np.frexp(vpos)
222-
exp = int(exp) # All calculations in Python ints
223-
224-
# sig in range [0.5, 1)
225-
sig *= 2
226-
exp -= 1
227-
# now sig in range [1, 2)
228-
229-
biased_exp = exp + fi.expBias
230-
if biased_exp < 1 and fi.has_subnormals:
231-
# subnormal
232-
sig *= 2.0 ** (biased_exp - 1)
233-
biased_exp = 0
234-
assert vpos == sig * 2 ** (1 - fi.expBias)
235-
else:
236-
if sig > 0:
237-
sig -= 1.0
238-
239-
isig = math.floor(sig * 2**t)
240-
241-
# Zero
242-
if isig == 0 and biased_exp == 0 and fi.has_zero:
243-
if sign and fi.has_nz:
244-
return fi.code_of_negzero
245-
else:
246-
return fi.code_of_zero
247-
248-
# Nonzero
249-
assert isig < 2**t
250-
assert biased_exp < 2**fi.expBits or fi.is_twos_complement
251-
252-
# Handle two's complement encoding
253-
if fi.is_twos_complement and sign:
254-
isig = (1 << t) - isig
255-
256-
# Pack values into a single integer
257-
code = (int(sign) << (k - 1)) | (biased_exp << t) | (isig << 0)
258-
259-
return code

src/gfloat/round_ndarray.py

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -150,83 +150,3 @@ def round_ndarray(
150150
result = np.where(result == 0, 0.0, result)
151151

152152
return result
153-
154-
155-
def encode_ndarray(fi: FormatInfo, v: np.ndarray) -> np.ndarray:
156-
"""
157-
Vectorized version of :meth:`encode_float`.
158-
159-
Encode inputs to the given :py:class:`FormatInfo`.
160-
161-
Will round toward zero if :paramref:`v` is not in the value set.
162-
Will saturate to `Inf`, `NaN`, `fi.max` in order of precedence.
163-
Encode -0 to 0 if not `fi.has_nz`
164-
165-
For other roundings and saturations, call :func:`round_ndarray` first.
166-
167-
Args:
168-
fi (FormatInfo): Describes the target format
169-
v (float array): The value to be encoded.
170-
171-
Returns:
172-
The integer code point
173-
"""
174-
k = fi.bits
175-
p = fi.precision
176-
t = p - 1
177-
178-
sign = np.signbit(v) & fi.is_signed
179-
vpos = np.where(sign, -v, v)
180-
181-
nan_mask = np.isnan(v)
182-
183-
code = np.zeros_like(v, dtype=np.uint64)
184-
185-
if fi.num_nans > 0:
186-
code[nan_mask] = fi.code_of_nan
187-
else:
188-
assert not np.any(nan_mask)
189-
190-
if fi.has_infs:
191-
code[v > fi.max] = fi.code_of_posinf
192-
code[v < fi.min] = fi.code_of_neginf
193-
else:
194-
code[v > fi.max] = fi.code_of_nan if fi.num_nans > 0 else fi.code_of_max
195-
code[v < fi.min] = fi.code_of_nan if fi.num_nans > 0 else fi.code_of_min
196-
197-
if fi.has_zero:
198-
if fi.has_nz:
199-
code[v == 0] = np.where(sign[v == 0], fi.code_of_negzero, fi.code_of_zero)
200-
else:
201-
code[v == 0] = fi.code_of_zero
202-
203-
finite_mask = (code == 0) & (v != 0)
204-
assert not np.any(np.isnan(vpos[finite_mask]))
205-
if np.any(finite_mask):
206-
finite_vpos = vpos[finite_mask]
207-
finite_sign = sign[finite_mask]
208-
209-
sig, exp = np.frexp(finite_vpos)
210-
211-
biased_exp = exp.astype(np.int64) + (fi.expBias - 1)
212-
subnormal_mask = (biased_exp < 1) & fi.has_subnormals
213-
214-
biased_exp_safe = np.where(subnormal_mask, biased_exp, 0)
215-
tsig = np.where(subnormal_mask, np.ldexp(sig, biased_exp_safe), sig * 2 - 1.0)
216-
biased_exp[subnormal_mask] = 0
217-
218-
isig = np.floor(np.ldexp(tsig, t)).astype(np.int64)
219-
220-
zero_mask = fi.has_zero & (isig == 0) & (biased_exp == 0)
221-
if not fi.has_nz:
222-
finite_sign[zero_mask] = False
223-
224-
# Handle two's complement encoding
225-
if fi.is_twos_complement:
226-
isig[finite_sign] = (1 << t) - isig[finite_sign]
227-
228-
code[finite_mask] = (
229-
(finite_sign.astype(int) << (k - 1)) | (biased_exp << t) | (isig << 0)
230-
)
231-
232-
return code

0 commit comments

Comments
 (0)