diff --git a/.gitignore b/.gitignore index 57fa349..625d238 100644 --- a/.gitignore +++ b/.gitignore @@ -162,3 +162,6 @@ cython_debug/ #.idea/ .vscode/settings.json .vscode/launch.json + +# Local +tmp/ diff --git a/docs/source/04-benchmark.ipynb b/docs/source/04-benchmark.ipynb index 2e964c2..1318d19 100644 --- a/docs/source/04-benchmark.ipynb +++ b/docs/source/04-benchmark.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -34,24 +34,17 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "GFloat scalar : 6306.18 nsec (25 runs at size 10000)\n", - "GFloat vectorized, numpy arrays: 52.52 nsec (25 runs at size 1000000)\n", - "GFloat vectorized, JAX JIT : 3.04 nsec (500 runs at size 1000000)\n", - "ML_dtypes : 2.69 nsec (500 runs at size 1000000)\n" + "GFloat scalar : 7510.22 nsec (25 runs at size 10000)\n", + "GFloat vectorized, numpy arrays: 43.82 nsec (25 runs at size 1000000)\n", + "GFloat vectorized, JAX JIT : 2.69 nsec (500 runs at size 1000000)\n", + "ML_dtypes : 2.57 nsec (500 runs at size 1000000)\n" ] } ], @@ -61,7 +54,7 @@ "N = 1_000_000\n", "a = np.random.rand(N)\n", "\n", - "jax_round_jit = jax.jit(lambda x: gfloat.round_ndarray(format_info_ocp_e5m2, x, np=jnp))\n", + "jax_round_jit = jax.jit(lambda x: gfloat.round_ndarray(format_info_ocp_e5m2, x))\n", "ja = jnp.array(a)\n", "jax_round_jit(ja) # Cache compilation\n", "\n", @@ -108,7 +101,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": ".venv", "language": "python", "name": "python3" }, diff --git a/pyproject.toml b/pyproject.toml index bb75f7c..4a4d103 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ fast = true [tool.mypy] [[tool.mypy.overrides]] -module = "mx.*" +module = ["mx.*", "array_api_compat.*", "array_api_strict.*"] ignore_missing_imports = true [tool.pytest.ini_options] diff --git a/requirements-dev.txt b/requirements-dev.txt index 675a4c3..1020b42 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,6 +4,8 @@ nbval ml_dtypes jaxlib jax +torch +array-api-strict airium pandas matplotlib diff --git a/requirements.txt b/requirements.txt index 4a5a8c5..ab2d116 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ numpy more_itertools +array-api-compat diff --git a/src/gfloat/round_ndarray.py b/src/gfloat/round_ndarray.py index fe1ef54..93ca0d3 100644 --- a/src/gfloat/round_ndarray.py +++ b/src/gfloat/round_ndarray.py @@ -4,12 +4,29 @@ from types import ModuleType from .types import FormatInfo, RoundMode import numpy as np +import array_api_compat def _isodd(v: np.ndarray) -> np.ndarray: return v & 0x1 == 1 +def _ldexp(v: np.ndarray, s: np.ndarray) -> np.ndarray: + xp = array_api_compat.array_namespace(v, s) + if ( + array_api_compat.is_torch_array(v) + or array_api_compat.is_jax_array(v) + or array_api_compat.is_numpy_array(v) + ): + return xp.ldexp(v, s) + + # Scale away from subnormal/infinite ranges + offset = 24 + vlo = (v * 2.0**+offset) * 2.0 ** xp.astype(s - offset, v.dtype) + vhi = (v * 2.0**-offset) * 2.0 ** xp.astype(s + offset, v.dtype) + return xp.where(v < 1.0, vlo, vhi) + + def round_ndarray( fi: FormatInfo, v: np.ndarray, @@ -17,7 +34,6 @@ def round_ndarray( sat: bool = False, srbits: Optional[np.ndarray] = None, srnumbits: int = 0, - np: ModuleType = np, ) -> np.ndarray: """ Vectorized version of :meth:`round_float`. @@ -38,8 +54,6 @@ def round_ndarray( srbits (int array): Bits to use for stochastic rounding if rnd == Stochastic. srnumbits (int): How many bits are in srbits. Implies srbits < 2**srnumbits. - np (Module): May be `numpy`, `jax.numpy` or another module cloning numpy - Returns: An array of floats which is a subset of the format's value set. @@ -48,27 +62,38 @@ def round_ndarray( (e.g. converting a `NaN`, or an `Inf` when the target has no `NaN` or `Inf`, and :paramref:`sat` is false) """ + xp = array_api_compat.array_namespace(v, srbits) + p = fi.precision bias = fi.expBias - is_negative = np.signbit(v) & fi.is_signed - absv = np.where(is_negative, -v, v) + is_negative = xp.signbit(v) & fi.is_signed + absv = xp.where(is_negative, -v, v) - finite_nonzero = ~(np.isnan(v) | np.isinf(v) | (v == 0)) + finite_nonzero = ~(xp.isnan(v) | xp.isinf(v) | (v == 0)) # Place 1.0 where finite_nonzero is False, to avoid log of {0,inf,nan} - absv_masked = np.where(finite_nonzero, absv, 1.0) + absv_masked = xp.where(finite_nonzero, absv, 1.0) - expval = np.floor(np.log2(absv_masked)).astype(int) + int_type = xp.int64 if fi.k > 8 or srnumbits > 8 else xp.int16 + + def to_int(x: np.ndarray) -> np.ndarray: + return xp.astype(x, int_type) + + def to_float(x: np.ndarray) -> np.ndarray: + return xp.astype(x, v.dtype) + + expval = to_int(xp.floor(xp.log2(absv_masked))) if fi.has_subnormals: - expval = np.maximum(expval, 1 - bias) + expval = xp.maximum(expval, 1 - bias) expval = expval - p + 1 - fsignificand = np.ldexp(absv_masked, -expval) + fsignificand = _ldexp(absv_masked, -expval) - isignificand = np.floor(fsignificand).astype(np.int64) - delta = fsignificand - isignificand + floorfsignificand = xp.floor(fsignificand) + isignificand = to_int(floorfsignificand) + delta = fsignificand - floorfsignificand if fi.precision > 1: code_is_odd = _isodd(isignificand) @@ -77,7 +102,7 @@ def round_ndarray( match rnd: case RoundMode.TowardZero: - should_round_away = np.zeros_like(delta, dtype=bool) + should_round_away = xp.zeros_like(delta, dtype=xp.bool) case RoundMode.TowardPositive: should_round_away = ~is_negative & (delta > 0) @@ -95,38 +120,44 @@ def round_ndarray( assert srbits is not None ## RTNE delta to srbits d = delta * 2.0**srnumbits - floord = np.floor(d).astype(np.int64) - dd = d - floord - drnd = floord + (dd > 0.5) + ((dd == 0.5) & _isodd(floord)) + floord = to_int(xp.floor(d)) + dd = d - xp.floor(d) + should_round_away_tne = (dd > 0.5) | ((dd == 0.5) & _isodd(floord)) + drnd = floord + xp.astype(should_round_away_tne, floord.dtype) - should_round_away = drnd + srbits >= 2.0**srnumbits + should_round_away = drnd + srbits >= 2**srnumbits case RoundMode.StochasticOdd: assert srbits is not None ## RTNO delta to srbits d = delta * 2.0**srnumbits - floord = np.floor(d).astype(np.int64) - dd = d - floord - drnd = floord + (dd > 0.5) + ((dd == 0.5) & ~_isodd(floord)) + floord = to_int(xp.floor(d)) + dd = d - xp.floor(d) + should_round_away_tno = (dd > 0.5) | ((dd == 0.5) & ~_isodd(floord)) + drnd = floord + xp.astype(should_round_away_tno, floord.dtype) - should_round_away = drnd + srbits >= 2.0**srnumbits + should_round_away = drnd + srbits >= 2**srnumbits case RoundMode.StochasticFast: assert srbits is not None - should_round_away = delta + (2 * srbits + 1) * 2.0 ** -(1 + srnumbits) >= 1.0 + should_round_away = ( + delta + to_float(2 * srbits + 1) * 2.0 ** -(1 + srnumbits) >= 1.0 + ) case RoundMode.StochasticFastest: assert srbits is not None - should_round_away = delta + srbits * 2.0**-srnumbits >= 1.0 + should_round_away = delta + to_float(srbits) * 2.0**-srnumbits >= 1.0 + + isignificand = xp.where(should_round_away, isignificand + 1, isignificand) - isignificand = np.where(should_round_away, isignificand + 1, isignificand) + fresult = _ldexp(to_float(isignificand), expval) - result = np.where(finite_nonzero, np.ldexp(isignificand, expval), absv) + result = xp.where(finite_nonzero, fresult, absv) - amax = np.where(is_negative, -fi.min, fi.max) + amax = xp.where(is_negative, -fi.min, fi.max) if sat: - result = np.where(result > amax, amax, result) + result = xp.where(result > amax, amax, result) else: match rnd: case RoundMode.TowardNegative: @@ -136,25 +167,25 @@ def round_ndarray( case RoundMode.TowardZero: put_amax_at = result > amax case _: - put_amax_at = np.zeros_like(result, dtype=bool) + put_amax_at = xp.zeros_like(result, dtype=xp.bool) - result = np.where(finite_nonzero & put_amax_at, amax, result) + result = xp.where(finite_nonzero & put_amax_at, amax, result) # Now anything larger than amax goes to infinity or NaN if fi.has_infs: - result = np.where(result > amax, np.inf, result) + result = xp.where(result > amax, xp.inf, result) elif fi.num_nans > 0: - result = np.where(result > amax, np.nan, result) + result = xp.where(result > amax, xp.nan, result) else: - if np.any(result > amax): + if xp.any(result > amax): raise ValueError(f"No Infs or NaNs in format {fi}, and sat=False") - result = np.where(is_negative, -result, result) + result = xp.where(is_negative, -result, result) # Make negative zeros negative if has_nz, else make them not negative. if fi.has_nz: - result = np.where((result == 0) & is_negative, -0.0, result) + result = xp.where((result == 0) & is_negative, -0.0, result) else: - result = np.where(result == 0, 0.0, result) + result = xp.where(result == 0, 0.0, result) return result diff --git a/src/gfloat/types.py b/src/gfloat/types.py index f67fad1..4c6cd66 100644 --- a/src/gfloat/types.py +++ b/src/gfloat/types.py @@ -232,7 +232,7 @@ def min(self) -> float: return -self.max else: assert not self.has_infs and self.num_high_nans == 0 and not self.has_nz - return -(2 ** (self.emax + 1)) + return -(2.0 ** (self.emax + 1)) elif self.has_zero: return 0.0 else: diff --git a/test/test_array_api.py b/test/test_array_api.py new file mode 100644 index 0000000..504b4e7 --- /dev/null +++ b/test/test_array_api.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024 Graphcore Ltd. All rights reserved. + +import array_api_strict as xp +import numpy as np +import pytest + +from gfloat import ( + RoundMode, + FormatInfo, + decode_float, + decode_ndarray, + round_float, + round_ndarray, +) +from gfloat.formats import * + +xp.set_array_api_strict_flags(api_version="2024.12") + +# Hack until https://github.com/data-apis/array-api/issues/807 +_xp_where = xp.where +xp.where = lambda a, t, f: _xp_where(a, xp.asarray(t), xp.asarray(f)) +_xp_maximum = xp.maximum +xp.maximum = lambda a, b: _xp_maximum(xp.asarray(a), xp.asarray(b)) + + +@pytest.mark.parametrize("fi", all_formats) +@pytest.mark.parametrize("rnd", RoundMode) +@pytest.mark.parametrize("sat", [True, False]) +def test_array_api(fi: FormatInfo, rnd: RoundMode, sat: bool) -> None: + a = np.random.rand(23, 1, 34) - 0.5 + a = xp.asarray(a) + + srnumbits = 32 + srbits = np.random.randint(0, 2**srnumbits, a.shape) + srbits = xp.asarray(srbits) + + round_ndarray(fi, a, rnd, sat, srbits=srbits, srnumbits=srnumbits) diff --git a/test/test_jax.py b/test/test_jax.py index 3422cb3..481fcbd 100644 --- a/test/test_jax.py +++ b/test/test_jax.py @@ -22,11 +22,11 @@ def test_jax() -> None: a8 = a.astype(ml_dtypes.float8_e5m2).astype(jnp.float64) fi = format_info_ocp_e5m2 - j8 = gfloat.round_ndarray(fi, jnp.array(a), np=jnp) # type: ignore [arg-type] + j8 = gfloat.round_ndarray(fi, jnp.array(a)) # type: ignore [arg-type] np.testing.assert_equal(a8, j8) - jax_round_array = jax.jit(lambda x: gfloat.round_ndarray(fi, x, np=jnp)) + jax_round_array = jax.jit(lambda x: gfloat.round_ndarray(fi, x)) j8i = jax_round_array(a) np.testing.assert_equal(a8, j8i) diff --git a/test/test_round.py b/test/test_round.py index fc2c979..badea4d 100644 --- a/test/test_round.py +++ b/test/test_round.py @@ -19,7 +19,8 @@ def rnd_scalar( def rnd_array( fi: FormatInfo, v: float, mode: RoundMode = RoundMode.TiesToEven, sat: bool = False ) -> float: - return round_ndarray(fi, np.array([v]), mode, sat).item() + a = round_ndarray(fi, np.asarray([v]), mode, sat) + return float(a[0]) @pytest.mark.parametrize("round_float", (rnd_scalar, rnd_array)) @@ -535,18 +536,27 @@ def test_stochastic_rounding( count_v1 = np.sum(rs == v1) print(f"SRBits={srnumbits}, observed = {count_v1}, expected = {expected_up_count} ") - # e.g. if expected is 1250/10000, want to be within 0.5,1.5 + # e.g. if expected is 1250/10000, want to be within 0.75,1.25 # this is loose, but should still catch logic errors atol = n * 2.0 ** (-1 - srnumbits) - np.testing.assert_allclose(count_v1, expected_up_count, atol=atol) + np.testing.assert_allclose(count_v1, expected_up_count, atol=atol / 2) @pytest.mark.parametrize( "rnd", - (RoundMode.Stochastic, RoundMode.StochasticFast, RoundMode.StochasticFastest), + ( + RoundMode.Stochastic, + RoundMode.StochasticOdd, + RoundMode.StochasticFast, + RoundMode.StochasticFastest, + ), ) -def test_stochastic_rounding_scalar_eq_array(rnd: RoundMode) -> None: - fi = format_info_p3109(8, 3) +@pytest.mark.parametrize("srnumbits", [3, 8, 9, 16, 32]) +@pytest.mark.parametrize("sat", (True, False)) +def test_stochastic_rounding_scalar_eq_array( + rnd: RoundMode, srnumbits: int, sat: bool +) -> None: + fi = format_info_ocp_e5m2 v0 = decode_ndarray(fi, np.arange(255)) v1 = decode_ndarray(fi, np.arange(255) + 1) @@ -554,32 +564,34 @@ def test_stochastic_rounding_scalar_eq_array(rnd: RoundMode) -> None: v0 = v0[ok] v1 = v1[ok] - srnumbits = 3 - for srbits in range(2**srnumbits): - for alpha in (0, 0.3, 0.5, 0.6, 0.9, 1.25): - v = _linterp(v0, v1, alpha) - assert np.isfinite(v).all() - val_array = round_ndarray( + for alpha in (0, 0.3, 0.5, 0.6, 0.7, 0.9, 1.25): + v = _linterp(v0, v1, alpha) + assert np.isfinite(v).all() + srbits = np.random.randint(0, 2**srnumbits, v.shape) + + val_array = round_ndarray( + fi, + v, + rnd, + sat=sat, + srbits=srbits, + srnumbits=srnumbits, + ) + + val_scalar = [ + round_float( fi, - v, + vi, rnd, - sat=False, - srbits=np.asarray(srbits), + sat=sat, + srbits=srbitsi, srnumbits=srnumbits, ) + for vi, srbitsi in zip(v, srbits) + ] + + np.testing.assert_equal(val_array, val_scalar) - val_scalar = [ - round_float( - fi, - v, - rnd, - sat=False, - srbits=srbits, - srnumbits=srnumbits, - ) - for v in v - ] - if alpha < 1.0: - assert ((val_array == v0) | (val_array == v1)).all() - - np.testing.assert_equal(val_array, val_scalar) + # Ensure faithful rounding + if alpha < 1.0: + assert ((val_array == v0) | (val_array == v1)).all()