Skip to content

Commit c82deb2

Browse files
author
jax authors
committed
Merge pull request #20373 from pearu:pearu/complex-plane-tests-fix
PiperOrigin-RevId: 618188699
2 parents 6695a85 + fdb5015 commit c82deb2

File tree

3 files changed

+581
-104
lines changed

3 files changed

+581
-104
lines changed

build/test-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ cloudpickle
44
colorama>=0.4.4
55
flatbuffers
66
hypothesis
7+
mpmath>=1.3
78
numpy>=1.22
89
pillow>=9.1.0
910
portpicker

jax/_src/test_util.py

Lines changed: 308 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1467,11 +1467,11 @@ def complex_plane_sample(dtype, size_re=10, size_im=None):
14671467
>>> print(complex_plane_sample(np.complex64, 0, 3))
14681468
[[-inf -infj 0. -infj inf -infj]
14691469
[-inf-3.4028235e+38j 0.-3.4028235e+38j inf-3.4028235e+38j]
1470-
[-inf-2.0000052e+00j 0.-2.0000052e+00j inf-2.0000052e+00j]
1470+
[-inf-2.0000000e+00j 0.-2.0000000e+00j inf-2.0000000e+00j]
14711471
[-inf-1.1754944e-38j 0.-1.1754944e-38j inf-1.1754944e-38j]
14721472
[-inf+0.0000000e+00j 0.+0.0000000e+00j inf+0.0000000e+00j]
14731473
[-inf+1.1754944e-38j 0.+1.1754944e-38j inf+1.1754944e-38j]
1474-
[-inf+2.0000052e+00j 0.+2.0000052e+00j inf+2.0000052e+00j]
1474+
[-inf+2.0000000e+00j 0.+2.0000000e+00j inf+2.0000000e+00j]
14751475
[-inf+3.4028235e+38j 0.+3.4028235e+38j inf+3.4028235e+38j]
14761476
[-inf +infj 0. +infj inf +infj]]
14771477
@@ -1481,16 +1481,18 @@ def complex_plane_sample(dtype, size_re=10, size_im=None):
14811481
finfo = np.finfo(dtype)
14821482

14831483
def make_axis_points(size):
1484-
logmin = np.log10(abs(finfo.min))
1485-
logtiny = np.log10(finfo.tiny)
1486-
logmax = np.log10(finfo.max)
1484+
prec_dps_ratio = 3.3219280948873626
1485+
logmin = logmax = finfo.maxexp / prec_dps_ratio
1486+
logtiny = finfo.minexp / prec_dps_ratio
14871487
axis_points = np.zeros(3 + 2 * size, dtype=finfo.dtype)
14881488

14891489
with warnings.catch_warnings():
14901490
# Silence RuntimeWarning: overflow encountered in cast
14911491
warnings.simplefilter("ignore")
1492-
axis_points[1:size + 1] = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype)
1493-
axis_points[-size - 1:-1] = np.logspace(logtiny, logmax, size, dtype=finfo.dtype)
1492+
half_neg_line = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype)
1493+
half_line = -half_neg_line[::-1]
1494+
axis_points[-size - 1:-1] = half_line
1495+
axis_points[1:size + 1] = half_neg_line
14941496

14951497
if size > 1:
14961498
axis_points[1] = finfo.min
@@ -1512,3 +1514,302 @@ def make_axis_points(size):
15121514
imag_part = imag_part.reshape((3 + 2 * size_im, -1)).repeat(3 + 2 * size_re, 1)
15131515

