Skip to content

Enable multiple treated units in synthetic control quasi experiments #494

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 28 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
59c5b6e
initial commit - utilising vibe coding
drbenvincent Jun 28, 2025
26691ec
update multi cell geolift notebook with new functionality/API
drbenvincent Jun 28, 2025
4fa1650
add test that the r2 scores differ across treated units
drbenvincent Jun 28, 2025
7b473af
simplify plot code
drbenvincent Jun 28, 2025
be01357
revert to simpler plot titles
drbenvincent Jun 28, 2025
47e0a2d
rename primary_unit_name -> treated_unit + revert to no "Unit" title
drbenvincent Jun 28, 2025
ee8a92b
revert change in causal impact shaded region colour
drbenvincent Jun 28, 2025
b79743f
code simplifications by always having a treated_units dimension
drbenvincent Jun 28, 2025
aa9920a
code simplification relating to _get_score_title
drbenvincent Jun 28, 2025
ebddbb5
another code simplification - related to scoring
drbenvincent Jun 28, 2025
eac1ef3
code simplification in _ols_plot
drbenvincent Jun 28, 2025
1a5f9bd
code simplification related to PyMCModel._data_setter
drbenvincent Jun 28, 2025
e67a28c
add sphinx-togglebutton for a collapsible admonition + other updates …
drbenvincent Jun 29, 2025
d0fc0d3
update uml diagrams
drbenvincent Jun 29, 2025
b6f5ca8
PyMCModel.score always to get xr.DataArray arguments
drbenvincent Jun 29, 2025
8badc05
simplification to PyMC.print_coefficients
drbenvincent Jun 29, 2025
4a78a50
simplification of WeightedSumFitter.build_model
drbenvincent Jun 29, 2025
f1849b1
clean up PyMCModel.predict + PyMCModel._data_setter
drbenvincent Jun 29, 2025
a55f97d
remove a numerical index in favour of a named dimension
drbenvincent Jun 29, 2025
6341e53
make code comment more specific
drbenvincent Jun 29, 2025
2befccb
consolidate tests, fix doctest
drbenvincent Jun 29, 2025
78be544
Merge branch 'main' into multi-cell-geolift
drbenvincent Jun 29, 2025
37e8de7
set fixture scope to module
drbenvincent Jun 30, 2025
d0c520f
towards a more unified scoring (r2) approach
drbenvincent Jun 30, 2025
3ee430e
more unification with score (r2) in terms of unified naming: unit_{n}_r2
drbenvincent Jun 30, 2025
adf04f9
refactor PyMCModel.score
drbenvincent Jun 30, 2025
25608ef
the grand simplification
drbenvincent Jun 30, 2025
2fba338
tweak docstring
drbenvincent Jun 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions causalpy/experiments/diff_in_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,18 @@ def __init__(
},
)
self.y = xr.DataArray(
self.y[:, 0],
dims=["obs_ind"],
coords={"obs_ind": np.arange(self.y.shape[0])},
self.y,
dims=["obs_ind", "treated_units"],
coords={"obs_ind": np.arange(self.y.shape[0]), "treated_units": ["unit_0"]},
)

# fit model
if isinstance(self.model, PyMCModel):
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.X.shape[0])}
COORDS = {
"coeffs": self.labels,
"obs_ind": np.arange(self.X.shape[0]),
"treated_units": ["unit_0"],
}
self.model.fit(X=self.X, y=self.y, coords=COORDS)
elif isinstance(self.model, RegressorMixin):
self.model.fit(X=self.X, y=self.y)
Expand Down Expand Up @@ -203,7 +207,7 @@ def __init__(
# TODO: CHECK FOR CORRECTNESS
self.causal_impact = (
self.y_pred_treatment[1] - self.y_pred_counterfactual[0]
)
).item()
else:
raise ValueError("Model type not recognized")

