@@ -754,6 +754,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
754
754
log_likelihood = _get_log_likelihood (inference_data , var_name = var_name )
755
755
pointwise = rcParams ["stats.ic_pointwise" ] if pointwise is None else pointwise
756
756
757
+ if dask_kwargs is None :
758
+ dask_kwargs = {}
759
+
757
760
log_likelihood = log_likelihood .stack (__sample__ = ("chain" , "draw" ))
758
761
shape = log_likelihood .shape
759
762
n_samples = shape [- 1 ]
@@ -783,7 +786,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
783
786
np .hstack ([ess_p [v ].values .flatten () for v in ess_p .data_vars ]).mean () / n_samples
784
787
)
785
788
786
- log_weights , pareto_shape = psislw (- log_likelihood , reff , dask_kwargs = dask_kwargs )
789
+ psis_dask = dask_kwargs .copy ()
790
+ psis_dask ["output_dtypes" ] = (float , float )
791
+ log_weights , pareto_shape = psislw (- log_likelihood , reff , dask_kwargs = psis_dask )
787
792
log_weights += log_likelihood
788
793
789
794
warn_mg = False
@@ -799,8 +804,10 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
799
804
800
805
ufunc_kwargs = {"n_dims" : 1 , "ravel" : False }
801
806
kwargs = {"input_core_dims" : [["__sample__" ]]}
807
+ logsumexp_dask = dask_kwargs .copy ()
808
+ logsumexp_dask ["output_dtypes" ] = [float ]
802
809
loo_lppd_i = scale_value * _wrap_xarray_ufunc (
803
- _logsumexp , log_weights , ufunc_kwargs = ufunc_kwargs , dask_kwargs = dask_kwargs , ** kwargs
810
+ _logsumexp , log_weights , ufunc_kwargs = ufunc_kwargs , dask_kwargs = logsumexp_dask , ** kwargs
804
811
)
805
812
loo_lppd = loo_lppd_i .sum ().compute ().item ()
806
813
loo_lppd_se = (n_data_points * loo_lppd_i .var ().compute ().item ()) ** 0.5
@@ -811,7 +818,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
811
818
log_likelihood ,
812
819
func_kwargs = {"b_inv" : n_samples },
813
820
ufunc_kwargs = ufunc_kwargs ,
814
- dask_kwargs = dask_kwargs ,
821
+ dask_kwargs = logsumexp_dask ,
815
822
** kwargs ,
816
823
)
817
824
.sum ()
@@ -907,6 +914,8 @@ def psislw(log_weights, reff=1.0, dask_kwargs=None):
907
914
...: az.psislw(-log_likelihood, reff=0.8)
908
915
909
916
"""
917
+ if dask_kwargs is None :
918
+ dask_kwargs = {}
910
919
if hasattr (log_weights , "__sample__" ):
911
920
n_samples = len (log_weights .__sample__ )
912
921
shape = [
@@ -920,10 +929,13 @@ def psislw(log_weights, reff=1.0, dask_kwargs=None):
920
929
cutoffmin = np .log (np .finfo (float ).tiny ) # pylint: disable=no-member, assignment-from-no-return
921
930
922
931
# create output array with proper dimensions
923
- out = np .empty_like (log_weights ), np .empty (shape )
932
+ if dask_kwargs .get ("dask" , "forbidden" ) in {"allowed" , "parallelized" }:
933
+ out = xr .zeros_like (log_weights ).data , np .empty (shape )
934
+ else :
935
+ out = np .empty_like (log_weights ), np .empty (shape )
924
936
925
937
# define kwargs
926
- func_kwargs = {"cutoff_ind" : cutoff_ind , "cutoffmin" : cutoffmin , "out" : out }
938
+ func_kwargs = {"cutoff_ind" : cutoff_ind , "cutoffmin" : cutoffmin }
927
939
ufunc_kwargs = {"n_dims" : 1 , "n_output" : 2 , "ravel" : False , "check_shape" : False }
928
940
kwargs = {"input_core_dims" : [["__sample__" ]], "output_core_dims" : [["__sample__" ], []]}
929
941
log_weights , pareto_shape = _wrap_xarray_ufunc (
0 commit comments