Skip to content

Commit c491ce4

Browse files
committed
continue work and add tests
1 parent b565643 commit c491ce4

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

arviz/stats/stats.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
768768
log_likelihood = _get_log_likelihood(inference_data, var_name=var_name)
769769
pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise
770770

771+
if dask_kwargs is None:
772+
dask_kwargs = {}
773+
771774
log_likelihood = log_likelihood.stack(__sample__=("chain", "draw"))
772775
shape = log_likelihood.shape
773776
n_samples = shape[-1]
@@ -797,7 +800,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
797800
np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples
798801
)
799802

800-
log_weights, pareto_shape = psislw(-log_likelihood, reff, dask_kwargs=dask_kwargs)
803+
psis_dask = dask_kwargs.copy()
804+
psis_dask["output_dtypes"] = (float, float)
805+
log_weights, pareto_shape = psislw(-log_likelihood, reff, dask_kwargs=psis_dask)
801806
log_weights += log_likelihood
802807

803808
warn_mg = False
@@ -813,8 +818,10 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
813818

814819
ufunc_kwargs = {"n_dims": 1, "ravel": False}
815820
kwargs = {"input_core_dims": [["__sample__"]]}
821+
logsumexp_dask = dask_kwargs.copy()
822+
logsumexp_dask["output_dtypes"] = [float]
816823
loo_lppd_i = scale_value * _wrap_xarray_ufunc(
817-
_logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, dask_kwargs=dask_kwargs, **kwargs
824+
_logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, dask_kwargs=logsumexp_dask, **kwargs
818825
)
819826
loo_lppd = loo_lppd_i.sum().compute().item()
820827
loo_lppd_se = (n_data_points * loo_lppd_i.var().compute().item()) ** 0.5
@@ -825,7 +832,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
825832
log_likelihood,
826833
func_kwargs={"b_inv": n_samples},
827834
ufunc_kwargs=ufunc_kwargs,
828-
dask_kwargs=dask_kwargs,
835+
dask_kwargs=logsumexp_dask,
829836
**kwargs,
830837
)
831838
.sum()
@@ -920,6 +927,8 @@ def psislw(log_weights, reff=1.0, dask_kwargs=None):
920927
...: az.psislw(-log_likelihood, reff=0.8)
921928
922929
"""
930+
if dask_kwargs is None:
931+
dask_kwargs = {}
923932
if hasattr(log_weights, "__sample__"):
924933
n_samples = len(log_weights.__sample__)
925934
shape = [
@@ -933,10 +942,13 @@ def psislw(log_weights, reff=1.0, dask_kwargs=None):
933942
cutoffmin = np.log(np.finfo(float).tiny) # pylint: disable=no-member, assignment-from-no-return
934943

935944
# create output array with proper dimensions
936-
out = np.empty_like(log_weights), np.empty(shape)
945+
if dask_kwargs.get("dask", "forbidden") in {"allowed", "parallelized"}:
946+
out = xr.zeros_like(log_weights).data, np.empty(shape)
947+
else:
948+
out = np.empty_like(log_weights), np.empty(shape)
937949

938950
# define kwargs
939-
func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin, "out": out}
951+
func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin}
940952
ufunc_kwargs = {"n_dims": 1, "n_output": 2, "ravel": False, "check_shape": False}
941953
kwargs = {"input_core_dims": [["__sample__"]], "output_core_dims": [["__sample__"], []]}
942954
log_weights, pareto_shape = _wrap_xarray_ufunc(
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# pylint: disable=redefined-outer-name, no-member
2+
import pytest
3+
4+
import numpy as np
5+
6+
from ...data import load_arviz_data
7+
from ...stats import loo
8+
from ..helpers import multidim_models, importorskip # pylint: disable=unused-import
9+
10+
dask = importorskip("dask", reason="Dask specific tests")
11+
12+
@pytest.fixture()
13+
def centered_eight():
14+
centered_eight = load_arviz_data("centered_eight")
15+
return centered_eight
16+
17+
@pytest.mark.parametrize("multidim", (True, False))
18+
def test_loo(centered_eight, multidim_models, multidim):
19+
"""Test approximate leave one out criterion calculation"""
20+
if multidim:
21+
idata = multidim_models.model_1
22+
idata.log_likelihood = idata.log_likelihood.chunk({"dim2": 3})
23+
else:
24+
idata = centered_eight
25+
idata.log_likelihood = idata.log_likelihood.chunk({"school": 4})
26+
assert loo(idata, dask_kwargs={"dask": "parallelized"}) is not None
27+
loo_pointwise = loo(idata, pointwise=True, dask_kwargs={"dask": "parallelized"})
28+
assert loo_pointwise is not None
29+
assert "loo_i" in loo_pointwise
30+
assert "pareto_k" in loo_pointwise
31+
assert "scale" in loo_pointwise
32+
33+
def test_compare_loo(centered_eight):
34+
loo_ram = loo(centered_eight)
35+
centered_eight.log_likelihood = centered_eight.log_likelihood.chunk({"school": 2})
36+
loo_dask = loo(centered_eight, dask_kwargs={"dask": "parallelized"})
37+
assert np.isclose(loo_ram["elpd"], loo_dask["eldp"])

0 commit comments

Comments
 (0)