Skip to content

Commit 29e45a9

Browse files
feat: use limb darkening light curve in system starry light curve (#261)
* feat: use limb darkening light curve in system starry light curve if central.ydeg is 0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: use jax assert_allclose --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 866fffc commit 29e45a9

File tree

3 files changed

+44
-122
lines changed

3 files changed

+44
-122
lines changed

src/jaxoplanet/starry/core/solution.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,7 @@ def p_integral(order: int, l_max: int, b: Array, r: Array, kappa0: Array) -> Arr
121121

122122
# low order variables
123123
low_order = np.min([order, 20])
124-
zeros = jnp.zeros(order - low_order)
125124
roots, low_weights = roots_legendre(low_order)
126-
low_weights = jnp.hstack((low_weights, zeros))
127125
phi = rng * (roots + 1)
128126
low_s2 = jnp.square(jnp.sin(phi))
129127
low_a1 = low_s2 - jnp.square(low_s2)
@@ -139,9 +137,12 @@ def p_integral(order: int, l_max: int, b: Array, r: Array, kappa0: Array) -> Arr
139137
high_a2 = jnp.where(r_cond, 0, delta + high_s2)
140138
high_a4 = 1 - 2 * high_s2
141139

142-
indices = []
143-
weights = []
144-
integrand = []
140+
low_indices = []
141+
low_integrand = []
142+
143+
high_indices = []
144+
high_integrand = []
145+
145146
n = 0
146147

147148
for l in range(l_max + 1): # noqa
@@ -160,8 +161,8 @@ def p_integral(order: int, l_max: int, b: Array, r: Array, kappa0: Array) -> Arr
160161
result = (
161162
2 * r * (r - b * c) * (1 - z2 * zero_safe_sqrt(z2)) / (3 * omz2)
162163
)
163-
integrand.append(jnp.where(cond, 0, 2 * result))
164-
weights.append(high_weights)
164+
high_integrand.append(jnp.where(cond, 0, 2 * result))
165+
high_indices.append(n)
165166

166167
elif mu % 2 == 0 and (mu // 2) % 2 == 0:
167168
f = (
@@ -170,18 +171,18 @@ def p_integral(order: int, l_max: int, b: Array, r: Array, kappa0: Array) -> Arr
170171
* low_a1 ** (0.25 * (mu + 4))
171172
* low_a2 ** (0.5 * nu)
172173
)
173-
integrand.append(2 * jnp.hstack((f, zeros)))
174-
weights.append(low_weights)
174+
low_integrand.append(2 * f)
175+
low_indices.append(n)
175176

176177
elif mu == 1 and l % 2 == 0:
177178
f = high_fa3 * high_a1 ** (l // 2 - 1) * high_a4
178-
integrand.append(2 * f)
179-
weights.append(high_weights)
179+
high_integrand.append(2 * f)
180+
high_indices.append(n)
180181

181182
elif mu == 1:
182183
f = high_fa3 * high_a1 ** ((l - 3) // 2) * high_a2 * high_a4
183-
integrand.append(2 * f)
184-
weights.append(high_weights)
184+
high_integrand.append(2 * f)
185+
high_indices.append(n)
185186

186187
elif (mu - 1) % 2 == 0 and ((mu - 1) // 2) % 2 == 0:
187188
f = (
@@ -190,23 +191,26 @@ def p_integral(order: int, l_max: int, b: Array, r: Array, kappa0: Array) -> Arr
190191
* high_a1 ** ((mu - 1) // 4)
191192
* high_a2 ** (0.5 * (nu - 1))
192193
)
193-
integrand.append(2 * f)
194-
weights.append(high_weights)
194+
high_integrand.append(2 * f)
195+
high_indices.append(n)
195196

196197
else:
197198
n += 1
198199
continue
199200

200-
indices.append(n)
201201
n += 1
202202

203-
indices = np.stack(indices)
204-
weights = jnp.stack(weights)
203+
low_indices = np.stack(low_indices)
204+
high_indices = np.stack(high_indices)
205+
206+
low_P0 = rng * jnp.sum(jnp.stack(low_integrand) * low_weights, axis=1)
207+
high_P0 = rng * jnp.sum(jnp.stack(high_integrand) * high_weights, axis=1)
205208

206-
P0 = rng * jnp.sum(jnp.stack(integrand) * weights, axis=1)
207209
P = jnp.zeros(l_max**2 + 2 * l_max + 1)
210+
P = P.at[low_indices].set(low_P0)
211+
P = P.at[high_indices].set(high_P0)
208212

209-
return P.at[indices].set(P0)
213+
return P
210214

211215

212216
def rT(lmax: int) -> Array:

src/jaxoplanet/starry/light_curves.py

Lines changed: 19 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from functools import partial
2+
13
import jax
24
import jax.numpy as jnp
35
import numpy as np
46
import scipy
57

8+
from jaxoplanet.core.limb_dark import light_curve as _limb_dark_light_curve
69
from jaxoplanet.starry.core.basis import A1, A2_inv, U
710
from jaxoplanet.starry.core.polynomials import Pijk
811
from jaxoplanet.starry.core.rotation import left_project
@@ -11,98 +14,6 @@
1114
from jaxoplanet.starry.system_observable import system_observable
1215

1316

14-
def _design_matrix_elements(
15-
surface: Surface,
16-
r: float | None = None,
17-
x: float | None = None,
18-
y: float | None = None,
19-
z: float | None = None,
20-
theta: float | None = None,
21-
order: int = 20,
22-
higher_precision: bool = False,
23-
rv: bool = False,
24-
):
25-
if higher_precision:
26-
try:
27-
from jaxoplanet.starry.multiprecision import (
28-
basis as basis_mp,
29-
utils as utils_mp,
30-
)
31-
except ImportError as e:
32-
raise ImportError(
33-
"The `mpmath` Python package is required for higher_precision=True."
34-
) from e
35-
36-
total_deg = surface.deg + (surface.vdeg if rv else 0)
37-
38-
rT_deg = rT(total_deg)
39-
40-
x = 0.0 if x is None else x
41-
y = 0.0 if y is None else y
42-
z = 0.0 if z is None else z
43-
44-
# no occulting body
45-
if r is None:
46-
b_rot = True
47-
theta_z = 0.0
48-
design_matrix_p = rT_deg
49-
50-
# occulting body
51-
else:
52-
b = jnp.sqrt(jnp.square(x) + jnp.square(y))
53-
b_rot = jnp.logical_or(jnp.greater_equal(b, 1.0 + r), jnp.less_equal(z, 0.0))
54-
b_occ = jnp.logical_not(b_rot)
55-
theta_z = jnp.arctan2(x, y)
56-
57-
# trick to avoid nan `x=jnp.where...` grad caused by nan sT
58-
r = jnp.where(b_rot, 1.0, r)
59-
b = jnp.where(b_rot, 1.0, b)
60-
61-
sT = solution_vector(total_deg, order=order)(b, r)
62-
63-
if total_deg > 0:
64-
if higher_precision:
65-
A2 = np.atleast_2d(utils_mp.to_numpy(basis_mp.A2(total_deg)))
66-
else:
67-
A2 = scipy.sparse.linalg.inv(A2_inv(total_deg))
68-
A2 = jax.experimental.sparse.BCOO.from_scipy_sparse(A2)
69-
else:
70-
A2 = jnp.array([[1]])
71-
72-
design_matrix_p = jnp.where(b_occ, sT @ A2, rT_deg)
73-
74-
if surface.ydeg == 0:
75-
rotated_y = surface.y.todense()
76-
else:
77-
rotated_y = left_project(
78-
surface.ydeg,
79-
surface._inc,
80-
surface._obl,
81-
theta,
82-
theta_z,
83-
surface.y.todense(),
84-
)
85-
86-
# limb darkening
87-
if surface.udeg == 0:
88-
p_u = Pijk.from_dense(jnp.array([1]))
89-
else:
90-
u = jnp.array([1, *surface.u])
91-
p_u = Pijk.from_dense(u @ U(surface.udeg), degree=surface.udeg)
92-
93-
# surface map * limb darkening map
94-
if higher_precision:
95-
A1_val = np.atleast_2d(utils_mp.to_numpy(basis_mp.A1(surface.ydeg)))
96-
else:
97-
A1_val = jax.experimental.sparse.BCOO.from_scipy_sparse(A1(surface.ydeg))
98-
99-
p_y = Pijk.from_dense(A1_val @ rotated_y, degree=surface.ydeg)
100-
101-
norm = np.pi / (p_u.tosparse() @ rT(surface.udeg))
102-
103-
return design_matrix_p, p_y, p_u, norm, b_occ
104-
105-
10617
def surface_light_curve(
10718
surface: Surface,
10819
r: float | None = None,
@@ -165,13 +76,26 @@ def surface_light_curve(
16576
b = jnp.sqrt(jnp.square(x) + jnp.square(y))
16677
b_rot = jnp.logical_or(jnp.greater_equal(b, 1.0 + r), jnp.less_equal(z, 0.0))
16778
b_occ = jnp.logical_not(b_rot)
168-
theta_z = jnp.arctan2(x, y)
16979

17080
# trick to avoid nan `x=jnp.where...` grad caused by nan sT
17181
r = jnp.where(b_rot, 1.0, r)
17282
b = jnp.where(b_rot, 1.0, b)
17383

174-
sT = solution_vector(total_deg, order=order)(b, r)
84+
if surface.ydeg == 0:
85+
if surface.udeg == 0:
86+
ld_u = jnp.array([])
87+
else:
88+
ld_u = jnp.concatenate(
89+
[jnp.atleast_1d(jnp.asarray(u_)) for u_ in surface.u], axis=0
90+
)
91+
92+
lc_func = partial(_limb_dark_light_curve, ld_u, order=order)
93+
lc = lc_func(b, r)
94+
return 1.0 + jnp.where(b_occ, lc, 0)
95+
96+
else:
97+
theta_z = jnp.arctan2(x, y)
98+
sT = solution_vector(total_deg, order=order)(b, r)
17599

176100
if total_deg > 0:
177101
if higher_precision:
@@ -210,10 +134,8 @@ def surface_light_curve(
210134
A1_val = jax.experimental.sparse.BCOO.from_scipy_sparse(A1(surface.ydeg))
211135

212136
p_y = Pijk.from_dense(A1_val @ rotated_y, degree=surface.ydeg)
213-
214-
norm = np.pi / (p_u.tosparse() @ rT(surface.udeg))
215-
216137
p_yu = p_y * p_u
138+
norm = np.pi / (p_u.tosparse() @ rT(surface.udeg))
217139

218140
return surface.amplitude * (p_yu.tosparse() @ design_matrix_p) * norm
219141

tests/starry/light_curve_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -470,11 +470,7 @@ def test_EB_interchanged():
470470
flux_reversed = light_curve(system_2)(time + params["period"] / 2).sum(1)
471471

472472
# for some reason the assert_allclose wasn't catching error here
473-
np.testing.assert_allclose(
474-
flux_ordered,
475-
flux_reversed,
476-
atol=1e-6 if flux_ordered.dtype.name == "float32" else 1e-12,
477-
)
473+
assert_allclose(flux_ordered, flux_reversed)
478474

479475

480476
@pytest.mark.parametrize("order", [20, 100, 500])

0 commit comments

Comments
 (0)