Skip to content

Commit 2d2a250

Browse files
authored
Use array-api for cross-framework compatibility (#45)
#45
1 parent 279e4f1 commit 2d2a250

File tree

10 files changed

+163
-84
lines changed

10 files changed

+163
-84
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,6 @@ cython_debug/
162162
#.idea/
163163
.vscode/settings.json
164164
.vscode/launch.json
165+
166+
# Local
167+
tmp/

docs/source/04-benchmark.ipynb

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 1,
5+
"execution_count": 3,
66
"metadata": {},
77
"outputs": [],
88
"source": [
@@ -34,24 +34,17 @@
3434
},
3535
{
3636
"cell_type": "code",
37-
"execution_count": 2,
37+
"execution_count": 4,
3838
"metadata": {},
3939
"outputs": [
40-
{
41-
"name": "stderr",
42-
"output_type": "stream",
43-
"text": [
44-
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
45-
]
46-
},
4740
{
4841
"name": "stdout",
4942
"output_type": "stream",
5043
"text": [
51-
"GFloat scalar : 6306.18 nsec (25 runs at size 10000)\n",
52-
"GFloat vectorized, numpy arrays: 52.52 nsec (25 runs at size 1000000)\n",
53-
"GFloat vectorized, JAX JIT : 3.04 nsec (500 runs at size 1000000)\n",
54-
"ML_dtypes : 2.69 nsec (500 runs at size 1000000)\n"
44+
"GFloat scalar : 7510.22 nsec (25 runs at size 10000)\n",
45+
"GFloat vectorized, numpy arrays: 43.82 nsec (25 runs at size 1000000)\n",
46+
"GFloat vectorized, JAX JIT : 2.69 nsec (500 runs at size 1000000)\n",
47+
"ML_dtypes : 2.57 nsec (500 runs at size 1000000)\n"
5548
]
5649
}
5750
],
@@ -61,7 +54,7 @@
6154
"N = 1_000_000\n",
6255
"a = np.random.rand(N)\n",
6356
"\n",
64-
"jax_round_jit = jax.jit(lambda x: gfloat.round_ndarray(format_info_ocp_e5m2, x, np=jnp))\n",
57+
"jax_round_jit = jax.jit(lambda x: gfloat.round_ndarray(format_info_ocp_e5m2, x))\n",
6558
"ja = jnp.array(a)\n",
6659
"jax_round_jit(ja) # Cache compilation\n",
6760
"\n",
@@ -108,7 +101,7 @@
108101
],
109102
"metadata": {
110103
"kernelspec": {
111-
"display_name": "Python 3",
104+
"display_name": ".venv",
112105
"language": "python",
113106
"name": "python3"
114107
},

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ fast = true
3636

3737
[tool.mypy]
3838
[[tool.mypy.overrides]]
39-
module = "mx.*"
39+
module = ["mx.*", "array_api_compat.*", "array_api_strict.*"]
4040
ignore_missing_imports = true
4141

4242
[tool.pytest.ini_options]

requirements-dev.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ nbval
44
ml_dtypes
55
jaxlib
66
jax
7+
torch
8+
array-api-strict
79
airium
810
pandas
911
matplotlib

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
numpy
22
more_itertools
3+
array-api-compat

src/gfloat/round_ndarray.py

Lines changed: 66 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,36 @@
44
from types import ModuleType
55
from .types import FormatInfo, RoundMode
66
import numpy as np
7+
import array_api_compat
78

89

910
def _isodd(v: np.ndarray) -> np.ndarray:
1011
return v & 0x1 == 1
1112

1213

14+
def _ldexp(v: np.ndarray, s: np.ndarray) -> np.ndarray:
15+
xp = array_api_compat.array_namespace(v, s)
16+
if (
17+
array_api_compat.is_torch_array(v)
18+
or array_api_compat.is_jax_array(v)
19+
or array_api_compat.is_numpy_array(v)
20+
):
21+
return xp.ldexp(v, s)
22+
23+
# Scale away from subnormal/infinite ranges
24+
offset = 24
25+
vlo = (v * 2.0**+offset) * 2.0 ** xp.astype(s - offset, v.dtype)
26+
vhi = (v * 2.0**-offset) * 2.0 ** xp.astype(s + offset, v.dtype)
27+
return xp.where(v < 1.0, vlo, vhi)
28+
29+
1330
def round_ndarray(
1431
fi: FormatInfo,
1532
v: np.ndarray,
1633
rnd: RoundMode = RoundMode.TiesToEven,
1734
sat: bool = False,
1835
srbits: Optional[np.ndarray] = None,
1936
srnumbits: int = 0,
20-
np: ModuleType = np,
2137
) -> np.ndarray:
2238
"""
2339
Vectorized version of :meth:`round_float`.
@@ -38,8 +54,6 @@ def round_ndarray(
3854
srbits (int array): Bits to use for stochastic rounding if rnd == Stochastic.
3955
srnumbits (int): How many bits are in srbits. Implies srbits < 2**srnumbits.
4056
41-
np (Module): May be `numpy`, `jax.numpy` or another module cloning numpy
42-
4357
Returns:
4458
An array of floats which is a subset of the format's value set.
4559
@@ -48,27 +62,38 @@ def round_ndarray(
4862
(e.g. converting a `NaN`, or an `Inf` when the target has no
4963
`NaN` or `Inf`, and :paramref:`sat` is false)
5064
"""
65+
xp = array_api_compat.array_namespace(v, srbits)
66+
5167
p = fi.precision
5268
bias = fi.expBias
5369

54-
is_negative = np.signbit(v) & fi.is_signed
55-
absv = np.where(is_negative, -v, v)
70+
is_negative = xp.signbit(v) & fi.is_signed
71+
absv = xp.where(is_negative, -v, v)
5672

57-
finite_nonzero = ~(np.isnan(v) | np.isinf(v) | (v == 0))
73+
finite_nonzero = ~(xp.isnan(v) | xp.isinf(v) | (v == 0))
5874

5975
# Place 1.0 where finite_nonzero is False, to avoid log of {0,inf,nan}
60-
absv_masked = np.where(finite_nonzero, absv, 1.0)
76+
absv_masked = xp.where(finite_nonzero, absv, 1.0)
6177

62-
expval = np.floor(np.log2(absv_masked)).astype(int)
78+
int_type = xp.int64 if fi.k > 8 or srnumbits > 8 else xp.int16
79+
80+
def to_int(x: np.ndarray) -> np.ndarray:
81+
return xp.astype(x, int_type)
82+
83+
def to_float(x: np.ndarray) -> np.ndarray:
84+
return xp.astype(x, v.dtype)
85+
86+
expval = to_int(xp.floor(xp.log2(absv_masked)))
6387

6488
if fi.has_subnormals:
65-
expval = np.maximum(expval, 1 - bias)
89+
expval = xp.maximum(expval, 1 - bias)
6690

6791
expval = expval - p + 1
68-
fsignificand = np.ldexp(absv_masked, -expval)
92+
fsignificand = _ldexp(absv_masked, -expval)
6993

70-
isignificand = np.floor(fsignificand).astype(np.int64)
71-
delta = fsignificand - isignificand
94+
floorfsignificand = xp.floor(fsignificand)
95+
isignificand = to_int(floorfsignificand)
96+
delta = fsignificand - floorfsignificand
7297

7398
if fi.precision > 1:
7499
code_is_odd = _isodd(isignificand)
@@ -77,7 +102,7 @@ def round_ndarray(
77102

78103
match rnd:
79104
case RoundMode.TowardZero:
80-
should_round_away = np.zeros_like(delta, dtype=bool)
105+
should_round_away = xp.zeros_like(delta, dtype=xp.bool)
81106

82107
case RoundMode.TowardPositive:
83108
should_round_away = ~is_negative & (delta > 0)
@@ -95,38 +120,44 @@ def round_ndarray(
95120
assert srbits is not None
96121
## RTNE delta to srbits
97122
d = delta * 2.0**srnumbits
98-
floord = np.floor(d).astype(np.int64)
99-
dd = d - floord
100-
drnd = floord + (dd > 0.5) + ((dd == 0.5) & _isodd(floord))
123+
floord = to_int(xp.floor(d))
124+
dd = d - xp.floor(d)
125+
should_round_away_tne = (dd > 0.5) | ((dd == 0.5) & _isodd(floord))
126+
drnd = floord + xp.astype(should_round_away_tne, floord.dtype)
101127

102-
should_round_away = drnd + srbits >= 2.0**srnumbits
128+
should_round_away = drnd + srbits >= 2**srnumbits
103129

104130
case RoundMode.StochasticOdd:
105131
assert srbits is not None
106132
## RTNO delta to srbits
107133
d = delta * 2.0**srnumbits
108-
floord = np.floor(d).astype(np.int64)
109-
dd = d - floord
110-
drnd = floord + (dd > 0.5) + ((dd == 0.5) & ~_isodd(floord))
134+
floord = to_int(xp.floor(d))
135+
dd = d - xp.floor(d)
136+
should_round_away_tno = (dd > 0.5) | ((dd == 0.5) & ~_isodd(floord))
137+
drnd = floord + xp.astype(should_round_away_tno, floord.dtype)
111138

112-
should_round_away = drnd + srbits >= 2.0**srnumbits
139+
should_round_away = drnd + srbits >= 2**srnumbits
113140

114141
case RoundMode.StochasticFast:
115142
assert srbits is not None
116-
should_round_away = delta + (2 * srbits + 1) * 2.0 ** -(1 + srnumbits) >= 1.0
143+
should_round_away = (
144+
delta + to_float(2 * srbits + 1) * 2.0 ** -(1 + srnumbits) >= 1.0
145+
)
117146

118147
case RoundMode.StochasticFastest:
119148
assert srbits is not None
120-
should_round_away = delta + srbits * 2.0**-srnumbits >= 1.0
149+
should_round_away = delta + to_float(srbits) * 2.0**-srnumbits >= 1.0
150+
151+
isignificand = xp.where(should_round_away, isignificand + 1, isignificand)
121152

122-
isignificand = np.where(should_round_away, isignificand + 1, isignificand)
153+
fresult = _ldexp(to_float(isignificand), expval)
123154

124-
result = np.where(finite_nonzero, np.ldexp(isignificand, expval), absv)
155+
result = xp.where(finite_nonzero, fresult, absv)
125156

126-
amax = np.where(is_negative, -fi.min, fi.max)
157+
amax = xp.where(is_negative, -fi.min, fi.max)
127158

128159
if sat:
129-
result = np.where(result > amax, amax, result)
160+
result = xp.where(result > amax, amax, result)
130161
else:
131162
match rnd:
132163
case RoundMode.TowardNegative:
@@ -136,25 +167,25 @@ def round_ndarray(
136167
case RoundMode.TowardZero:
137168
put_amax_at = result > amax
138169
case _:
139-
put_amax_at = np.zeros_like(result, dtype=bool)
170+
put_amax_at = xp.zeros_like(result, dtype=xp.bool)
140171

141-
result = np.where(finite_nonzero & put_amax_at, amax, result)
172+
result = xp.where(finite_nonzero & put_amax_at, amax, result)
142173

143174
# Now anything larger than amax goes to infinity or NaN
144175
if fi.has_infs:
145-
result = np.where(result > amax, np.inf, result)
176+
result = xp.where(result > amax, xp.inf, result)
146177
elif fi.num_nans > 0:
147-
result = np.where(result > amax, np.nan, result)
178+
result = xp.where(result > amax, xp.nan, result)
148179
else:
149-
if np.any(result > amax):
180+
if xp.any(result > amax):
150181
raise ValueError(f"No Infs or NaNs in format {fi}, and sat=False")
151182

152-
result = np.where(is_negative, -result, result)
183+
result = xp.where(is_negative, -result, result)
153184

154185
# Make negative zeros negative if has_nz, else make them not negative.
155186
if fi.has_nz:
156-
result = np.where((result == 0) & is_negative, -0.0, result)
187+
result = xp.where((result == 0) & is_negative, -0.0, result)
157188
else:
158-
result = np.where(result == 0, 0.0, result)
189+
result = xp.where(result == 0, 0.0, result)
159190

160191
return result

src/gfloat/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def min(self) -> float:
232232
return -self.max
233233
else:
234234
assert not self.has_infs and self.num_high_nans == 0 and not self.has_nz
235-
return -(2 ** (self.emax + 1))
235+
return -(2.0 ** (self.emax + 1))
236236
elif self.has_zero:
237237
return 0.0
238238
else:

test/test_array_api.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
2+
3+
import array_api_strict as xp
4+
import numpy as np
5+
import pytest
6+
7+
from gfloat import (
8+
RoundMode,
9+
FormatInfo,
10+
decode_float,
11+
decode_ndarray,
12+
round_float,
13+
round_ndarray,
14+
)
15+
from gfloat.formats import *
16+
17+
xp.set_array_api_strict_flags(api_version="2024.12")
18+
19+
# Hack until https://github.com/data-apis/array-api/issues/807
20+
_xp_where = xp.where
21+
xp.where = lambda a, t, f: _xp_where(a, xp.asarray(t), xp.asarray(f))
22+
_xp_maximum = xp.maximum
23+
xp.maximum = lambda a, b: _xp_maximum(xp.asarray(a), xp.asarray(b))
24+
25+
26+
@pytest.mark.parametrize("fi", all_formats)
27+
@pytest.mark.parametrize("rnd", RoundMode)
28+
@pytest.mark.parametrize("sat", [True, False])
29+
def test_array_api(fi: FormatInfo, rnd: RoundMode, sat: bool) -> None:
30+
a = np.random.rand(23, 1, 34) - 0.5
31+
a = xp.asarray(a)
32+
33+
srnumbits = 32
34+
srbits = np.random.randint(0, 2**srnumbits, a.shape)
35+
srbits = xp.asarray(srbits)
36+
37+
round_ndarray(fi, a, rnd, sat, srbits=srbits, srnumbits=srnumbits)

test/test_jax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ def test_jax() -> None:
2222
a8 = a.astype(ml_dtypes.float8_e5m2).astype(jnp.float64)
2323

2424
fi = format_info_ocp_e5m2
25-
j8 = gfloat.round_ndarray(fi, jnp.array(a), np=jnp) # type: ignore [arg-type]
25+
j8 = gfloat.round_ndarray(fi, jnp.array(a)) # type: ignore [arg-type]
2626

2727
np.testing.assert_equal(a8, j8)
2828

29-
jax_round_array = jax.jit(lambda x: gfloat.round_ndarray(fi, x, np=jnp))
29+
jax_round_array = jax.jit(lambda x: gfloat.round_ndarray(fi, x))
3030
j8i = jax_round_array(a)
3131

3232
np.testing.assert_equal(a8, j8i)

0 commit comments

Comments
 (0)