15141516
return real_part + imag_part
1517+
1518+
1519+
class vectorize_with_mpmath(np.vectorize):
1520+
"""Same as numpy.vectorize but using mpmath backend for function evaluation.
1521+
"""
1522+
1523+
map_float_to_complex = dict(float16='complex32', float32='complex64', float64='complex128', float128='complex256', longdouble='clongdouble')
1524+
map_complex_to_float = {v: k for k, v in map_float_to_complex.items()}
1525+
1526+
float_prec = dict(
1527+
# float16=11,
1528+
float32=24,
1529+
float64=53,
1530+
# float128=113,
1531+
# longdouble=113
1532+
)
1533+
1534+
float_minexp = dict(
1535+
float16=-14,
1536+
float32=-126,
1537+
float64=-1022,
1538+
float128=-16382
1539+
)
1540+
1541+
float_maxexp = dict(
1542+
float16=16,
1543+
float32=128,
1544+
float64=1024,
1545+
float128=16384,
1546+
)
1547+
1548+
def __init__(self, *args, **kwargs):
1549+
mpmath = kwargs.pop('mpmath', None)
1550+
if mpmath is None:
1551+
raise ValueError('vectorize_with_mpmath: no mpmath argument specified')
1552+
self.extra_prec_multiplier = kwargs.pop('extra_prec_multiplier', 0)
1553+
self.extra_prec = kwargs.pop('extra_prec', 0)
1554+
self.mpmath = mpmath
1555+
self.contexts = dict()
1556+
self.contexts_inv = dict()
1557+
for fp_format, prec in self.float_prec.items():
1558+
ctx = self.mpmath.mp.clone()
1559+
ctx.prec = prec
1560+
self.contexts[fp_format] = ctx
1561+
self.contexts_inv[ctx] = fp_format
1562+
1563+
super().__init__(*args, **kwargs)
1564+
1565+
def get_context(self, x):
1566+
if isinstance(x, (np.ndarray, np.floating, np.complexfloating)):
1567+
fp_format = str(x.dtype)
1568+
fp_format = self.map_complex_to_float.get(fp_format, fp_format)
1569+
return self.contexts[fp_format]
1570+
raise NotImplementedError(f'get mpmath context from {type(x).__name__} instance')
1571+
1572+
def nptomp(self, x):
1573+
"""Convert numpy array/scalar to an array/instance of mpmath number type.
1574+
"""
1575+
if isinstance(x, np.ndarray):
1576+
return np.fromiter(map(self.nptomp, x.flatten()), dtype=object).reshape(x.shape)
1577+
elif isinstance(x, np.floating):
1578+
mpmath = self.mpmath
1579+
ctx = self.get_context(x)
1580+
prec, rounding = ctx._prec_rounding
1581+
if np.isposinf(x):
1582+
return ctx.make_mpf(mpmath.libmp.finf)
1583+
elif np.isneginf(x):
1584+
return ctx.make_mpf(mpmath.libmp.fninf)
1585+
elif np.isnan(x):
1586+
return ctx.make_mpf(mpmath.libmp.fnan)
1587+
elif np.isfinite(x):
1588+
mantissa, exponent = np.frexp(x)
1589+
man = int(np.ldexp(mantissa, prec))
1590+
exp = int(exponent - prec)
1591+
r = ctx.make_mpf(mpmath.libmp.from_man_exp(man, exp, prec, rounding))
1592+
assert ctx.isfinite(r), r._mpf_
1593+
return r
1594+
elif isinstance(x, np.complexfloating):
1595+
re, im = self.nptomp(x.real), self.nptomp(x.imag)
1596+
return re.context.make_mpc((re._mpf_, im._mpf_))
1597+
raise NotImplementedError(f'convert {type(x).__name__} instance to mpmath number type')
1598+
1599+
def mptonp(self, x):
1600+
"""Convert mpmath instance to numpy array/scalar type.
1601+
"""
1602+
if isinstance(x, np.ndarray) and x.dtype.kind == 'O':
1603+
x_flat = x.flatten()
1604+
item = x_flat[0]
1605+
ctx = item.context
1606+
fp_format = self.contexts_inv[ctx]
1607+
if isinstance(item, ctx.mpc):
1608+
dtype = getattr(np, self.map_float_to_complex[fp_format])
1609+
elif isinstance(item, ctx.mpf):
1610+
dtype = getattr(np, fp_format)
1611+
else:
1612+
dtype = None
1613+
if dtype is not None:
1614+
return np.fromiter(map(self.mptonp, x_flat), dtype=dtype).reshape(x.shape)
1615+
elif isinstance(x, self.mpmath.ctx_mp.mpnumeric):
1616+
ctx = x.context
1617+
if isinstance(x, ctx.mpc):
1618+
fp_format = self.contexts_inv[ctx]
1619+
dtype = getattr(np, self.map_float_to_complex[fp_format])
1620+
r = dtype().reshape(1).view(getattr(np, fp_format))
1621+
r[0] = self.mptonp(x.real)
1622+
r[1] = self.mptonp(x.imag)
1623+
return r.view(dtype)[0]
1624+
elif isinstance(x, ctx.mpf):
1625+
fp_format = self.contexts_inv[ctx]
1626+
dtype = getattr(np, fp_format)
1627+
if ctx.isfinite(x):
1628+
sign, man, exp, bc = self.mpmath.libmp.normalize(*x._mpf_, *ctx._prec_rounding)
1629+
assert bc >= 0, (sign, man, exp, bc, x._mpf_)
1630+
if exp + bc < self.float_minexp[fp_format]:
1631+
return -ctx.zero if sign else ctx.zero
1632+
if exp + bc > self.float_maxexp[fp_format]:
1633+
return ctx.ninf if sign else ctx.inf
1634+
man = dtype(-man if sign else man)
1635+
r = np.ldexp(man, exp)
1636+
assert np.isfinite(r), (x, r, x._mpf_, man)
1637+
return r
1638+
elif ctx.isnan(x):
1639+
return dtype(np.nan)
1640+
elif ctx.isinf(x):
1641+
return dtype(-np.inf if x._mpf_[0] else np.inf)
1642+
raise NotImplementedError(f'convert {type(x)} instance to numpy floating point type')
1643+
1644+
def __call__(self, *args, **kwargs):
1645+
mp_args = []
1646+
context = None
1647+
for a in args:
1648+
if isinstance(a, (np.ndarray, np.floating, np.complexfloating)):
1649+
mp_args.append(self.nptomp(a))
1650+
if context is None:
1651+
context = self.get_context(a)
1652+
else:
1653+
assert context is self.get_context(a)
1654+
else:
1655+
mp_args.append(a)
1656+
1657+
extra_prec = int(context.prec * self.extra_prec_multiplier) + self.extra_prec
1658+
with context.extraprec(extra_prec):
1659+
result = super().__call__(*mp_args, **kwargs)
1660+
1661+
if isinstance(result, tuple):
1662+
lst = []
1663+
for r in result:
1664+
if ((isinstance(r, np.ndarray) and r.dtype.kind == 'O')
1665+
or isinstance(r, self.mpmath.ctx_mp.mpnumeric)):
1666+
r = self.mptonp(r)
1667+
lst.append(r)
1668+
return tuple(lst)
1669+
1670+
if ((isinstance(result, np.ndarray) and result.dtype.kind == 'O')
1671+
or isinstance(result, self.mpmath.ctx_mp.mpnumeric)):
1672+
return self.mptonp(result)
1673+
1674+
return result
1675+
1676+
1677+
class numpy_with_mpmath:
1678+
"""Namespace of universal functions on numpy arrays that use mpmath
1679+
backend for evaluation and return numpy arrays as outputs.
1680+
"""
1681+
1682+
_provides = [
1683+
'abs', 'absolute', 'sqrt', 'exp', 'expm1', 'exp2',
1684+
'log', 'log1p', 'log10', 'log2',
1685+
'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan',
1686+
'sinh', 'cosh', 'tanh', 'arcsinh', 'arccosh', 'arctanh',
1687+
'square', 'positive', 'negative', 'conjugate', 'sign', 'sinc',
1688+
'normalize',
1689+
]
1690+
1691+
_mp_names = dict(
1692+
abs='absmin', absolute='absmin',
1693+
log='ln',
1694+
arcsin='asin', arccos='acos', arctan='atan',
1695+
arcsinh='asinh', arccosh='acosh', arctanh='atanh',
1696+
)
1697+
1698+
def __init__(self, mpmath, extra_prec_multiplier=0, extra_prec=0):
1699+
self.mpmath = mpmath
1700+
1701+
for name in self._provides:
1702+
mp_name = self._mp_names.get(name, name)
1703+
1704+
if hasattr(self, name):
1705+
op = getattr(self, name)
1706+
else:
1707+
1708+
def op(x, mp_name=mp_name):
1709+
return getattr(x.context, mp_name)(x)
1710+
1711+
setattr(self, name, vectorize_with_mpmath(op, mpmath=mpmath, extra_prec_multiplier=extra_prec_multiplier, extra_prec=extra_prec))
1712+
1713+
# The following function methods operate on mpmath number instances.
1714+
# The corresponding function names must be listed in
1715+
# numpy_with_mpmath._provides list.
1716+
1717+
def square(self, x):
1718+
return x * x
1719+
1720+
def positive(self, x):
1721+
return x
1722+
1723+
def negative(self, x):
1724+
return -x
1725+
1726+
def sqrt(self, x):
1727+
ctx = x.context
1728+
# workaround mpmath bugs:
1729+
if isinstance(x, ctx.mpc):
1730+
if ctx.isinf(x.real) and ctx.isinf(x.imag):
1731+
if x.real > 0: return x
1732+
ninf = x.real
1733+
inf = -ninf
1734+
if x.imag > 0: return ctx.make_mpc((inf._mpf_, inf._mpf_))
1735+
return ctx.make_mpc((inf._mpf_, inf._mpf_))
1736+
elif ctx.isfinite(x.real) and ctx.isinf(x.imag):
1737+
if x.imag > 0:
1738+
inf = x.imag
1739+
return ctx.make_mpc((inf._mpf_, inf._mpf_))
1740+
else:
1741+
ninf = x.imag
1742+
inf = -ninf
1743+
return ctx.make_mpc((inf._mpf_, ninf._mpf_))
1744+
return ctx.sqrt(x)
1745+
1746+
def expm1(self, x):
1747+
return x.context.expm1(x)
1748+
1749+
def log2(self, x):
1750+
return x.context.ln(x) / x.context.ln2
1751+
1752+
def log10(self, x):
1753+
return x.context.ln(x) / x.context.ln10
1754+
1755+
def exp2(self, x):
1756+
return x.context.exp(x * x.context.ln2)
1757+
1758+
def normalize(self, exact, reference, value):
1759+
"""Normalize reference and value using precision defined by the
1760+
difference of exact and reference.
1761+
"""
1762+
def worker(ctx, s, e, r, v):
1763+
ss, sm, se, sbc = s._mpf_
1764+
es, em, ee, ebc = e._mpf_
1765+
rs, rm, re, rbc = r._mpf_
1766+
vs, vm, ve, vbc = v._mpf_
1767+
1768+
if not (ctx.isfinite(e) and ctx.isfinite(r) and ctx.isfinite(v)):
1769+
return r, v
1770+
1771+
me = min(se, ee, re, ve)
1772+
1773+
# transform mantissa parts to the same exponent base
1774+
sm_e = sm << (se - me)
1775+
em_e = em << (ee - me)
1776+
rm_e = rm << (re - me)
1777+
vm_e = vm << (ve - me)
1778+
1779+
# find matching higher and non-matching lower bits of e and r
1780+
sm_b = bin(sm_e)[2:] if sm_e else ''
1781+
em_b = bin(em_e)[2:] if em_e else ''
1782+
rm_b = bin(rm_e)[2:] if rm_e else ''
1783+
vm_b = bin(vm_e)[2:] if vm_e else ''
1784+
1785+
m = max(len(sm_b), len(em_b), len(rm_b), len(vm_b))
1786+
em_b = '0' * (m - len(em_b)) + em_b
1787+
rm_b = '0' * (m - len(rm_b)) + rm_b
1788+
1789+
c1 = 0
1790+
for b0, b1 in zip(em_b, rm_b):
1791+
if b0 != b1:
1792+
break
1793+
c1 += 1
1794+
c0 = m - c1
1795+
1796+
# truncate r and v mantissa
1797+
rm_m = rm_e >> c0
1798+
vm_m = vm_e >> c0
1799+
1800+
# normalized r and v
1801+
nr = ctx.make_mpf((rs, rm_m, -c1, len(bin(rm_m)) - 2)) if rm_m else (-ctx.zero if rs else ctx.zero)
1802+
nv = ctx.make_mpf((vs, vm_m, -c1, len(bin(vm_m)) - 2)) if vm_m else (-ctx.zero if vs else ctx.zero)
1803+
1804+
return nr, nv
1805+
1806+
ctx = exact.context
1807+
scale = abs(exact)
1808+
if isinstance(exact, ctx.mpc):
1809+
rr, rv = worker(ctx, scale, exact.real, reference.real, value.real)
1810+
ir, iv = worker(ctx, scale, exact.imag, reference.imag, value.imag)
1811+
return ctx.make_mpc((rr._mpf_, ir._mpf_)), ctx.make_mpc((rv._mpf_, iv._mpf_))
1812+
elif isinstance(exact, ctx.mpf):
1813+
return worker(ctx, scale, exact, reference, value)
1814+
else:
1815+
assert 0 # unreachable

0 commit comments

Comments
 (0)