Skip to content

Commit 636a890

Browse files
committed
fix out and kwargs in psislw
1 parent b12b4ce commit 636a890

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

arviz/stats/stats.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -806,17 +806,19 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
806806
kwargs = {"input_core_dims": [["__sample__"]]}
807807
logsumexp_dask = dask_kwargs.copy()
808808
logsumexp_dask["output_dtypes"] = [float]
809+
logsumexp_out = xr.zeros_like(log_weights.isel(__sample__=0, drop=True)).data
809810
loo_lppd_i = scale_value * _wrap_xarray_ufunc(
810-
_logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, dask_kwargs=logsumexp_dask, **kwargs
811+
_logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, func_kwargs={"out": logsumexp_out}, dask_kwargs=logsumexp_dask, **kwargs
811812
)
812813
loo_lppd = loo_lppd_i.sum().compute().item()
813814
loo_lppd_se = (n_data_points * loo_lppd_i.var().compute().item()) ** 0.5
814815

816+
lppd_out = xr.zeros_like(log_weights.isel(__sample__=0, drop=True)).data
815817
lppd = (
816818
_wrap_xarray_ufunc(
817819
_logsumexp,
818820
log_likelihood,
819-
func_kwargs={"b_inv": n_samples},
821+
func_kwargs={"b_inv": n_samples, "out": lppd_out},
820822
ufunc_kwargs=ufunc_kwargs,
821823
dask_kwargs=logsumexp_dask,
822824
**kwargs,
@@ -930,12 +932,13 @@ def psislw(log_weights, reff=1.0, dask_kwargs=None):
930932

931933
# create output array with proper dimensions
932934
if dask_kwargs.get("dask", "forbidden") in {"allowed", "parallelized"}:
933-
out = xr.zeros_like(log_weights).data, np.empty(shape)
935+
out = xr.zeros_like(log_weights).data, xr.zeros_like(log_weights.isel(__sample__=0, drop=True)).data
936+
dask_kwargs["output_dtypes"] = (float, float)
934937
else:
935938
out = np.empty_like(log_weights), np.empty(shape)
936939

937940
# define kwargs
938-
func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin}
941+
func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin, "out": out}
939942
ufunc_kwargs = {"n_dims": 1, "n_output": 2, "ravel": False, "check_shape": False}
940943
kwargs = {"input_core_dims": [["__sample__"]], "output_core_dims": [["__sample__"], []]}
941944
log_weights, pareto_shape = _wrap_xarray_ufunc(

arviz/tests/base_tests/test_stats_dask.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,36 @@
44
import numpy as np
55

66
from ...data import load_arviz_data
7-
from ...stats import loo
7+
from ...stats import loo, psislw
88
from ..helpers import multidim_models, importorskip # pylint: disable=unused-import
99

1010
dask = importorskip("dask", reason="Dask specific tests")
1111

1212
@pytest.fixture()
13-
def centered_eight():
13+
def chunked_centered_eight():
1414
centered_eight = load_arviz_data("centered_eight")
15+
centered_eight.log_likelihood = centered_eight.log_likelihood.chunk({"school": 4})
1516
return centered_eight
1617

1718
@pytest.mark.parametrize("multidim", (True, False))
18-
def test_loo(centered_eight, multidim_models, multidim):
19+
def test_psislw(chunked_centered_eight, multidim_models, multidim):
20+
if multidim:
21+
log_like = multidim_models.model_1.log_likelihood["y"].stack(__sample__=["chain", "draw"]).chunk({"dim2": 3})
22+
else:
23+
log_like = chunked_centered_eight.log_likelihood["obs"].stack(__sample__=["chain", "draw"])
24+
log_weights, khat = psislw(-log_like, dask_kwargs={"dask": "parallelized", })
25+
assert log_weights.shape[:-1] == khat.shape
26+
27+
28+
@pytest.mark.parametrize("multidim", (True, False))
29+
def test_loo(chunked_centered_eight, multidim_models, multidim):
1930
"""Test approximate leave one out criterion calculation"""
2031
if multidim:
2132
idata = multidim_models.model_1
2233
idata.log_likelihood = idata.log_likelihood.chunk({"dim2": 3})
2334
else:
24-
idata = centered_eight
25-
idata.log_likelihood = idata.log_likelihood.chunk({"school": 4})
35+
idata = chunked_centered_eight
36+
idata.log_likelihood = idata.log_likelihood
2637
assert loo(idata, dask_kwargs={"dask": "parallelized"}) is not None
2738
loo_pointwise = loo(idata, pointwise=True, dask_kwargs={"dask": "parallelized"})
2839
assert loo_pointwise is not None

0 commit comments

Comments
 (0)