@@ -806,17 +806,19 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
806
806
kwargs = {"input_core_dims" : [["__sample__" ]]}
807
807
logsumexp_dask = dask_kwargs .copy ()
808
808
logsumexp_dask ["output_dtypes" ] = [float ]
809
+ logsumexp_out = xr .zeros_like (log_weights .isel (__sample__ = 0 , drop = True )).data
809
810
loo_lppd_i = scale_value * _wrap_xarray_ufunc (
810
- _logsumexp , log_weights , ufunc_kwargs = ufunc_kwargs , dask_kwargs = logsumexp_dask , ** kwargs
811
+ _logsumexp , log_weights , ufunc_kwargs = ufunc_kwargs , func_kwargs = { "out" : logsumexp_out } , dask_kwargs = logsumexp_dask , ** kwargs
811
812
)
812
813
loo_lppd = loo_lppd_i .sum ().compute ().item ()
813
814
loo_lppd_se = (n_data_points * loo_lppd_i .var ().compute ().item ()) ** 0.5
814
815
816
+ lppd_out = xr .zeros_like (log_weights .isel (__sample__ = 0 , drop = True )).data
815
817
lppd = (
816
818
_wrap_xarray_ufunc (
817
819
_logsumexp ,
818
820
log_likelihood ,
819
- func_kwargs = {"b_inv" : n_samples },
821
+ func_kwargs = {"b_inv" : n_samples , "out" : lppd_out },
820
822
ufunc_kwargs = ufunc_kwargs ,
821
823
dask_kwargs = logsumexp_dask ,
822
824
** kwargs ,
@@ -930,12 +932,13 @@ def psislw(log_weights, reff=1.0, dask_kwargs=None):
930
932
931
933
# create output array with proper dimensions
932
934
if dask_kwargs .get ("dask" , "forbidden" ) in {"allowed" , "parallelized" }:
933
- out = xr .zeros_like (log_weights ).data , np .empty (shape )
935
+ out = xr .zeros_like (log_weights ).data , xr .zeros_like (log_weights .isel (__sample__ = 0 , drop = True )).data
936
+ dask_kwargs ["output_dtypes" ] = (float , float )
934
937
else :
935
938
out = np .empty_like (log_weights ), np .empty (shape )
936
939
937
940
# define kwargs
938
- func_kwargs = {"cutoff_ind" : cutoff_ind , "cutoffmin" : cutoffmin }
941
+ func_kwargs = {"cutoff_ind" : cutoff_ind , "cutoffmin" : cutoffmin , "out" : out }
939
942
ufunc_kwargs = {"n_dims" : 1 , "n_output" : 2 , "ravel" : False , "check_shape" : False }
940
943
kwargs = {"input_core_dims" : [["__sample__" ]], "output_core_dims" : [["__sample__" ], []]}
941
944
log_weights , pareto_shape = _wrap_xarray_ufunc (
0 commit comments