Skip to content

Enable multiple treated units in synthetic control quasi experiments #494

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
59c5b6e
initial commit - utilising vibe coding
drbenvincent Jun 28, 2025
26691ec
update multi cell geolift notebook with new functionality/API
drbenvincent Jun 28, 2025
4fa1650
add test that the r2 scores differ across treated units
drbenvincent Jun 28, 2025
7b473af
simplify plot code
drbenvincent Jun 28, 2025
be01357
revert to simpler plot titles
drbenvincent Jun 28, 2025
47e0a2d
rename primary_unit_name -> treated_unit + revert to no "Unit" title
drbenvincent Jun 28, 2025
ee8a92b
revert change in causal impact shaded region colour
drbenvincent Jun 28, 2025
b79743f
code simplifications by always having a treated_units dimension
drbenvincent Jun 28, 2025
aa9920a
code simplification relating to _get_score_title
drbenvincent Jun 28, 2025
ebddbb5
another code simplification - related to scoring
drbenvincent Jun 28, 2025
eac1ef3
code simplification in _ols_plot
drbenvincent Jun 28, 2025
1a5f9bd
code simplification related to PyMCModel._data_setter
drbenvincent Jun 28, 2025
e67a28c
add sphinx-togglebutton for a collapsible admonition + other updates …
drbenvincent Jun 29, 2025
d0fc0d3
update uml diagrams
drbenvincent Jun 29, 2025
b6f5ca8
PyMCModel.score always to get xr.DataArray arguments
drbenvincent Jun 29, 2025
8badc05
simplification to PyMC.print_coefficients
drbenvincent Jun 29, 2025
4a78a50
simplification of WeightedSumFitter.build_model
drbenvincent Jun 29, 2025
f1849b1
clean up PyMCModel.predict + PyMCModel._data_setter
drbenvincent Jun 29, 2025
a55f97d
remove a numerical index in favour of a named dimension
drbenvincent Jun 29, 2025
6341e53
make code comment more specific
drbenvincent Jun 29, 2025
2befccb
consolidate tests, fix doctest
drbenvincent Jun 29, 2025
78be544
Merge branch 'main' into multi-cell-geolift
drbenvincent Jun 29, 2025
37e8de7
set fixture scope to module
drbenvincent Jun 30, 2025
d0c520f
towards a more unified scoring (r2) approach
drbenvincent Jun 30, 2025
3ee430e
more unification with score (r2) in terms of unified naming: unit_{n}_r2
drbenvincent Jun 30, 2025
adf04f9
refactor PyMCModel.score
drbenvincent Jun 30, 2025
25608ef
the grand simplification
drbenvincent Jun 30, 2025
2fba338
tweak docstring
drbenvincent Jun 30, 2025
4e63dfa
Merge branch 'main' into multi-cell-geolift
drbenvincent Jul 5, 2025
04ace83
update class diagram
drbenvincent Jul 5, 2025
192327d
re-run relevant notebooks
drbenvincent Jul 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 148 additions & 61 deletions causalpy/experiments/synthetic_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@
# make constructing the xarray DataArray objects easier.
self.datapre_control = xr.DataArray(
self.datapre[self.control_units],
dims=["obs_ind", "control_units"],
dims=["obs_ind", "coeffs"],
coords={
"obs_ind": self.datapre[self.control_units].index,
"control_units": self.control_units,
"coeffs": self.control_units,
},
)
self.datapre_treated = xr.DataArray(
Expand All @@ -116,10 +116,10 @@
)
self.datapost_control = xr.DataArray(
self.datapost[self.control_units],
dims=["obs_ind", "control_units"],
dims=["obs_ind", "coeffs"],
coords={
"obs_ind": self.datapost[self.control_units].index,
"control_units": self.control_units,
"coeffs": self.control_units,
},
)
self.datapost_treated = xr.DataArray(
Expand Down Expand Up @@ -156,8 +156,8 @@

# score the goodness of fit to the pre-intervention data
self.score = self.model.score(
X=self.datapre_control.to_numpy(),
y=self.datapre_treated.isel(treated_units=0).to_numpy(),
X=self.datapre_control,
y=self.datapre_treated,
)

# get the model predictions of the observed (pre-intervention) data
Expand Down Expand Up @@ -207,84 +207,105 @@
self.print_coefficients(round_to)

def _bayesian_plot(
self, round_to=None, **kwargs
self, round_to=None, treated_unit: str | None = None, **kwargs
) -> tuple[plt.Figure, List[plt.Axes]]:
"""
Plot the results
Plot the results for a specific treated unit

:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
:param treated_unit:
Which treated unit to plot. Must be a string name of the treated unit.
If None, plots the first treated unit.
"""
counterfactual_label = "Counterfactual"

fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
# TOP PLOT --------------------------------------------------
# pre-intervention period

# Get treated unit name - default to first unit if None
treated_unit = (
treated_unit if treated_unit is not None else self.treated_units[0]
)

if treated_unit not in self.treated_units:
raise ValueError(
f"treated_unit '{treated_unit}' not found. Available units: {self.treated_units}"
)

pre_pred_plot = self.pre_pred["posterior_predictive"].mu.sel(
treated_units=treated_unit
)
post_pred_plot = self.post_pred["posterior_predictive"].mu.sel(
treated_units=treated_unit
)

h_line, h_patch = plot_xY(
self.datapre.index,
self.pre_pred["posterior_predictive"].mu,
pre_pred_plot,
ax=ax[0],
plot_hdi_kwargs={"color": "C0"},
)
handles = [(h_line, h_patch)]
labels = ["Pre-intervention period"]

# Plot observations for primary treated unit
(h,) = ax[0].plot(
self.datapre.index, self.datapre_treated, "k.", label="Observations"
self.datapre.index,
self.datapre_treated.sel(treated_units=treated_unit),
"k.",
label="Observations",
)
handles.append(h)
labels.append("Observations")

# post intervention period
h_line, h_patch = plot_xY(
self.datapost.index,
self.post_pred["posterior_predictive"].mu,
post_pred_plot,
ax=ax[0],
plot_hdi_kwargs={"color": "C1"},
)
handles.append((h_line, h_patch))
labels.append(counterfactual_label)

ax[0].plot(self.datapost.index, self.datapost_treated, "k.")
# Shaded causal effect
ax[0].plot(
self.datapost.index,
self.datapost_treated.sel(treated_units=treated_unit),
"k.",
)
# Shaded causal effect for primary treated unit
h = ax[0].fill_between(
self.datapost.index,
y1=az.extract(
self.post_pred, group="posterior_predictive", var_names="mu"
).mean("sample"),
y2=np.squeeze(self.datapost_treated),
y1=post_pred_plot.mean(dim=["chain", "draw"]).values,
y2=self.datapost_treated.sel(treated_units=treated_unit).values,
color="C0",
alpha=0.25,
label="Causal impact",
)
handles.append(h)
labels.append("Causal impact")

ax[0].set(
title=f"""
Pre-intervention Bayesian $R^2$: {round_num(self.score.r2, round_to)}
(std = {round_num(self.score.r2_std, round_to)})
"""
)
ax[0].set(title=f"{self._get_score_title(treated_unit, round_to)}")

# MIDDLE PLOT -----------------------------------------------
plot_xY(
self.datapre.index,
self.pre_impact.sel(treated_units=self.treated_units[0]),
self.pre_impact.sel(treated_units=treated_unit),
ax=ax[1],
plot_hdi_kwargs={"color": "C0"},
)
plot_xY(
self.datapost.index,
self.post_impact.sel(treated_units=self.treated_units[0]),
self.post_impact.sel(treated_units=treated_unit),
ax=ax[1],
plot_hdi_kwargs={"color": "C1"},
)
ax[1].axhline(y=0, c="k")
ax[1].fill_between(
self.datapost.index,
y1=self.post_impact.mean(["chain", "draw"]).sel(
treated_units=self.treated_units[0]
),
y1=self.post_impact.mean(["chain", "draw"]).sel(treated_units=treated_unit),
color="C0",
alpha=0.25,
label="Causal impact",
Expand All @@ -295,7 +316,7 @@
ax[2].set(title="Cumulative Causal Impact")
plot_xY(
self.datapost.index,
self.post_impact_cumulative.sel(treated_units=self.treated_units[0]),
self.post_impact_cumulative.sel(treated_units=treated_unit),
ax=ax[2],
plot_hdi_kwargs={"color": "C1"},
)
Expand Down Expand Up @@ -336,25 +357,40 @@

return fig, ax

def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]:
def _ols_plot(
self, round_to=None, treated_unit: str | None = None, **kwargs
) -> tuple[plt.Figure, List[plt.Axes]]:
"""
Plot the results
Plot the results for OLS model for a specific treated unit

:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
:param treated_unit:
Which treated unit to plot. Must be a string name of the treated unit.
If None, plots the first treated unit.
"""
counterfactual_label = "Counterfactual"

# Get treated unit name - default to first unit if None
treated_unit = (
treated_unit if treated_unit is not None else self.treated_units[0]
)

if treated_unit not in self.treated_units:
raise ValueError(

Check warning on line 380 in causalpy/experiments/synthetic_control.py

View check run for this annotation

Codecov / codecov/patch

causalpy/experiments/synthetic_control.py#L380

Added line #L380 was not covered by tests
f"treated_unit '{treated_unit}' not found. Available units: {self.treated_units}"
)

fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))

