@@ -689,7 +689,7 @@ def _hdi_multimodal(ary, hdi_prob, skipna, max_modes):
689
689
return np .array (hdi_intervals )
690
690
691
691
692
- def loo (data , pointwise = None , var_name = None , reff = None , scale = None ):
692
+ def loo (data , pointwise = None , var_name = None , reff = None , scale = None , dask_kwargs = None ):
693
693
"""Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).
694
694
695
695
Estimates the expected log pointwise predictive density (elpd) using Pareto-smoothed
@@ -699,20 +699,20 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
699
699
700
700
Parameters
701
701
----------
702
- data: obj
702
+ data : obj
703
703
Any object that can be converted to an :class:`arviz.InferenceData` object.
704
704
Refer to documentation of
705
705
:func:`arviz.convert_to_dataset` for details.
706
- pointwise: bool, optional
706
+ pointwise : bool, optional
707
707
If True the pointwise predictive accuracy will be returned. Defaults to
708
708
``stats.ic_pointwise`` rcParam.
709
709
var_name : str, optional
710
710
The name of the variable in log_likelihood groups storing the pointwise log
711
711
likelihood data to use for loo computation.
712
- reff: float, optional
712
+ reff : float, optional
713
713
Relative MCMC efficiency, ``ess / n`` i.e. number of effective samples divided by the number
714
714
of actual samples. Computed from trace by default.
715
- scale: str
715
+ scale : {"log", "negative_log", "deviance"}, optional
716
716
Output scale for loo. Available options are:
717
717
718
718
- ``log`` : (default) log-score
@@ -721,6 +721,8 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
721
721
722
722
A higher log-score (or a lower deviance or negative log_score) indicates a model with
723
723
better predictive accuracy.
724
+ dask_kwargs : dict, optional
725
+ Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
724
726
725
727
Returns
726
728
-------
@@ -795,7 +797,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
795
797
np .hstack ([ess_p [v ].values .flatten () for v in ess_p .data_vars ]).mean () / n_samples
796
798
)
797
799
798
- log_weights , pareto_shape = psislw (- log_likelihood , reff )
800
+ log_weights , pareto_shape = psislw (- log_likelihood , reff , dask_kwargs = dask_kwargs )
799
801
log_weights += log_likelihood
800
802
801
803
warn_mg = False
@@ -812,20 +814,19 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
812
814
ufunc_kwargs = {"n_dims" : 1 , "ravel" : False }
813
815
kwargs = {"input_core_dims" : [["__sample__" ]]}
814
816
loo_lppd_i = scale_value * _wrap_xarray_ufunc (
815
- _logsumexp , log_weights , ufunc_kwargs = ufunc_kwargs , ** kwargs
816
- )
817
- loo_lppd = loo_lppd_i .values .sum ()
818
- loo_lppd_se = (n_data_points * np .var (loo_lppd_i .values )) ** 0.5
819
-
820
- lppd = np .sum (
821
- _wrap_xarray_ufunc (
822
- _logsumexp ,
823
- log_likelihood ,
824
- func_kwargs = {"b_inv" : n_samples },
825
- ufunc_kwargs = ufunc_kwargs ,
826
- ** kwargs ,
827
- ).values
817
+ _logsumexp , log_weights , ufunc_kwargs = ufunc_kwargs , dask_kwargs = dask_kwargs , ** kwargs
828
818
)
819
+ loo_lppd = loo_lppd_i .sum ().compute ().item ()
820
+ loo_lppd_se = (n_data_points * loo_lppd_i .var ().compute ().item ()) ** 0.5
821
+
822
+ lppd = _wrap_xarray_ufunc (
823
+ _logsumexp ,
824
+ log_likelihood ,
825
+ func_kwargs = {"b_inv" : n_samples },
826
+ ufunc_kwargs = ufunc_kwargs ,
827
+ dask_kwargs = dask_kwargs ,
828
+ ** kwargs ,
829
+ ).sum ().compute .item ()
829
830
p_loo = lppd - loo_lppd / scale_value
830
831
831
832
if not pointwise :
@@ -864,7 +865,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
864
865
)
865
866
866
867
867
- def psislw (log_weights , reff = 1.0 ):
868
+ def psislw (log_weights , reff = 1.0 , dask_kwargs = None ):
868
869
"""
869
870
Pareto smoothed importance sampling (PSIS).
870
871
@@ -878,16 +879,18 @@ def psislw(log_weights, reff=1.0):
878
879
879
880
Parameters
880
881
----------
881
- log_weights: array
882
- Array of size (n_observations , n_samples)
883
- reff: float
882
+ log_weights : array-like
883
+ Array of shape (*observation_shape , n_samples)
884
+ reff : float, optional
884
885
relative MCMC efficiency, ``ess / n``
886
+ dask_kwargs : dict, optional
887
+ Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
885
888
886
889
Returns
887
890
-------
888
- lw_out: array
891
+ lw_out : array
889
892
Smoothed log weights
890
- kss: array
893
+ kss : array
891
894
Pareto tail indices
892
895
893
896
References
@@ -936,6 +939,7 @@ def psislw(log_weights, reff=1.0):
936
939
log_weights ,
937
940
ufunc_kwargs = ufunc_kwargs ,
938
941
func_kwargs = func_kwargs ,
942
+ dask_kwargs = dask_kwargs ,
939
943
** kwargs ,
940
944
)
941
945
if isinstance (log_weights , xr .DataArray ):
0 commit comments