Skip to content

Commit bbc2479

Browse files
committed
start working on dask compatible loo
1 parent 4fc5000 commit bbc2479

File tree

1 file changed

+29
-25
lines changed

1 file changed

+29
-25
lines changed

arviz/stats/stats.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ def _hdi_multimodal(ary, hdi_prob, skipna, max_modes):
689689
return np.array(hdi_intervals)
690690

691691

692-
def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
692+
def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=None):
693693
"""Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).
694694
695695
Estimates the expected log pointwise predictive density (elpd) using Pareto-smoothed
@@ -699,20 +699,20 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
699699
700700
Parameters
701701
----------
702-
data: obj
702+
data : obj
703703
Any object that can be converted to an :class:`arviz.InferenceData` object.
704704
Refer to documentation of
705705
:func:`arviz.convert_to_dataset` for details.
706-
pointwise: bool, optional
706+
pointwise : bool, optional
707707
If True the pointwise predictive accuracy will be returned. Defaults to
708708
``stats.ic_pointwise`` rcParam.
709709
var_name : str, optional
710710
The name of the variable in log_likelihood groups storing the pointwise log
711711
likelihood data to use for loo computation.
712-
reff: float, optional
712+
reff : float, optional
713713
Relative MCMC efficiency, ``ess / n`` i.e. number of effective samples divided by the number
714714
of actual samples. Computed from trace by default.
715-
scale: str
715+
scale : {"log", "negative_log", "deviance"}, optional
716716
Output scale for loo. Available options are:
717717
718718
- ``log`` : (default) log-score
@@ -721,6 +721,8 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
721721
722722
A higher log-score (or a lower deviance or negative log_score) indicates a model with
723723
better predictive accuracy.
724+
dask_kwargs : dict, optional
725+
Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
724726
725727
Returns
726728
-------
@@ -795,7 +797,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
795797
np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples
796798
)
797799

798-
log_weights, pareto_shape = psislw(-log_likelihood, reff)
800+
log_weights, pareto_shape = psislw(-log_likelihood, reff, dask_kwargs=dask_kwargs)
799801
log_weights += log_likelihood
800802

801803
warn_mg = False
@@ -812,20 +814,19 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
812814
ufunc_kwargs = {"n_dims": 1, "ravel": False}
813815
kwargs = {"input_core_dims": [["__sample__"]]}
814816
loo_lppd_i = scale_value * _wrap_xarray_ufunc(
815-
_logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, **kwargs
816-
)
817-
loo_lppd = loo_lppd_i.values.sum()
818-
loo_lppd_se = (n_data_points * np.var(loo_lppd_i.values)) ** 0.5
819-
820-
lppd = np.sum(
821-
_wrap_xarray_ufunc(
822-
_logsumexp,
823-
log_likelihood,
824-
func_kwargs={"b_inv": n_samples},
825-
ufunc_kwargs=ufunc_kwargs,
826-
**kwargs,
827-
).values
817+
_logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, dask_kwargs=dask_kwargs, **kwargs
828818
)
819+
loo_lppd = loo_lppd_i.sum().compute().item()
820+
loo_lppd_se = (n_data_points * loo_lppd_i.var().compute().item()) ** 0.5
821+
822+
lppd = _wrap_xarray_ufunc(
823+
_logsumexp,
824+
log_likelihood,
825+
func_kwargs={"b_inv": n_samples},
826+
ufunc_kwargs=ufunc_kwargs,
827+
dask_kwargs=dask_kwargs,
828+
**kwargs,
829+
).sum().compute.item()
829830
p_loo = lppd - loo_lppd / scale_value
830831

831832
if not pointwise:
@@ -864,7 +865,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
864865
)
865866

866867

867-
def psislw(log_weights, reff=1.0):
868+
def psislw(log_weights, reff=1.0, dask_kwargs=None):
868869
"""
869870
Pareto smoothed importance sampling (PSIS).
870871
@@ -878,16 +879,18 @@ def psislw(log_weights, reff=1.0):
878879
879880
Parameters
880881
----------
881-
log_weights: array
882-
Array of size (n_observations, n_samples)
883-
reff: float
882+
log_weights : array-like
883+
Array of shape (*observation_shape, n_samples)
884+
reff : float, optional
884885
relative MCMC efficiency, ``ess / n``
886+
dask_kwargs : dict, optional
887+
Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
885888
886889
Returns
887890
-------
888-
lw_out: array
891+
lw_out : array
889892
Smoothed log weights
890-
kss: array
893+
kss : array
891894
Pareto tail indices
892895
893896
References
@@ -936,6 +939,7 @@ def psislw(log_weights, reff=1.0):
936939
log_weights,
937940
ufunc_kwargs=ufunc_kwargs,
938941
func_kwargs=func_kwargs,
942+
dask_kwargs=dask_kwargs,
939943
**kwargs,
940944
)
941945
if isinstance(log_weights, xr.DataArray):

0 commit comments

Comments
 (0)