ax[0].plot(
self.datapre_treated["obs_ind"],
self.datapre_treated.isel(treated_units=0),
self.datapre_treated.sel(treated_units=treated_unit),
"k.",
)
ax[0].plot(
self.datapost_treated["obs_ind"],
self.datapost_treated.isel(treated_units=0),
self.datapost_treated.sel(treated_units=treated_unit),
"k.",
)

Expand All @@ -366,14 +402,14 @@
ls=":",
c="k",
)
ax[0].set(
title=f"$R^2$ on pre-intervention data = {round_num(self.score, round_to)}"
)
ax[0].set(title=f"{self._get_score_title(treated_unit, round_to)}")
# Shaded causal effect
post_pred_values = np.squeeze(self.post_pred)

ax[0].fill_between(
self.datapost.index,
y1=np.squeeze(self.post_pred),
y2=np.squeeze(self.datapost_treated.isel(treated_units=0).data),
y1=post_pred_values,
y2=np.squeeze(self.datapost_treated.sel(treated_units=treated_unit).data),
color="C0",
alpha=0.25,
label="Causal impact",
Expand Down Expand Up @@ -431,12 +467,17 @@

return self.plot_data

def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
def get_plot_data_bayesian(
self, hdi_prob: float = 0.94, treated_unit: str | None = None
) -> pd.DataFrame:
"""
Recover the data of the PrePostFit experiment along with the prediction and causal impact information.

:param hdi_prob:
Prob for which the highest density interval will be computed. The default value is defined as the default from the :func:`arviz.hdi` function.
:param treated_unit:
Which treated unit to extract data for. Must be a string name of the treated unit.
If None, uses the first treated unit.
"""
if not isinstance(self.model, PyMCModel):
raise ValueError("Unsupported model type")
Expand All @@ -451,36 +492,82 @@
pre_data = self.datapre.copy()
post_data = self.datapost.copy()

pre_data["prediction"] = (
az.extract(self.pre_pred, group="posterior_predictive", var_names="mu")
.mean("sample")
.values
# Get treated unit name - default to first unit if None
treated_unit = (
treated_unit if treated_unit is not None else self.treated_units[0]
)
post_data["prediction"] = (
az.extract(self.post_pred, group="posterior_predictive", var_names="mu")
.mean("sample")
.values

if treated_unit not in self.treated_units:
raise ValueError(
f"treated_unit '{treated_unit}' not found. Available units: {self.treated_units}"
)

# Extract predictions - handle multi-unit case
pre_pred_vals = az.extract(
self.pre_pred, group="posterior_predictive", var_names="mu"
).mean("sample")
post_pred_vals = az.extract(
self.post_pred, group="posterior_predictive", var_names="mu"
).mean("sample")

# Extract predictions for the specified treated unit (always has treated_units dimension)
pre_data["prediction"] = pre_pred_vals.sel(treated_units=treated_unit).values
post_data["prediction"] = post_pred_vals.sel(treated_units=treated_unit).values

# HDI intervals for predictions (always use treated_units dimension)
pre_hdi = get_hdi_to_df(
self.pre_pred["posterior_predictive"].mu.sel(treated_units=treated_unit),
hdi_prob=hdi_prob,
)
pre_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
).set_index(pre_data.index)
post_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
).set_index(post_data.index)
post_hdi = get_hdi_to_df(
self.post_pred["posterior_predictive"].mu.sel(treated_units=treated_unit),
hdi_prob=hdi_prob,
)

# Extract only the lower and upper columns and ensure proper indexing
pre_lower_upper = pre_hdi.iloc[:, [0, -1]].values # Get first and last columns
post_lower_upper = post_hdi.iloc[:, [0, -1]].values

pre_data[[pred_lower_col, pred_upper_col]] = pre_lower_upper
post_data[[pred_lower_col, pred_upper_col]] = post_lower_upper

# Impact data - always use primary unit for main dataframe
pre_data["impact"] = (
self.pre_impact.mean(dim=["chain", "draw"]).isel(treated_units=0).values
self.pre_impact.mean(dim=["chain", "draw"])
.sel(treated_units=treated_unit)
.values
)
post_data["impact"] = (
self.post_impact.mean(dim=["chain", "draw"]).isel(treated_units=0).values
self.post_impact.mean(dim=["chain", "draw"])
.sel(treated_units=treated_unit)
.values
)
# Impact HDI intervals (always use treated_units dimension)
pre_impact_hdi = get_hdi_to_df(
self.pre_impact.sel(treated_units=treated_unit), hdi_prob=hdi_prob
)
post_impact_hdi = get_hdi_to_df(
self.post_impact.sel(treated_units=treated_unit), hdi_prob=hdi_prob
)
pre_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
self.pre_impact, hdi_prob=hdi_prob
).set_index(pre_data.index)
post_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
self.post_impact, hdi_prob=hdi_prob
).set_index(post_data.index)

# Extract only the lower and upper columns for impact HDI
pre_impact_lower_upper = pre_impact_hdi.iloc[:, [0, -1]].values
post_impact_lower_upper = post_impact_hdi.iloc[:, [0, -1]].values

pre_data[[impact_lower_col, impact_upper_col]] = pre_impact_lower_upper
post_data[[impact_lower_col, impact_upper_col]] = post_impact_lower_upper

self.plot_data = pd.concat([pre_data, post_data])

return self.plot_data

def _get_score_title(self, treated_unit: str, round_to=None):
"""Generate appropriate score title for the specified treated unit"""
if isinstance(self.model, PyMCModel):
# Bayesian model - get unit-specific R² scores
r2_val = round_num(self.score[f"{treated_unit}_r2"], round_to)
r2_std_val = round_num(self.score[f"{treated_unit}_r2_std"], round_to)
return f"Pre-intervention Bayesian $R^2$: {r2_val} (std = {r2_std_val})"
else:
# OLS model - simple float score
return f"$R^2$ on pre-intervention data = {round_num(self.score, round_to)}"
Loading