@@ -12,7 +12,12 @@ def _isodd(v: int) -> bool:
12
12
13
13
14
14
def round_float (
15
- fi : FormatInfo , v : float , rnd : RoundMode = RoundMode .TiesToEven , sat : bool = False
15
+ fi : FormatInfo ,
16
+ v : float ,
17
+ rnd : RoundMode = RoundMode .TiesToEven ,
18
+ sat : bool = False ,
19
+ srbits : int = - 1 ,
20
+ srnumbits : int = 0 ,
16
21
) -> float :
17
22
"""
18
23
Round input to the given :py:class:`FormatInfo`, given rounding mode and saturation flag
@@ -27,6 +32,8 @@ def round_float(
27
32
v (float): Input value to be rounded
28
33
rnd (RoundMode): Rounding mode to use
29
34
sat (bool): Saturation flag: if True, round overflowed values to `fi.max`
35
+ srbits (int): Bits to use for stochastic rounding if rnd == Stochastic.
36
+ srnumbits (int): How many bits are in srbits. Implies srbits < 2**srnumbits.
30
37
31
38
Returns:
32
39
A float which is one of the values in the format.
@@ -35,12 +42,17 @@ def round_float(
35
42
ValueError: The target format cannot represent the input
36
43
(e.g. converting a `NaN`, or an `Inf` when the target has no
37
44
`NaN` or `Inf`, and :paramref:`sat` is false)
45
+ ValueError: Inconsistent arguments, e.g. srnumbits >= 2**srnumbits
38
46
"""
39
47
40
48
# Constants
41
49
p = fi .precision
42
50
bias = fi .expBias
43
51
52
+ if rnd == RoundMode .Stochastic :
53
+ if srbits >= 2 ** srnumbits :
54
+ raise ValueError (f"srnumbits={ srnumbits } >= 2**srnumbits={ 2 ** srnumbits } " )
55
+
44
56
if np .isnan (v ):
45
57
if fi .num_nans == 0 :
46
58
raise ValueError (f"No NaN in format { fi } " )
@@ -75,12 +87,15 @@ def round_float(
75
87
# Round
76
88
isignificand = math .floor (fsignificand )
77
89
delta = fsignificand - isignificand
90
+
91
+ # fmt: off
78
92
if (
79
93
(rnd == RoundMode .TowardPositive and not sign and delta > 0 )
80
94
or (rnd == RoundMode .TowardNegative and sign and delta > 0 )
81
95
or (rnd == RoundMode .TiesToAway and delta >= 0.5 )
82
96
or (rnd == RoundMode .TiesToEven and delta > 0.5 )
83
97
or (rnd == RoundMode .TiesToEven and delta == 0.5 and _isodd (isignificand ))
98
+ or (rnd == RoundMode .Stochastic and delta > (0.5 + srbits ) * 2.0 ** - srnumbits )
84
99
):
85
100
isignificand += 1
86
101
@@ -95,6 +110,7 @@ def round_float(
95
110
or (rnd == RoundMode .TiesToAway and delta >= 0.5 )
96
111
or (rnd == RoundMode .TiesToEven and delta > 0.5 )
97
112
or (rnd == RoundMode .TiesToEven and delta == 0.5 and code_is_odd )
113
+ or (rnd == RoundMode .Stochastic and delta > (0.5 + srbits ) * 2.0 ** - srnumbits )
98
114
):
99
115
# Go to nextUp.
100
116
# Increment isignificand if zero,
@@ -105,6 +121,7 @@ def round_float(
105
121
assert isignificand == 1
106
122
expval += 1
107
123
## End special case for Precision=1.
124
+ # fmt: on
108
125
109
126
# Reconstruct rounded result to float
110
127
result = isignificand * (2.0 ** expval )
0 commit comments