@@ -768,6 +768,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
768
768
log_likelihood = _get_log_likelihood (inference_data , var_name = var_name )
769
769
pointwise = rcParams ["stats.ic_pointwise" ] if pointwise is None else pointwise
770
770
771
+ if dask_kwargs is None :
772
+ dask_kwargs = {}
773
+
771
774
log_likelihood = log_likelihood .stack (__sample__ = ("chain" , "draw" ))
772
775
shape = log_likelihood .shape
773
776
n_samples = shape [- 1 ]
@@ -797,7 +800,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
797
800
np .hstack ([ess_p [v ].values .flatten () for v in ess_p .data_vars ]).mean () / n_samples
798
801
)
799
802
800
- log_weights , pareto_shape = psislw (- log_likelihood , reff , dask_kwargs = dask_kwargs )
803
+ psis_dask = dask_kwargs .copy ()
804
+ psis_dask ["output_dtypes" ] = (float , float )
805
+ log_weights , pareto_shape = psislw (- log_likelihood , reff , dask_kwargs = psis_dask )
801
806
log_weights += log_likelihood
802
807
803
808
warn_mg = False
@@ -813,8 +818,10 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
813
818
814
819
ufunc_kwargs = {"n_dims" : 1 , "ravel" : False }
815
820
kwargs = {"input_core_dims" : [["__sample__" ]]}
821
+ logsumexp_dask = dask_kwargs .copy ()
822
+ logsumexp_dask ["output_dtypes" ] = [float ]
816
823
loo_lppd_i = scale_value * _wrap_xarray_ufunc (
817
- _logsumexp , log_weights , ufunc_kwargs = ufunc_kwargs , dask_kwargs = dask_kwargs , ** kwargs
824
+ _logsumexp , log_weights , ufunc_kwargs = ufunc_kwargs , dask_kwargs = logsumexp_dask , ** kwargs
818
825
)
819
826
loo_lppd = loo_lppd_i .sum ().compute ().item ()
820
827
loo_lppd_se = (n_data_points * loo_lppd_i .var ().compute ().item ()) ** 0.5
@@ -825,7 +832,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
825
832
log_likelihood ,
826
833
func_kwargs = {"b_inv" : n_samples },
827
834
ufunc_kwargs = ufunc_kwargs ,
828
- dask_kwargs = dask_kwargs ,
835
+ dask_kwargs = logsumexp_dask ,
829
836
** kwargs ,
830
837
)
831
838
.sum ()
@@ -920,6 +927,8 @@ def psislw(log_weights, reff=1.0, dask_kwargs=None):
920
927
...: az.psislw(-log_likelihood, reff=0.8)
921
928
922
929
"""
930
+ if dask_kwargs is None :
931
+ dask_kwargs = {}
923
932
if hasattr (log_weights , "__sample__" ):
924
933
n_samples = len (log_weights .__sample__ )
925
934
shape = [
@@ -933,10 +942,13 @@ def psislw(log_weights, reff=1.0, dask_kwargs=None):
933
942
cutoffmin = np .log (np .finfo (float ).tiny ) # pylint: disable=no-member, assignment-from-no-return
934
943
935
944
# create output array with proper dimensions
936
- out = np .empty_like (log_weights ), np .empty (shape )
945
+ if dask_kwargs .get ("dask" , "forbidden" ) in {"allowed" , "parallelized" }:
946
+ out = xr .zeros_like (log_weights ).data , np .empty (shape )
947
+ else :
948
+ out = np .empty_like (log_weights ), np .empty (shape )
937
949
938
950
# define kwargs
939
- func_kwargs = {"cutoff_ind" : cutoff_ind , "cutoffmin" : cutoffmin , "out" : out }
951
+ func_kwargs = {"cutoff_ind" : cutoff_ind , "cutoffmin" : cutoffmin }
940
952
ufunc_kwargs = {"n_dims" : 1 , "n_output" : 2 , "ravel" : False , "check_shape" : False }
941
953
kwargs = {"input_core_dims" : [["__sample__" ]], "output_core_dims" : [["__sample__" ], []]}
942
954
log_weights , pareto_shape = _wrap_xarray_ufunc (
0 commit comments