Skip to content

start working on dask compatible loo #2205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 39 additions & 15 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ def _hdi_multimodal(ary, hdi_prob, skipna, max_modes):
return np.array(hdi_intervals)


def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=None):
"""Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).

Estimates the expected log pointwise predictive density (elpd) using Pareto-smoothed
Expand All @@ -685,20 +685,20 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):

Parameters
----------
data: obj
data : obj
Any object that can be converted to an :class:`arviz.InferenceData` object.
Refer to documentation of
:func:`arviz.convert_to_dataset` for details.
pointwise: bool, optional
pointwise : bool, optional
If True the pointwise predictive accuracy will be returned. Defaults to
``stats.ic_pointwise`` rcParam.
var_name : str, optional
The name of the variable in log_likelihood groups storing the pointwise log
likelihood data to use for loo computation.
reff: float, optional
reff : float, optional
Relative MCMC efficiency, ``ess / n`` i.e. number of effective samples divided by the number
of actual samples. Computed from trace by default.
scale: str
scale : {"log", "negative_log", "deviance"}, optional
Output scale for loo. Available options are:

- ``log`` : (default) log-score
Expand All @@ -707,6 +707,8 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):

A higher log-score (or a lower deviance or negative log_score) indicates a model with
better predictive accuracy.
dask_kwargs : dict, optional
Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.

Returns
-------
Expand Down Expand Up @@ -752,6 +754,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
log_likelihood = _get_log_likelihood(inference_data, var_name=var_name)
pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise

if dask_kwargs is None:
dask_kwargs = {}

log_likelihood = log_likelihood.stack(__sample__=("chain", "draw"))
shape = log_likelihood.shape
n_samples = shape[-1]
Expand Down Expand Up @@ -781,7 +786,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples
)

log_weights, pareto_shape = psislw(-log_likelihood, reff)
psis_dask = dask_kwargs.copy()
psis_dask["output_dtypes"] = (float, float)
log_weights, pareto_shape = psislw(-log_likelihood, reff, dask_kwargs=psis_dask)
log_weights += log_likelihood

warn_mg = False
Expand All @@ -797,20 +804,28 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):

ufunc_kwargs = {"n_dims": 1, "ravel": False}
kwargs = {"input_core_dims": [["__sample__"]]}
logsumexp_dask = dask_kwargs.copy()
logsumexp_dask["output_dtypes"] = [float]
logsumexp_out = xr.zeros_like(log_weights.isel(__sample__=0, drop=True)).data
loo_lppd_i = scale_value * _wrap_xarray_ufunc(
_logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, **kwargs
_logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, func_kwargs={"out": logsumexp_out}, dask_kwargs=logsumexp_dask, **kwargs
)
loo_lppd = loo_lppd_i.values.sum()
loo_lppd_se = (n_data_points * np.var(loo_lppd_i.values)) ** 0.5
loo_lppd = loo_lppd_i.sum().compute().item()
loo_lppd_se = (n_data_points * loo_lppd_i.var().compute().item()) ** 0.5

lppd = np.sum(
lppd_out = xr.zeros_like(log_weights.isel(__sample__=0, drop=True)).data
lppd = (
_wrap_xarray_ufunc(
_logsumexp,
log_likelihood,
func_kwargs={"b_inv": n_samples},
func_kwargs={"b_inv": n_samples, "out": lppd_out},
ufunc_kwargs=ufunc_kwargs,
dask_kwargs=logsumexp_dask,
**kwargs,
).values
)
.sum()
.compute()
.item()
)
p_loo = lppd - loo_lppd / scale_value

Expand Down Expand Up @@ -850,7 +865,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
)


def psislw(log_weights, reff=1.0):
def psislw(log_weights, reff=1.0, dask_kwargs=None):
"""
Pareto smoothed importance sampling (PSIS).

