4
4
from types import ModuleType
5
5
from .types import FormatInfo , RoundMode
6
6
import numpy as np
7
+ import array_api_compat
7
8
8
9
9
10
def _isodd (v : np .ndarray ) -> np .ndarray :
10
11
return v & 0x1 == 1
11
12
12
13
14
+ def _ldexp (v : np .ndarray , s : np .ndarray ) -> np .ndarray :
15
+ xp = array_api_compat .array_namespace (v , s )
16
+ if (
17
+ array_api_compat .is_torch_array (v )
18
+ or array_api_compat .is_jax_array (v )
19
+ or array_api_compat .is_numpy_array (v )
20
+ ):
21
+ return xp .ldexp (v , s )
22
+
23
+ # Scale away from subnormal/infinite ranges
24
+ offset = 24
25
+ vlo = (v * 2.0 ** + offset ) * 2.0 ** xp .astype (s - offset , v .dtype )
26
+ vhi = (v * 2.0 ** - offset ) * 2.0 ** xp .astype (s + offset , v .dtype )
27
+ return xp .where (v < 1.0 , vlo , vhi )
28
+
29
+
13
30
def round_ndarray (
14
31
fi : FormatInfo ,
15
32
v : np .ndarray ,
16
33
rnd : RoundMode = RoundMode .TiesToEven ,
17
34
sat : bool = False ,
18
35
srbits : Optional [np .ndarray ] = None ,
19
36
srnumbits : int = 0 ,
20
- np : ModuleType = np ,
21
37
) -> np .ndarray :
22
38
"""
23
39
Vectorized version of :meth:`round_float`.
@@ -38,8 +54,6 @@ def round_ndarray(
38
54
srbits (int array): Bits to use for stochastic rounding if rnd == Stochastic.
39
55
srnumbits (int): How many bits are in srbits. Implies srbits < 2**srnumbits.
40
56
41
- np (Module): May be `numpy`, `jax.numpy` or another module cloning numpy
42
-
43
57
Returns:
44
58
An array of floats which is a subset of the format's value set.
45
59
@@ -48,27 +62,38 @@ def round_ndarray(
48
62
(e.g. converting a `NaN`, or an `Inf` when the target has no
49
63
`NaN` or `Inf`, and :paramref:`sat` is false)
50
64
"""
65
+ xp = array_api_compat .array_namespace (v , srbits )
66
+
51
67
p = fi .precision
52
68
bias = fi .expBias
53
69
54
- is_negative = np .signbit (v ) & fi .is_signed
55
- absv = np .where (is_negative , - v , v )
70
+ is_negative = xp .signbit (v ) & fi .is_signed
71
+ absv = xp .where (is_negative , - v , v )
56
72
57
- finite_nonzero = ~ (np .isnan (v ) | np .isinf (v ) | (v == 0 ))
73
+ finite_nonzero = ~ (xp .isnan (v ) | xp .isinf (v ) | (v == 0 ))
58
74
59
75
# Place 1.0 where finite_nonzero is False, to avoid log of {0,inf,nan}
60
- absv_masked = np .where (finite_nonzero , absv , 1.0 )
76
+ absv_masked = xp .where (finite_nonzero , absv , 1.0 )
61
77
62
- expval = np .floor (np .log2 (absv_masked )).astype (int )
78
+ int_type = xp .int64 if fi .k > 8 or srnumbits > 8 else xp .int16
79
+
80
+ def to_int (x : np .ndarray ) -> np .ndarray :
81
+ return xp .astype (x , int_type )
82
+
83
+ def to_float (x : np .ndarray ) -> np .ndarray :
84
+ return xp .astype (x , v .dtype )
85
+
86
+ expval = to_int (xp .floor (xp .log2 (absv_masked )))
63
87
64
88
if fi .has_subnormals :
65
- expval = np .maximum (expval , 1 - bias )
89
+ expval = xp .maximum (expval , 1 - bias )
66
90
67
91
expval = expval - p + 1
68
- fsignificand = np . ldexp (absv_masked , - expval )
92
+ fsignificand = _ldexp (absv_masked , - expval )
69
93
70
- isignificand = np .floor (fsignificand ).astype (np .int64 )
71
- delta = fsignificand - isignificand
94
+ floorfsignificand = xp .floor (fsignificand )
95
+ isignificand = to_int (floorfsignificand )
96
+ delta = fsignificand - floorfsignificand
72
97
73
98
if fi .precision > 1 :
74
99
code_is_odd = _isodd (isignificand )
@@ -77,7 +102,7 @@ def round_ndarray(
77
102
78
103
match rnd :
79
104
case RoundMode .TowardZero :
80
- should_round_away = np .zeros_like (delta , dtype = bool )
105
+ should_round_away = xp .zeros_like (delta , dtype = xp . bool )
81
106
82
107
case RoundMode .TowardPositive :
83
108
should_round_away = ~ is_negative & (delta > 0 )
@@ -95,38 +120,44 @@ def round_ndarray(
95
120
assert srbits is not None
96
121
## RTNE delta to srbits
97
122
d = delta * 2.0 ** srnumbits
98
- floord = np .floor (d ).astype (np .int64 )
99
- dd = d - floord
100
- drnd = floord + (dd > 0.5 ) + ((dd == 0.5 ) & _isodd (floord ))
123
+ floord = to_int (xp .floor (d ))
124
+ dd = d - xp .floor (d )
125
+ should_round_away_tne = (dd > 0.5 ) | ((dd == 0.5 ) & _isodd (floord ))
126
+ drnd = floord + xp .astype (should_round_away_tne , floord .dtype )
101
127
102
- should_round_away = drnd + srbits >= 2.0 ** srnumbits
128
+ should_round_away = drnd + srbits >= 2 ** srnumbits
103
129
104
130
case RoundMode .StochasticOdd :
105
131
assert srbits is not None
106
132
## RTNO delta to srbits
107
133
d = delta * 2.0 ** srnumbits
108
- floord = np .floor (d ).astype (np .int64 )
109
- dd = d - floord
110
- drnd = floord + (dd > 0.5 ) + ((dd == 0.5 ) & ~ _isodd (floord ))
134
+ floord = to_int (xp .floor (d ))
135
+ dd = d - xp .floor (d )
136
+ should_round_away_tno = (dd > 0.5 ) | ((dd == 0.5 ) & ~ _isodd (floord ))
137
+ drnd = floord + xp .astype (should_round_away_tno , floord .dtype )
111
138
112
- should_round_away = drnd + srbits >= 2.0 ** srnumbits
139
+ should_round_away = drnd + srbits >= 2 ** srnumbits
113
140
114
141
case RoundMode .StochasticFast :
115
142
assert srbits is not None
116
- should_round_away = delta + (2 * srbits + 1 ) * 2.0 ** - (1 + srnumbits ) >= 1.0
143
+ should_round_away = (
144
+ delta + to_float (2 * srbits + 1 ) * 2.0 ** - (1 + srnumbits ) >= 1.0
145
+ )
117
146
118
147
case RoundMode .StochasticFastest :
119
148
assert srbits is not None
120
- should_round_away = delta + srbits * 2.0 ** - srnumbits >= 1.0
149
+ should_round_away = delta + to_float (srbits ) * 2.0 ** - srnumbits >= 1.0
150
+
151
+ isignificand = xp .where (should_round_away , isignificand + 1 , isignificand )
121
152
122
- isignificand = np . where ( should_round_away , isignificand + 1 , isignificand )
153
+ fresult = _ldexp ( to_float ( isignificand ), expval )
123
154
124
- result = np .where (finite_nonzero , np . ldexp ( isignificand , expval ) , absv )
155
+ result = xp .where (finite_nonzero , fresult , absv )
125
156
126
- amax = np .where (is_negative , - fi .min , fi .max )
157
+ amax = xp .where (is_negative , - fi .min , fi .max )
127
158
128
159
if sat :
129
- result = np .where (result > amax , amax , result )
160
+ result = xp .where (result > amax , amax , result )
130
161
else :
131
162
match rnd :
132
163
case RoundMode .TowardNegative :
@@ -136,25 +167,25 @@ def round_ndarray(
136
167
case RoundMode .TowardZero :
137
168
put_amax_at = result > amax
138
169
case _:
139
- put_amax_at = np .zeros_like (result , dtype = bool )
170
+ put_amax_at = xp .zeros_like (result , dtype = xp . bool )
140
171
141
- result = np .where (finite_nonzero & put_amax_at , amax , result )
172
+ result = xp .where (finite_nonzero & put_amax_at , amax , result )
142
173
143
174
# Now anything larger than amax goes to infinity or NaN
144
175
if fi .has_infs :
145
- result = np .where (result > amax , np .inf , result )
176
+ result = xp .where (result > amax , xp .inf , result )
146
177
elif fi .num_nans > 0 :
147
- result = np .where (result > amax , np .nan , result )
178
+ result = xp .where (result > amax , xp .nan , result )
148
179
else :
149
- if np .any (result > amax ):
180
+ if xp .any (result > amax ):
150
181
raise ValueError (f"No Infs or NaNs in format { fi } , and sat=False" )
151
182
152
- result = np .where (is_negative , - result , result )
183
+ result = xp .where (is_negative , - result , result )
153
184
154
185
# Make negative zeros negative if has_nz, else make them not negative.
155
186
if fi .has_nz :
156
- result = np .where ((result == 0 ) & is_negative , - 0.0 , result )
187
+ result = xp .where ((result == 0 ) & is_negative , - 0.0 , result )
157
188
else :
158
- result = np .where (result == 0 , 0.0 , result )
189
+ result = xp .where (result == 0 , 0.0 , result )
159
190
160
191
return result
0 commit comments