|
| 1 | +__all__ = ["integrate", "interpolate"] |
| 2 | + |
| 3 | +from functools import wraps |
| 4 | +from typing import Any, Optional, Union |
| 5 | + |
| 6 | +import jax |
| 7 | +import jax.numpy as jnp |
| 8 | +import jpu.numpy as jnpu |
| 9 | +from jpu.core import is_quantity |
| 10 | + |
| 11 | +from jaxoplanet import units |
| 12 | +from jaxoplanet.light_curves.types import LightCurveFunc |
| 13 | +from jaxoplanet.light_curves.utils import vectorize |
| 14 | +from jaxoplanet.types import Array, Quantity |
| 15 | +from jaxoplanet.units import unit_registry as ureg |
| 16 | + |
| 17 | +try: |
| 18 | + from jax.extend import linear_util as lu |
| 19 | +except ImportError: |
| 20 | + from jax import linear_util as lu # type: ignore |
| 21 | + |
| 22 | + |
| 23 | +@units.quantity_input(exposure_time=ureg.d) |
| 24 | +def integrate( |
| 25 | + func: LightCurveFunc, |
| 26 | + exposure_time: Optional[Quantity] = None, |
| 27 | + order: int = 0, |
| 28 | + num_samples: int = 7, |
| 29 | +) -> LightCurveFunc: |
| 30 | + """Transform a light curve function to apply exposure time integration |
| 31 | +
|
| 32 | + This transformation applies a fixed stencil numerical integration scheme to the input |
| 33 | + function ``func`` to convolve the light curve with a top hat exposure time centered |
| 34 | + on the input time, with a full width of ``exposure_time``. |
| 35 | +
|
| 36 | + The order of the integration scheme is set using the ``order`` parameter which must |
| 37 | + be ``0``, ``1``, or ``2``. The default (``0``) uses the "resampling" scheme discussed |
| 38 | + by `Kipping (2010) <https://arxiv.org/abs/1004.3741>`_. The higher order schemes |
| 39 | + ``1`` and ``2`` apply the trapezoid and Simpson's rules respectively, but won't |
| 40 | + necessarily provide higher accuracy results because of discontinuities at the |
| 41 | + contact points. |
| 42 | +
|
| 43 | + In practice, the parameter ``num_samples`` which sets the number of function |
| 44 | + evaluations per integral has the most significant effect on the accuracy of this |
| 45 | + integral, trading off against higher computational cost. |
| 46 | +
|
| 47 | + Args: |
| 48 | + func: A light curve function which takes a time ``Quantity`` as the first |
| 49 | + argument |
| 50 | + exposure_time (Quantity): The exposure time (in days, by default) |
| 51 | + order (int): The order of the integration scheme as discussed above |
| 52 | + num_samples (int): The number of function evaluations made per integral, |
| 53 | + controlling the accuracy of the numerics |
| 54 | +
|
| 55 | + Returns: |
| 56 | + A new light curve function with the same signature as ``func``, computing the |
| 57 | + exposure time integrated flux |
| 58 | + """ |
| 59 | + if exposure_time is None: |
| 60 | + return func |
| 61 | + |
| 62 | + if jnpu.ndim(exposure_time) != 0: |
| 63 | + raise ValueError( |
| 64 | + "The exposure time passed to 'integrate_exposure_time' has shape " |
| 65 | + f"{jnpu.shape(exposure_time)}, but a scalar was expected; " |
| 66 | + "To use exposure time integration with different exposures at different " |
| 67 | + "times, manually 'vmap' or 'vectorize' the function" |
| 68 | + ) |
| 69 | + |
| 70 | + # Ensure 'num_samples' is an odd number |
| 71 | + num_samples = int(num_samples) |
| 72 | + num_samples += 1 - num_samples % 2 |
| 73 | + stencil = jnp.ones(num_samples) |
| 74 | + |
| 75 | + # Construct exposure time integration stencil |
| 76 | + if order == 0: |
| 77 | + dt = jnp.linspace(-0.5, 0.5, 2 * num_samples + 1)[1:-1:2] |
| 78 | + elif order == 1: |
| 79 | + dt = jnp.linspace(-0.5, 0.5, num_samples) |
| 80 | + stencil = 2 * stencil |
| 81 | + stencil = stencil.at[0].set(1) |
| 82 | + stencil = stencil.at[-1].set(1) |
| 83 | + elif order == 2: |
| 84 | + dt = jnp.linspace(-0.5, 0.5, num_samples) |
| 85 | + stencil = stencil.at[1:-1:2].set(4) |
| 86 | + stencil = stencil.at[2:-1:2].set(2) |
| 87 | + else: |
| 88 | + raise ValueError( |
| 89 | + "The parameter 'order' in 'integrate_exposure_time' must be 0, 1, or 2" |
| 90 | + ) |
| 91 | + dt = dt * exposure_time |
| 92 | + stencil /= jnp.sum(stencil) |
| 93 | + |
| 94 | + @wraps(func) |
| 95 | + @units.quantity_input(time=ureg.d) |
| 96 | + @vectorize |
| 97 | + def wrapped(time: Quantity, *args: Any, **kwargs: Any) -> Union[Array, Quantity]: |
| 98 | + if jnpu.ndim(time) != 0: |
| 99 | + raise ValueError( |
| 100 | + "The time passed to 'integrate_exposure_time' has shape " |
| 101 | + f"{jnpu.shape(time)}, but a scalar was expected; " |
| 102 | + "this shouldn't typically happen so please open an issue " |
| 103 | + "on GitHub demonstrating the problem" |
| 104 | + ) |
| 105 | + |
| 106 | + f = lu.wrap_init(jax.vmap(func, in_axes=(0,) + (None,) * len(args))) |
| 107 | + f = apply_exposure_time_integration(f, stencil, dt) # type: ignore |
| 108 | + return f.call_wrapped(time, args, kwargs) # type: ignore |
| 109 | + |
| 110 | + return wrapped |
| 111 | + |
| 112 | + |
| 113 | +@lu.transformation # type: ignore |
| 114 | +def apply_exposure_time_integration(stencil, dt, time, args, kwargs): |
| 115 | + result = yield (time + dt,) + args, kwargs |
| 116 | + yield jnpu.dot(stencil, result) |
| 117 | + |
| 118 | + |
| 119 | +@units.quantity_input(period=ureg.day, time_transit=ureg.day, duration=ureg.day) |
| 120 | +def interpolate( |
| 121 | + func: LightCurveFunc, |
| 122 | + *, |
| 123 | + period: Quantity, |
| 124 | + time_transit: Quantity, |
| 125 | + num_samples: int, |
| 126 | + duration: Optional[Quantity] = None, |
| 127 | + args: tuple[Any, ...] = (), |
| 128 | + kwargs: Optional[dict[str, Any]] = None, |
| 129 | +) -> LightCurveFunc: |
| 130 | + """Transform a light curve function to pre-compute the model on a grid |
| 131 | +
|
| 132 | + Sometimes it can be useful to precompute the light curve on a grid near a transit, |
| 133 | + and then interpolate those computations to the required phases when computing the |
| 134 | + full model. This can speed things up a lot when you have many transits, or a lot of |
| 135 | + out of transit data. This transform uses linear interpolation. |
| 136 | +
|
| 137 | + .. note:: Unlike some other transforms, this function requires that any upstream |
| 138 | + ``*args`` and ``**kwargs`` be passed directly to the transform, rather than when |
| 139 | + calling the transformed function. This is necessary because the model is |
| 140 | + pre-computed when it is tranformed. |
| 141 | +
|
| 142 | + Args: |
| 143 | + func: A light curve function which takes a time ``Quantity`` as the first |
| 144 | + argument |
| 145 | + period (Quantity): The period of the orbit. Used to wrap the input times into the |
| 146 | + domain of the pre-computed model |
| 147 | + time_transit (Quantity): The transit time of the orbit. Used to wrap the input |
| 148 | + times into the domain of the pre-computed model |
| 149 | + duration (Quantity): The duration centered on the transit to pre-compute. By |
| 150 | + default, the full period will be evaluated |
| 151 | + num_samples (int): The number of points in the time grid used for pre-computation |
| 152 | + args (tuple): Any extra positional arguments that should be passed to ``func`` |
| 153 | + kwargs (dict): Any extra keyword arguments that should be passed to ``func`` |
| 154 | +
|
| 155 | + Returns: |
| 156 | + A new light curve function with the same signature as ``func``, computing the |
| 157 | + flux by interpolating a pre-computed model |
| 158 | + """ |
| 159 | + |
| 160 | + kwargs = kwargs or {} |
| 161 | + if duration is None: |
| 162 | + duration = period |
| 163 | + time_grid = time_transit + duration * jnp.linspace(-0.5, 0.5, num_samples) |
| 164 | + flux_grid = func(time_grid, *args, **kwargs) |
| 165 | + |
| 166 | + if is_quantity(flux_grid): |
| 167 | + flux_magnitude = flux_grid.magnitude |
| 168 | + flux_units = flux_grid.units |
| 169 | + else: |
| 170 | + flux_magnitude = flux_grid |
| 171 | + flux_units = None |
| 172 | + |
| 173 | + @wraps(func) |
| 174 | + @units.quantity_input(time=ureg.d) |
| 175 | + @vectorize |
| 176 | + def wrapped(time: Quantity, *args: Any, **kwargs: Any) -> Union[Array, Quantity]: |
| 177 | + del args, kwargs |
| 178 | + time_wrapped = ( |
| 179 | + jnpu.mod(time - time_transit + 0.5 * period, period) |
| 180 | + + 0.5 * period |
| 181 | + + time_transit |
| 182 | + ) |
| 183 | + flux = jnp.interp( |
| 184 | + time_wrapped.magnitude, |
| 185 | + time_grid.magnitude, |
| 186 | + flux_magnitude, |
| 187 | + left=flux_magnitude[0], |
| 188 | + right=flux_magnitude[-1], |
| 189 | + period=period.magnitude, |
| 190 | + ) |
| 191 | + if flux_units is None: |
| 192 | + return flux |
| 193 | + else: |
| 194 | + return flux * flux_units |
| 195 | + |
| 196 | + return wrapped |
0 commit comments