Expand Down Expand Up @@ -321,7 +325,7 @@ def _plot_causal_impact_arrow(results, ax):
time_points = self.x_pred_control[self.time_variable_name].values
h_line, h_patch = plot_xY(
time_points,
self.y_pred_control.posterior_predictive.mu,
self.y_pred_control["posterior_predictive"].mu.isel(treated_units=0),
ax=ax,
plot_hdi_kwargs={"color": "C0"},
label="Control group",
Expand All @@ -333,7 +337,7 @@ def _plot_causal_impact_arrow(results, ax):
time_points = self.x_pred_control[self.time_variable_name].values
h_line, h_patch = plot_xY(
time_points,
self.y_pred_treatment.posterior_predictive.mu,
self.y_pred_treatment["posterior_predictive"].mu.isel(treated_units=0),
ax=ax,
plot_hdi_kwargs={"color": "C1"},
label="Treatment group",
Expand All @@ -345,12 +349,20 @@ def _plot_causal_impact_arrow(results, ax):
# had occurred.
time_points = self.x_pred_counterfactual[self.time_variable_name].values
if len(time_points) == 1:
y_pred_cf = az.extract(
self.y_pred_counterfactual,
group="posterior_predictive",
var_names="mu",
)
# Select single unit data for plotting
y_pred_cf_single = y_pred_cf.isel(treated_units=0)
violin_data = (
y_pred_cf_single.values
if hasattr(y_pred_cf_single, "values")
else y_pred_cf_single
)
parts = ax.violinplot(
az.extract(
self.y_pred_counterfactual,
group="posterior_predictive",
var_names="mu",
).values.T,
violin_data.T,
positions=self.x_pred_counterfactual[self.time_variable_name].values,
showmeans=False,
showmedians=False,
Expand All @@ -363,7 +375,9 @@ def _plot_causal_impact_arrow(results, ax):
else:
h_line, h_patch = plot_xY(
time_points,
self.y_pred_counterfactual.posterior_predictive.mu,
self.y_pred_counterfactual.posterior_predictive.mu.isel(
treated_units=0
),
ax=ax,
plot_hdi_kwargs={"color": "C2"},
label="Counterfactual",
Expand Down
120 changes: 87 additions & 33 deletions causalpy/experiments/interrupted_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ def __init__(
},
)
self.pre_y = xr.DataArray(
self.pre_y[:, 0],
dims=["obs_ind"],
coords={"obs_ind": self.datapre.index},
self.pre_y, # Keep 2D shape
dims=["obs_ind", "treated_units"],
coords={"obs_ind": self.datapre.index, "treated_units": ["unit_0"]},
)
self.post_X = xr.DataArray(
self.post_X,
Expand All @@ -133,17 +133,22 @@ def __init__(
},
)
self.post_y = xr.DataArray(
self.post_y[:, 0],
dims=["obs_ind"],
coords={"obs_ind": self.datapost.index},
self.post_y, # Keep 2D shape
dims=["obs_ind", "treated_units"],
coords={"obs_ind": self.datapost.index, "treated_units": ["unit_0"]},
)

# fit the model to the observed (pre-intervention) data
if isinstance(self.model, PyMCModel):
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.pre_X.shape[0])}
COORDS = {
"coeffs": self.labels,
"obs_ind": np.arange(self.pre_X.shape[0]),
"treated_units": ["unit_0"],
}
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
elif isinstance(self.model, RegressorMixin):
self.model.fit(X=self.pre_X, y=self.pre_y)
# For OLS models, use 1D y data
self.model.fit(X=self.pre_X, y=self.pre_y.isel(treated_units=0))
else:
raise ValueError("Model type not recognized")

Expand All @@ -155,8 +160,21 @@ def __init__(

# calculate the counterfactual
self.post_pred = self.model.predict(X=self.post_X)
self.pre_impact = self.model.calculate_impact(self.pre_y, self.pre_pred)
self.post_impact = self.model.calculate_impact(self.post_y, self.post_pred)

# calculate impact - use appropriate y data format for each model type
if isinstance(self.model, PyMCModel):
# PyMC models work with 2D data
self.pre_impact = self.model.calculate_impact(self.pre_y, self.pre_pred)
self.post_impact = self.model.calculate_impact(self.post_y, self.post_pred)
elif isinstance(self.model, RegressorMixin):
# SKL models work with 1D data
self.pre_impact = self.model.calculate_impact(
self.pre_y.isel(treated_units=0), self.pre_pred
)
self.post_impact = self.model.calculate_impact(
self.post_y.isel(treated_units=0), self.post_pred
)

self.post_impact_cumulative = self.model.calculate_cumulative_impact(
self.post_impact
)
Expand Down Expand Up @@ -202,35 +220,53 @@ def _bayesian_plot(
# pre-intervention period
h_line, h_patch = plot_xY(
self.datapre.index,
self.pre_pred["posterior_predictive"].mu,
self.pre_pred["posterior_predictive"].mu.isel(treated_units=0),
ax=ax[0],
plot_hdi_kwargs={"color": "C0"},
)
handles = [(h_line, h_patch)]
labels = ["Pre-intervention period"]

(h,) = ax[0].plot(self.datapre.index, self.pre_y, "k.", label="Observations")
(h,) = ax[0].plot(
self.datapre.index,
self.pre_y.isel(treated_units=0)
if hasattr(self.pre_y, "isel")
else self.pre_y[:, 0],
"k.",
label="Observations",
)
handles.append(h)
labels.append("Observations")

# post intervention period
h_line, h_patch = plot_xY(
self.datapost.index,
self.post_pred["posterior_predictive"].mu,
self.post_pred["posterior_predictive"].mu.isel(treated_units=0),
ax=ax[0],
plot_hdi_kwargs={"color": "C1"},
)
handles.append((h_line, h_patch))
labels.append(counterfactual_label)

ax[0].plot(self.datapost.index, self.post_y, "k.")
ax[0].plot(
self.datapost.index,
self.post_y.isel(treated_units=0)
if hasattr(self.post_y, "isel")
else self.post_y[:, 0],
"k.",
)
# Shaded causal effect
post_pred_mu = (
az.extract(self.post_pred, group="posterior_predictive", var_names="mu")
.isel(treated_units=0)
.mean("sample")
) # Add .mean("sample") to get 1D array
h = ax[0].fill_between(
self.datapost.index,
y1=az.extract(
self.post_pred, group="posterior_predictive", var_names="mu"
).mean("sample"),
y2=np.squeeze(self.post_y),
y1=post_pred_mu,
y2=self.post_y.isel(treated_units=0)
if hasattr(self.post_y, "isel")
else self.post_y[:, 0],
color="C0",
alpha=0.25,
)
Expand All @@ -239,28 +275,28 @@ def _bayesian_plot(

ax[0].set(
title=f"""
Pre-intervention Bayesian $R^2$: {round_num(self.score.r2, round_to)}
(std = {round_num(self.score.r2_std, round_to)})
Pre-intervention Bayesian $R^2$: {round_num(self.score["unit_0_r2"], round_to)}
(std = {round_num(self.score["unit_0_r2_std"], round_to)})
"""
)

# MIDDLE PLOT -----------------------------------------------
plot_xY(
self.datapre.index,
self.pre_impact,
self.pre_impact.isel(treated_units=0),
ax=ax[1],
plot_hdi_kwargs={"color": "C0"},
)
plot_xY(
self.datapost.index,
self.post_impact,
self.post_impact.isel(treated_units=0),
ax=ax[1],
plot_hdi_kwargs={"color": "C1"},
)
ax[1].axhline(y=0, c="k")
ax[1].fill_between(
self.datapost.index,
y1=self.post_impact.mean(["chain", "draw"]),
y1=self.post_impact.mean(["chain", "draw"]).isel(treated_units=0),
color="C0",
alpha=0.25,
label="Causal impact",
Expand All @@ -271,7 +307,7 @@ def _bayesian_plot(
ax[2].set(title="Cumulative Causal Impact")
plot_xY(
self.datapost.index,
self.post_impact_cumulative,
self.post_impact_cumulative.isel(treated_units=0),
ax=ax[2],
plot_hdi_kwargs={"color": "C1"},
)
Expand Down Expand Up @@ -387,27 +423,45 @@ def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
pre_data["prediction"] = (
az.extract(self.pre_pred, group="posterior_predictive", var_names="mu")
.mean("sample")
.isel(treated_units=0)
.values
)
post_data["prediction"] = (
az.extract(self.post_pred, group="posterior_predictive", var_names="mu")
.mean("sample")
.isel(treated_units=0)
.values
)
pre_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
hdi_pre_pred = get_hdi_to_df(
self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
).set_index(pre_data.index)
post_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
)
hdi_post_pred = get_hdi_to_df(
self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
)
# Select the single unit from the MultiIndex results
pre_data[[pred_lower_col, pred_upper_col]] = hdi_pre_pred.xs(
"unit_0", level="treated_units"
).set_index(pre_data.index)
post_data[[pred_lower_col, pred_upper_col]] = hdi_post_pred.xs(
"unit_0", level="treated_units"
).set_index(post_data.index)

pre_data["impact"] = self.pre_impact.mean(dim=["chain", "draw"]).values
post_data["impact"] = self.post_impact.mean(dim=["chain", "draw"]).values
pre_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
self.pre_impact, hdi_prob=hdi_prob
pre_data["impact"] = (
self.pre_impact.mean(dim=["chain", "draw"]).isel(treated_units=0).values
)
post_data["impact"] = (
self.post_impact.mean(dim=["chain", "draw"])
.isel(treated_units=0)
.values
)
hdi_pre_impact = get_hdi_to_df(self.pre_impact, hdi_prob=hdi_prob)
hdi_post_impact = get_hdi_to_df(self.post_impact, hdi_prob=hdi_prob)
# Select the single unit from the MultiIndex results
pre_data[[impact_lower_col, impact_upper_col]] = hdi_pre_impact.xs(
"unit_0", level="treated_units"
).set_index(pre_data.index)
post_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
self.post_impact, hdi_prob=hdi_prob
post_data[[impact_lower_col, impact_upper_col]] = hdi_post_impact.xs(
"unit_0", level="treated_units"
).set_index(post_data.index)

self.plot_data = pd.concat([pre_data, post_data])
Expand Down
16 changes: 10 additions & 6 deletions causalpy/experiments/prepostnegd.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,18 @@ def __init__(
},
)
self.y = xr.DataArray(
self.y[:, 0],
dims=["obs_ind"],
coords={"obs_ind": self.data.index},
self.y,
dims=["obs_ind", "treated_units"],
coords={"obs_ind": self.data.index, "treated_units": ["unit_0"]},
)

# fit the model to the observed (pre-intervention) data
if isinstance(self.model, PyMCModel):
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.X.shape[0])}
COORDS = {
"coeffs": self.labels,
"obs_ind": np.arange(self.X.shape[0]),
"treated_units": ["unit_0"],
}
self.model.fit(X=self.X, y=self.y, coords=COORDS)
elif isinstance(self.model, RegressorMixin):
raise NotImplementedError("Not implemented for OLS model")
Expand Down Expand Up @@ -239,7 +243,7 @@ def _bayesian_plot(
# plot posterior predictive of untreated
h_line, h_patch = plot_xY(
self.pred_xi,
self.pred_untreated["posterior_predictive"].mu,
self.pred_untreated["posterior_predictive"].mu.isel(treated_units=0),
ax=ax[0],
plot_hdi_kwargs={"color": "C0"},
label="Control group",
Expand All @@ -250,7 +254,7 @@ def _bayesian_plot(
# plot posterior predictive of treated
h_line, h_patch = plot_xY(
self.pred_xi,
self.pred_treated["posterior_predictive"].mu,
self.pred_treated["posterior_predictive"].mu.isel(treated_units=0),
ax=ax[0],
plot_hdi_kwargs={"color": "C1"},
label="Treatment group",
Expand Down
16 changes: 10 additions & 6 deletions causalpy/experiments/regression_discontinuity.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,19 @@ def __init__(
},
)
self.y = xr.DataArray(
self.y[:, 0],
dims=["obs_ind"],
coords={"obs_ind": np.arange(self.y.shape[0])},
self.y,
dims=["obs_ind", "treated_units"],
coords={"obs_ind": np.arange(self.y.shape[0]), "treated_units": ["unit_0"]},
)

# fit model
if isinstance(self.model, PyMCModel):
# fit the model to the observed (pre-intervention) data
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.X.shape[0])}
COORDS = {
"coeffs": self.labels,
"obs_ind": np.arange(self.X.shape[0]),
"treated_units": ["unit_0"],
}
self.model.fit(X=self.X, y=self.y, coords=COORDS)
elif isinstance(self.model, RegressorMixin):
self.model.fit(X=self.X, y=self.y)
Expand Down Expand Up @@ -248,15 +252,15 @@ def _bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]
# Plot model fit to data
h_line, h_patch = plot_xY(
self.x_pred[self.running_variable_name],
self.pred["posterior_predictive"].mu,
self.pred["posterior_predictive"].mu.isel(treated_units=0),
ax=ax,
plot_hdi_kwargs={"color": "C1"},
)
handles = [(h_line, h_patch)]
labels = ["Posterior mean"]

# create strings to compose title
title_info = f"{round_num(self.score.r2, round_to)} (std = {round_num(self.score.r2_std, round_to)})"
title_info = f"{round_num(self.score['unit_0_r2'], round_to)} (std = {round_num(self.score['unit_0_r2_std'], round_to)})"
r2 = f"Bayesian $R^2$ on all data = {title_info}"
percentiles = self.discontinuity_at_threshold.quantile([0.03, 1 - 0.03]).values
ci = (
Expand Down
Loading