Skip to content

Commit b12b4ce

Browse files
committed
continue work and add tests
1 parent 9ba36ab commit b12b4ce

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

arviz/stats/stats.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
754754
log_likelihood = _get_log_likelihood(inference_data, var_name=var_name)
755755
pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise
756756

757+
if dask_kwargs is None:
758+
dask_kwargs = {}
759+
757760
log_likelihood = log_likelihood.stack(__sample__=("chain", "draw"))
758761
shape = log_likelihood.shape
759762
n_samples = shape[-1]
@@ -783,7 +786,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
783786
np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples
784787
)
785788

786-
log_weights, pareto_shape = psislw(-log_likelihood, reff, dask_kwargs=dask_kwargs)
789+
psis_dask = dask_kwargs.copy()
790+
psis_dask["output_dtypes"] = (float, float)
791+
log_weights, pareto_shape = psislw(-log_likelihood, reff, dask_kwargs=psis_dask)
787792
log_weights += log_likelihood
788793

789794
warn_mg = False
@@ -799,8 +804,10 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
799804

800805
ufunc_kwargs = {"n_dims": 1, "ravel": False}
801806
kwargs = {"input_core_dims": [["__sample__"]]}
807+
logsumexp_dask = dask_kwargs.copy()
808+
logsumexp_dask["output_dtypes"] = [float]
802809
loo_lppd_i = scale_value * _wrap_xarray_ufunc(
803-
_logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, dask_kwargs=dask_kwargs, **kwargs
810+
_logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, dask_kwargs=logsumexp_dask, **kwargs
804811
)
805812
loo_lppd = loo_lppd_i.sum().compute().item()
806813
loo_lppd_se = (n_data_points * loo_lppd_i.var().compute().item()) ** 0.5
@@ -811,7 +818,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
811818
log_likelihood,
812819
func_kwargs={"b_inv": n_samples},
813820
ufunc_kwargs=ufunc_kwargs,
814-
dask_kwargs=dask_kwargs,
821+
dask_kwargs=logsumexp_dask,
815822
**kwargs,
816823
)
817824
.sum()
@@ -907,6 +914,8 @@ def psislw(log_weights, reff=1.0, dask_kwargs=None):
907914
...: az.psislw(-log_likelihood, reff=0.8)
908915
909916
"""
917+
if dask_kwargs is None:
918+
dask_kwargs = {}
910919
if hasattr(log_weights, "__sample__"):
911920
n_samples = len(log_weights.__sample__)
912921
shape = [
@@ -920,10 +929,13 @@ def psislw(log_weights, reff=1.0, dask_kwargs=None):
920929
cutoffmin = np.log(np.finfo(float).tiny) # pylint: disable=no-member, assignment-from-no-return
921930

922931
# create output array with proper dimensions
923-
out = np.empty_like(log_weights), np.empty(shape)
932+
if dask_kwargs.get("dask", "forbidden") in {"allowed", "parallelized"}:
933+
out = xr.zeros_like(log_weights).data, np.empty(shape)
934+
else:
935+
out = np.empty_like(log_weights), np.empty(shape)
924936

925937
# define kwargs
926-
func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin, "out": out}
938+
func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin}
927939
ufunc_kwargs = {"n_dims": 1, "n_output": 2, "ravel": False, "check_shape": False}
928940
kwargs = {"input_core_dims": [["__sample__"]], "output_core_dims": [["__sample__"], []]}
929941
log_weights, pareto_shape = _wrap_xarray_ufunc(
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# pylint: disable=redefined-outer-name, no-member
2+
import pytest
3+
4+
import numpy as np
5+
6+
from ...data import load_arviz_data
7+
from ...stats import loo
8+
from ..helpers import multidim_models, importorskip # pylint: disable=unused-import
9+
10+
dask = importorskip("dask", reason="Dask specific tests")
11+
12+
@pytest.fixture()
13+
def centered_eight():
14+
centered_eight = load_arviz_data("centered_eight")
15+
return centered_eight
16+
17+
@pytest.mark.parametrize("multidim", (True, False))
18+
def test_loo(centered_eight, multidim_models, multidim):
19+
"""Test approximate leave one out criterion calculation"""
20+
if multidim:
21+
idata = multidim_models.model_1
22+
idata.log_likelihood = idata.log_likelihood.chunk({"dim2": 3})
23+
else:
24+
idata = centered_eight
25+
idata.log_likelihood = idata.log_likelihood.chunk({"school": 4})
26+
assert loo(idata, dask_kwargs={"dask": "parallelized"}) is not None
27+
loo_pointwise = loo(idata, pointwise=True, dask_kwargs={"dask": "parallelized"})
28+
assert loo_pointwise is not None
29+
assert "loo_i" in loo_pointwise
30+
assert "pareto_k" in loo_pointwise
31+
assert "scale" in loo_pointwise
32+
33+
def test_compare_loo(centered_eight):
34+
loo_ram = loo(centered_eight)
35+
centered_eight.log_likelihood = centered_eight.log_likelihood.chunk({"school": 2})
36+
loo_dask = loo(centered_eight, dask_kwargs={"dask": "parallelized"})
37+
assert np.isclose(loo_ram["elpd"], loo_dask["eldp"])

0 commit comments

Comments
 (0)