Skip to content

Commit 7d171c9

Browse files
committed
Fixing merge conflict
2 parents b23ada6 + 0f8b68a commit 7d171c9

File tree

4 files changed

+221
-85
lines changed

4 files changed

+221
-85
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""This module contains models for computing and transforming light curve models"""
22

3-
from jaxoplanet.light_curves import exposure_time as exposure_time
3+
from jaxoplanet.light_curves import transforms as transforms
44
from jaxoplanet.light_curves.limb_dark import (
55
limb_dark_light_curve as limb_dark_light_curve,
66
)

src/jaxoplanet/light_curves/exposure_time.py

Lines changed: 0 additions & 84 deletions
This file was deleted.
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
__all__ = ["integrate", "interpolate"]
2+
3+
from functools import wraps
4+
from typing import Any, Optional, Union
5+
6+
import jax
7+
import jax.numpy as jnp
8+
import jpu.numpy as jnpu
9+
from jpu.core import is_quantity
10+
11+
from jaxoplanet import units
12+
from jaxoplanet.light_curves.types import LightCurveFunc
13+
from jaxoplanet.light_curves.utils import vectorize
14+
from jaxoplanet.types import Array, Quantity
15+
from jaxoplanet.units import unit_registry as ureg
16+
17+
try:
18+
from jax.extend import linear_util as lu
19+
except ImportError:
20+
from jax import linear_util as lu # type: ignore
21+
22+
23+
@units.quantity_input(exposure_time=ureg.d)
24+
def integrate(
25+
func: LightCurveFunc,
26+
exposure_time: Optional[Quantity] = None,
27+
order: int = 0,
28+
num_samples: int = 7,
29+
) -> LightCurveFunc:
30+
"""Transform a light curve function to apply exposure time integration
31+
32+
This transformation applies a fixed stencil numerical integration scheme to the input
33+
function ``func`` to convolve the light curve with a top hat exposure time centered
34+
on the input time, with a full width of ``exposure_time``.
35+
36+
The order of the integration scheme is set using the ``order`` parameter which must
37+
be ``0``, ``1``, or ``2``. The default (``0``) uses the "resampling" scheme discussed
38+
by `Kipping (2010) <https://arxiv.org/abs/1004.3741>`_. The higher order schemes
39+
``1`` and ``2`` apply the trapezoid and Simpson's rules respectively, but won't
40+
necessarily provide higher accuracy results because of discontinuities at the
41+
contact points.
42+
43+
In practice, the parameter ``num_samples`` which sets the number of function
44+
evaluations per integral has the most significant effect on the accuracy of this
45+
integral, trading off against higher computational cost.
46+
47+
Args:
48+
func: A light curve function which takes a time ``Quantity`` as the first
49+
argument
50+
exposure_time (Quantity): The exposure time (in days, by default)
51+
order (int): The order of the integration scheme as discussed above
52+
num_samples (int): The number of function evaluations made per integral,
53+
controlling the accuracy of the numerics
54+
55+
Returns:
56+
A new light curve function with the same signature as ``func``, computing the
57+
exposure time integrated flux
58+
"""
59+
if exposure_time is None:
60+
return func
61+
62+
if jnpu.ndim(exposure_time) != 0:
63+
raise ValueError(
64+
"The exposure time passed to 'integrate_exposure_time' has shape "
65+
f"{jnpu.shape(exposure_time)}, but a scalar was expected; "
66+
"To use exposure time integration with different exposures at different "
67+
"times, manually 'vmap' or 'vectorize' the function"
68+
)
69+
70+
# Ensure 'num_samples' is an odd number
71+
num_samples = int(num_samples)
72+
num_samples += 1 - num_samples % 2
73+
stencil = jnp.ones(num_samples)
74+
75+
# Construct exposure time integration stencil
76+
if order == 0:
77+
dt = jnp.linspace(-0.5, 0.5, 2 * num_samples + 1)[1:-1:2]
78+
elif order == 1:
79+
dt = jnp.linspace(-0.5, 0.5, num_samples)
80+
stencil = 2 * stencil
81+
stencil = stencil.at[0].set(1)
82+
stencil = stencil.at[-1].set(1)
83+
elif order == 2:
84+
dt = jnp.linspace(-0.5, 0.5, num_samples)
85+
stencil = stencil.at[1:-1:2].set(4)
86+
stencil = stencil.at[2:-1:2].set(2)
87+
else:
88+
raise ValueError(
89+
"The parameter 'order' in 'integrate_exposure_time' must be 0, 1, or 2"
90+
)
91+
dt = dt * exposure_time
92+
stencil /= jnp.sum(stencil)
93+
94+
@wraps(func)
95+
@units.quantity_input(time=ureg.d)
96+
@vectorize
97+
def wrapped(time: Quantity, *args: Any, **kwargs: Any) -> Union[Array, Quantity]:
98+
if jnpu.ndim(time) != 0:
99+
raise ValueError(
100+
"The time passed to 'integrate_exposure_time' has shape "
101+
f"{jnpu.shape(time)}, but a scalar was expected; "
102+
"this shouldn't typically happen so please open an issue "
103+
"on GitHub demonstrating the problem"
104+
)
105+
106+
f = lu.wrap_init(jax.vmap(func, in_axes=(0,) + (None,) * len(args)))
107+
f = apply_exposure_time_integration(f, stencil, dt) # type: ignore
108+
return f.call_wrapped(time, args, kwargs) # type: ignore
109+
110+
return wrapped
111+
112+
113+
@lu.transformation # type: ignore
114+
def apply_exposure_time_integration(stencil, dt, time, args, kwargs):
115+
result = yield (time + dt,) + args, kwargs
116+
yield jnpu.dot(stencil, result)
117+
118+
119+
@units.quantity_input(period=ureg.day, time_transit=ureg.day, duration=ureg.day)
120+
def interpolate(
121+
func: LightCurveFunc,
122+
*,
123+
period: Quantity,
124+
time_transit: Quantity,
125+
num_samples: int,
126+
duration: Optional[Quantity] = None,
127+
args: tuple[Any, ...] = (),
128+
kwargs: Optional[dict[str, Any]] = None,
129+
) -> LightCurveFunc:
130+
"""Transform a light curve function to pre-compute the model on a grid
131+
132+
Sometimes it can be useful to precompute the light curve on a grid near a transit,
133+
and then interpolate those computations to the required phases when computing the
134+
full model. This can speed things up a lot when you have many transits, or a lot of
135+
out of transit data. This transform uses linear interpolation.
136+
137+
.. note:: Unlike some other transforms, this function requires that any upstream
138+
``*args`` and ``**kwargs`` be passed directly to the transform, rather than when
139+
calling the transformed function. This is necessary because the model is
140+
pre-computed when it is tranformed.
141+
142+
Args:
143+
func: A light curve function which takes a time ``Quantity`` as the first
144+
argument
145+
period (Quantity): The period of the orbit. Used to wrap the input times into the
146+
domain of the pre-computed model
147+
time_transit (Quantity): The transit time of the orbit. Used to wrap the input
148+
times into the domain of the pre-computed model
149+
duration (Quantity): The duration centered on the transit to pre-compute. By
150+
default, the full period will be evaluated
151+
num_samples (int): The number of points in the time grid used for pre-computation
152+
args (tuple): Any extra positional arguments that should be passed to ``func``
153+
kwargs (dict): Any extra keyword arguments that should be passed to ``func``
154+
155+
Returns:
156+
A new light curve function with the same signature as ``func``, computing the
157+
flux by interpolating a pre-computed model
158+
"""
159+
160+
kwargs = kwargs or {}
161+
if duration is None:
162+
duration = period
163+
time_grid = time_transit + duration * jnp.linspace(-0.5, 0.5, num_samples)
164+
flux_grid = func(time_grid, *args, **kwargs)
165+
166+
if is_quantity(flux_grid):
167+
flux_magnitude = flux_grid.magnitude
168+
flux_units = flux_grid.units
169+
else:
170+
flux_magnitude = flux_grid
171+
flux_units = None
172+
173+
@wraps(func)
174+
@units.quantity_input(time=ureg.d)
175+
@vectorize
176+
def wrapped(time: Quantity, *args: Any, **kwargs: Any) -> Union[Array, Quantity]:
177+
del args, kwargs
178+
time_wrapped = (
179+
jnpu.mod(time - time_transit + 0.5 * period, period)
180+
+ 0.5 * period
181+
+ time_transit
182+
)
183+
flux = jnp.interp(
184+
time_wrapped.magnitude,
185+
time_grid.magnitude,
186+
flux_magnitude,
187+
left=flux_magnitude[0],
188+
right=flux_magnitude[-1],
189+
period=period.magnitude,
190+
)
191+
if flux_units is None:
192+
return flux
193+
else:
194+
return flux * flux_units
195+
196+
return wrapped

tests/light_curves/transforms_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import jax.numpy as jnp
2+
import pytest
3+
4+
from jaxoplanet.light_curves import transforms
5+
from jaxoplanet.test_utils import assert_allclose
6+
from jaxoplanet.units import quantity_input, unit_registry as ureg
7+
8+
9+
@quantity_input(time=ureg.day)
10+
def lc_func1(time):
11+
return jnp.ones_like(time.magnitude)
12+
13+
14+
@quantity_input(time=ureg.day)
15+
def lc_func2(time):
16+
return jnp.stack([0.5 * time.magnitude + 0.1, -1.5 * time.magnitude + 3.6], axis=-1)
17+
18+
19+
@pytest.mark.parametrize("order", [0, 1, 2])
20+
@pytest.mark.parametrize("lc_func", [lc_func1, lc_func2])
21+
def test_integrate_invariant(order, lc_func):
22+
time = jnp.linspace(0, 10, 50)
23+
int_func = transforms.integrate(lc_func, exposure_time=0.1, order=order)
24+
assert_allclose(int_func(time), lc_func(time))

0 commit comments

Comments
 (0)