|
| 1 | +from functools import partial |
| 2 | + |
1 | 3 | import jax
|
2 | 4 | import jax.numpy as jnp
|
3 | 5 | import numpy as np
|
4 | 6 | import scipy
|
5 | 7 |
|
| 8 | +from jaxoplanet.core.limb_dark import light_curve as _limb_dark_light_curve |
6 | 9 | from jaxoplanet.starry.core.basis import A1, A2_inv, U
|
7 | 10 | from jaxoplanet.starry.core.polynomials import Pijk
|
8 | 11 | from jaxoplanet.starry.core.rotation import left_project
|
|
11 | 14 | from jaxoplanet.starry.system_observable import system_observable
|
12 | 15 |
|
13 | 16 |
|
14 |
| -def _design_matrix_elements( |
15 |
| - surface: Surface, |
16 |
| - r: float | None = None, |
17 |
| - x: float | None = None, |
18 |
| - y: float | None = None, |
19 |
| - z: float | None = None, |
20 |
| - theta: float | None = None, |
21 |
| - order: int = 20, |
22 |
| - higher_precision: bool = False, |
23 |
| - rv: bool = False, |
24 |
| -): |
25 |
| - if higher_precision: |
26 |
| - try: |
27 |
| - from jaxoplanet.starry.multiprecision import ( |
28 |
| - basis as basis_mp, |
29 |
| - utils as utils_mp, |
30 |
| - ) |
31 |
| - except ImportError as e: |
32 |
| - raise ImportError( |
33 |
| - "The `mpmath` Python package is required for higher_precision=True." |
34 |
| - ) from e |
35 |
| - |
36 |
| - total_deg = surface.deg + (surface.vdeg if rv else 0) |
37 |
| - |
38 |
| - rT_deg = rT(total_deg) |
39 |
| - |
40 |
| - x = 0.0 if x is None else x |
41 |
| - y = 0.0 if y is None else y |
42 |
| - z = 0.0 if z is None else z |
43 |
| - |
44 |
| - # no occulting body |
45 |
| - if r is None: |
46 |
| - b_rot = True |
47 |
| - theta_z = 0.0 |
48 |
| - design_matrix_p = rT_deg |
49 |
| - |
50 |
| - # occulting body |
51 |
| - else: |
52 |
| - b = jnp.sqrt(jnp.square(x) + jnp.square(y)) |
53 |
| - b_rot = jnp.logical_or(jnp.greater_equal(b, 1.0 + r), jnp.less_equal(z, 0.0)) |
54 |
| - b_occ = jnp.logical_not(b_rot) |
55 |
| - theta_z = jnp.arctan2(x, y) |
56 |
| - |
57 |
| - # trick to avoid nan `x=jnp.where...` grad caused by nan sT |
58 |
| - r = jnp.where(b_rot, 1.0, r) |
59 |
| - b = jnp.where(b_rot, 1.0, b) |
60 |
| - |
61 |
| - sT = solution_vector(total_deg, order=order)(b, r) |
62 |
| - |
63 |
| - if total_deg > 0: |
64 |
| - if higher_precision: |
65 |
| - A2 = np.atleast_2d(utils_mp.to_numpy(basis_mp.A2(total_deg))) |
66 |
| - else: |
67 |
| - A2 = scipy.sparse.linalg.inv(A2_inv(total_deg)) |
68 |
| - A2 = jax.experimental.sparse.BCOO.from_scipy_sparse(A2) |
69 |
| - else: |
70 |
| - A2 = jnp.array([[1]]) |
71 |
| - |
72 |
| - design_matrix_p = jnp.where(b_occ, sT @ A2, rT_deg) |
73 |
| - |
74 |
| - if surface.ydeg == 0: |
75 |
| - rotated_y = surface.y.todense() |
76 |
| - else: |
77 |
| - rotated_y = left_project( |
78 |
| - surface.ydeg, |
79 |
| - surface._inc, |
80 |
| - surface._obl, |
81 |
| - theta, |
82 |
| - theta_z, |
83 |
| - surface.y.todense(), |
84 |
| - ) |
85 |
| - |
86 |
| - # limb darkening |
87 |
| - if surface.udeg == 0: |
88 |
| - p_u = Pijk.from_dense(jnp.array([1])) |
89 |
| - else: |
90 |
| - u = jnp.array([1, *surface.u]) |
91 |
| - p_u = Pijk.from_dense(u @ U(surface.udeg), degree=surface.udeg) |
92 |
| - |
93 |
| - # surface map * limb darkening map |
94 |
| - if higher_precision: |
95 |
| - A1_val = np.atleast_2d(utils_mp.to_numpy(basis_mp.A1(surface.ydeg))) |
96 |
| - else: |
97 |
| - A1_val = jax.experimental.sparse.BCOO.from_scipy_sparse(A1(surface.ydeg)) |
98 |
| - |
99 |
| - p_y = Pijk.from_dense(A1_val @ rotated_y, degree=surface.ydeg) |
100 |
| - |
101 |
| - norm = np.pi / (p_u.tosparse() @ rT(surface.udeg)) |
102 |
| - |
103 |
| - return design_matrix_p, p_y, p_u, norm, b_occ |
104 |
| - |
105 |
| - |
106 | 17 | def surface_light_curve(
|
107 | 18 | surface: Surface,
|
108 | 19 | r: float | None = None,
|
@@ -165,13 +76,26 @@ def surface_light_curve(
|
165 | 76 | b = jnp.sqrt(jnp.square(x) + jnp.square(y))
|
166 | 77 | b_rot = jnp.logical_or(jnp.greater_equal(b, 1.0 + r), jnp.less_equal(z, 0.0))
|
167 | 78 | b_occ = jnp.logical_not(b_rot)
|
168 |
| - theta_z = jnp.arctan2(x, y) |
169 | 79 |
|
170 | 80 | # trick to avoid nan `x=jnp.where...` grad caused by nan sT
|
171 | 81 | r = jnp.where(b_rot, 1.0, r)
|
172 | 82 | b = jnp.where(b_rot, 1.0, b)
|
173 | 83 |
|
174 |
| - sT = solution_vector(total_deg, order=order)(b, r) |
| 84 | + if surface.ydeg == 0: |
| 85 | + if surface.udeg == 0: |
| 86 | + ld_u = jnp.array([]) |
| 87 | + else: |
| 88 | + ld_u = jnp.concatenate( |
| 89 | + [jnp.atleast_1d(jnp.asarray(u_)) for u_ in surface.u], axis=0 |
| 90 | + ) |
| 91 | + |
| 92 | + lc_func = partial(_limb_dark_light_curve, ld_u, order=order) |
| 93 | + lc = lc_func(b, r) |
| 94 | + return 1.0 + jnp.where(b_occ, lc, 0) |
| 95 | + |
| 96 | + else: |
| 97 | + theta_z = jnp.arctan2(x, y) |
| 98 | + sT = solution_vector(total_deg, order=order)(b, r) |
175 | 99 |
|
176 | 100 | if total_deg > 0:
|
177 | 101 | if higher_precision:
|
@@ -210,10 +134,8 @@ def surface_light_curve(
|
210 | 134 | A1_val = jax.experimental.sparse.BCOO.from_scipy_sparse(A1(surface.ydeg))
|
211 | 135 |
|
212 | 136 | p_y = Pijk.from_dense(A1_val @ rotated_y, degree=surface.ydeg)
|
213 |
| - |
214 |
| - norm = np.pi / (p_u.tosparse() @ rT(surface.udeg)) |
215 |
| - |
216 | 137 | p_yu = p_y * p_u
|
| 138 | + norm = np.pi / (p_u.tosparse() @ rT(surface.udeg)) |
217 | 139 |
|
218 | 140 | return surface.amplitude * (p_yu.tosparse() @ design_matrix_p) * norm
|
219 | 141 |
|
|
0 commit comments