-
-
Notifications
You must be signed in to change notification settings - Fork 443
Description
Tell us about it
In Bambi we have a function called plot_cap
that is used to obtain visualizations of the fitted curve. We overlay a credible interval so users can visualize the uncertainty around the mean estimate. Internally, we're using az.hdi()
to obtain the bounds. Today, I was implementing some improvements and found the plots look quite noisy. See the following examples
import arviz as az
import bambi as bmb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
print(az.__version__)
# 0.14.0
The following is Bambi specific code, it's not that important for what I want to show
data = pd.read_csv("https://gist.githubusercontent.com/seankross/a412dfbd88b3db70b74b/raw/5f23f993cd87c283ce766e7ac6b329ee7cc2e1d1/mtcars.csv")
model = bmb.Model("mpg ~ 1 + hp", data)
idata = model.fit(random_seed=1234)
# Obtain predictiosn
new_data = pd.DataFrame({"hp": np.linspace(50, 320, 200)})
idata = model.predict(idata, data=new_data, inplace=False)
y_hat = idata.posterior["mpg_mean"]
Get the bands using az.hdi()
y_hat_bounds = az.hdi(y_hat, 0.94)["mpg_mean"].T.to_numpy()
fig, ax = plt.subplots(figsize=(7, 5), dpi=120)
ax.fill_between(new_data["hp"], y_hat_bounds[0], y_hat_bounds[1], alpha=0.5);
Get the bands using .quantile()
in DataArray
, which calls np.quantile
under the hood (if I understood correctly)
y_hat_bounds = y_hat.quantile(q=(0.03, 0.97), dim=("chain", "draw"))
fig, ax = plt.subplots(figsize=(7, 5), dpi=120)
ax.fill_between(new_data["hp"], y_hat_bounds[0], y_hat_bounds[1], alpha=0.5);
Thoughts on implementation
I'm not aware of the historical details that led to the current implementation of az.hdi()
. But I think it's worth considering other alternatives since the current behavior returns very noisy results. I have other examples where it looks even worse, for example here
Tagging @aloctavodia because we talked about this via chat