Expand All @@ -865,9 +880,11 @@ def psislw(log_weights, reff=1.0):
Parameters
----------
log_weights : DataArray or (..., N) array-like
Array of size (n_observations, n_samples)
Array of shape (*observation_shape, n_samples)
reff : float, default 1
relative MCMC efficiency, ``ess / n``
dask_kwargs : dict, optional
Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.

Returns
-------
Expand Down Expand Up @@ -899,6 +916,8 @@ def psislw(log_weights, reff=1.0):
...: az.psislw(-log_likelihood, reff=0.8)

"""
if dask_kwargs is None:
dask_kwargs = {}
if hasattr(log_weights, "__sample__"):
n_samples = len(log_weights.__sample__)
shape = [
Expand All @@ -912,7 +931,11 @@ def psislw(log_weights, reff=1.0):
cutoffmin = np.log(np.finfo(float).tiny) # pylint: disable=no-member, assignment-from-no-return

# create output array with proper dimensions
out = np.empty_like(log_weights), np.empty(shape)
if dask_kwargs.get("dask", "forbidden") in {"allowed", "parallelized"}:
out = xr.zeros_like(log_weights).data, xr.zeros_like(log_weights.isel(__sample__=0, drop=True)).data
dask_kwargs["output_dtypes"] = (float, float)
else:
out = np.empty_like(log_weights), np.empty(shape)

# define kwargs
func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin, "out": out}
Expand All @@ -923,6 +946,7 @@ def psislw(log_weights, reff=1.0):
log_weights,
ufunc_kwargs=ufunc_kwargs,
func_kwargs=func_kwargs,
dask_kwargs=dask_kwargs,
**kwargs,
)
if isinstance(log_weights, xr.DataArray):
Expand Down
48 changes: 48 additions & 0 deletions arviz/tests/base_tests/test_stats_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# pylint: disable=redefined-outer-name, no-member
import pytest

import numpy as np

from ...data import load_arviz_data
from ...stats import loo, psislw
from ..helpers import multidim_models, importorskip # pylint: disable=unused-import

dask = importorskip("dask", reason="Dask specific tests")

@pytest.fixture()
def chunked_centered_eight():
centered_eight = load_arviz_data("centered_eight")
centered_eight.log_likelihood = centered_eight.log_likelihood.chunk({"school": 4})
return centered_eight

@pytest.mark.parametrize("multidim", (True, False))
def test_psislw(chunked_centered_eight, multidim_models, multidim):
if multidim:
log_like = multidim_models.model_1.log_likelihood["y"].stack(__sample__=["chain", "draw"]).chunk({"dim2": 3})
else:
log_like = chunked_centered_eight.log_likelihood["obs"].stack(__sample__=["chain", "draw"])
log_weights, khat = psislw(-log_like, dask_kwargs={"dask": "parallelized", })
assert log_weights.shape[:-1] == khat.shape


@pytest.mark.parametrize("multidim", (True, False))
def test_loo(chunked_centered_eight, multidim_models, multidim):
"""Test approximate leave one out criterion calculation"""
if multidim:
idata = multidim_models.model_1
idata.log_likelihood = idata.log_likelihood.chunk({"dim2": 3})
else:
idata = chunked_centered_eight
idata.log_likelihood = idata.log_likelihood
assert loo(idata, dask_kwargs={"dask": "parallelized"}) is not None
loo_pointwise = loo(idata, pointwise=True, dask_kwargs={"dask": "parallelized"})
assert loo_pointwise is not None
assert "loo_i" in loo_pointwise
assert "pareto_k" in loo_pointwise
assert "scale" in loo_pointwise

def test_compare_loo(centered_eight):
loo_ram = loo(centered_eight)
centered_eight.log_likelihood = centered_eight.log_likelihood.chunk({"school": 2})
loo_dask = loo(centered_eight, dask_kwargs={"dask": "parallelized"})
assert np.isclose(loo_ram["elpd"], loo_dask["eldp"])