Skip to content

Commit 2916688

Browse files
authored
Merge pull request #355 from pymc-labs/summary
Enable `summary` method for all currently implemented frequentist experiments
2 parents 4af4af6 + 9fc0798 commit 2916688

File tree

11 files changed

+338
-86
lines changed

11 files changed

+338
-86
lines changed

causalpy/pymc_experiments.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,9 @@ def __init__(
202202
self.post_X = np.asarray(new_x)
203203
self.post_y = np.asarray(new_y)
204204

205-
# DEVIATION FROM SKL EXPERIMENT CODE =============================
206205
# fit the model to the observed (pre-intervention) data
207206
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.pre_X.shape[0])}
208207
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
209-
# ================================================================
210208

211209
# score the goodness of fit to the pre-intervention data
212210
self.score = self.model.score(X=self.pre_X, y=self.pre_y)
@@ -347,7 +345,6 @@ def summary(self, round_to=None) -> None:
347345

348346
print(f"{self.expt_type:=^80}")
349347
print(f"Formula: {self.formula}")
350-
# TODO: extra experiment specific outputs here
351348
self.print_coefficients(round_to)
352349

353350

causalpy/skl_experiments.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class ExperimentalDesign:
4545
"""Base class for experiment designs"""
4646

4747
model = None
48+
expt_type = None
4849
outcome_variable_name = None
4950

5051
def __init__(self, model=None, **kwargs):
@@ -53,6 +54,24 @@ def __init__(self, model=None, **kwargs):
5354
if self.model is None:
5455
raise ValueError("fitting_model not set or passed.")
5556

57+
def print_coefficients(self, round_to=None) -> None:
58+
"""
59+
Prints the model coefficients
60+
61+
:param round_to:
62+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
63+
"""
64+
print("Model coefficients:")
65+
# Determine the width of the longest label
66+
max_label_length = max(len(name) for name in self.labels)
67+
# Print each coefficient with formatted alignment
68+
for name, val in zip(self.labels, self.model.coef_[0]):
69+
# Left-align the name
70+
formatted_name = f"{name:<{max_label_length}}"
71+
# Right-align the value with width 10
72+
formatted_val = f"{round_num(val, round_to):>10}"
73+
print(f" {formatted_name}\t{formatted_val}")
74+
5675

5776
class PrePostFit(ExperimentalDesign, PrePostFitDataValidator):
5877
"""
@@ -95,6 +114,8 @@ def __init__(
95114
super().__init__(model=model, **kwargs)
96115
self._input_validation(data, treatment_time)
97116
self.treatment_time = treatment_time
117+
# set experiment type - usually done in subclasses
118+
self.expt_type = "Pre-Post Fit"
98119
# split data in to pre and post intervention
99120
self.datapre = data[data.index < self.treatment_time]
100121
self.datapost = data[data.index >= self.treatment_time]
@@ -103,10 +124,10 @@ def __init__(
103124

104125
# set things up with pre-intervention data
105126
y, X = dmatrices(formula, self.datapre)
127+
self.outcome_variable_name = y.design_info.column_names[0]
106128
self._y_design_info = y.design_info
107129
self._x_design_info = X.design_info
108130
self.labels = X.design_info.column_names
109-
self.outcome_variable_name = y.design_info.column_names[0]
110131
self.pre_y, self.pre_X = np.asarray(y), np.asarray(X)
111132
# process post-intervention data
112133
(new_y, new_x) = build_design_matrices(
@@ -222,6 +243,18 @@ def plot_coeffs(self):
222243
palette=sns.color_palette("husl"),
223244
)
224245

246+
def summary(self, round_to=None) -> None:
247+
"""
248+
Print text output summarising the results
249+
250+
:param round_to:
251+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
252+
"""
253+
254+
print(f"{self.expt_type:=^80}")
255+
print(f"Formula: {self.formula}")
256+
self.print_coefficients(round_to)
257+
225258

226259
class InterruptedTimeSeries(PrePostFit):
227260
"""
@@ -253,7 +286,6 @@ class InterruptedTimeSeries(PrePostFit):
253286
... formula="y ~ 1 + t + C(month)",
254287
... model = LinearRegression()
255288
... )
256-
257289
"""
258290

259291
expt_type = "Interrupted Time Series"
@@ -351,6 +383,7 @@ def __init__(
351383
):
352384
super().__init__(model=model, **kwargs)
353385
self.data = data
386+
self.expt_type = "Difference in Differences"
354387
self.formula = formula
355388
self.time_variable_name = time_variable_name
356389
self.group_variable_name = group_variable_name
@@ -509,6 +542,20 @@ def plot(self, round_to=None):
509542
ax.legend(fontsize=LEGEND_FONT_SIZE)
510543
return (fig, ax)
511544

545+
def summary(self, round_to=None) -> None:
546+
"""
547+
Print text output summarising the results.
548+
549+
:param round_to:
550+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
551+
"""
552+
553+
print(f"{self.expt_type:=^80}")
554+
print(f"Formula: {self.formula}")
555+
print("\nResults:")
556+
print(f"Causal impact = {round_num(self.causal_impact[0], round_to)}")
557+
self.print_coefficients(round_to)
558+
512559

513560
class RegressionDiscontinuity(ExperimentalDesign, RDDataValidator):
514561
"""
@@ -542,17 +589,6 @@ class RegressionDiscontinuity(ExperimentalDesign, RDDataValidator):
542589
... model=LinearRegression(),
543590
... treatment_threshold=0.5,
544591
... )
545-
>>> result.summary() # doctest: +NORMALIZE_WHITESPACE,+NUMBER
546-
Difference in Differences experiment
547-
Formula: y ~ 1 + x + treated
548-
Running variable: x
549-
Threshold on running variable: 0.5
550-
Results:
551-
Discontinuity at threshold = 0.19
552-
Model coefficients:
553-
Intercept 0.0
554-
treated[T.True] 0.19
555-
x 1.23
556592
"""
557593

558594
def __init__(
@@ -687,16 +723,18 @@ def plot(self, round_to=None):
687723
ax.legend(fontsize=LEGEND_FONT_SIZE)
688724
return (fig, ax)
689725

690-
def summary(self):
726+
def summary(self, round_to=None) -> None:
691727
"""
692728
Print text output summarising the results
729+
730+
:param round_to:
731+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
693732
"""
694733
print("Difference in Differences experiment")
695734
print(f"Formula: {self.formula}")
696735
print(f"Running variable: {self.running_variable_name}")
697736
print(f"Threshold on running variable: {self.treatment_threshold}")
698737
print("\nResults:")
699738
print(f"Discontinuity at threshold = {self.discontinuity_at_threshold:.2f}")
700-
print("Model coefficients:")
701-
for name, val in zip(self.labels, self.model.coef_[0]):
702-
print(f"\t{name}\t\t{val}")
739+
print("\n")
740+
self.print_coefficients(round_to)

causalpy/tests/test_integration_skl_examples.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def test_did():
4242
)
4343
assert isinstance(data, pd.DataFrame)
4444
assert isinstance(result, cp.skl_experiments.DifferenceInDifferences)
45+
result.summary()
4546

4647

4748
@pytest.mark.integration
@@ -68,6 +69,7 @@ def test_rd_drinking():
6869
)
6970
assert isinstance(df, pd.DataFrame)
7071
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
72+
result.summary()
7173

7274

7375
@pytest.mark.integration
@@ -94,6 +96,7 @@ def test_its():
9496
)
9597
assert isinstance(df, pd.DataFrame)
9698
assert isinstance(result, cp.skl_experiments.SyntheticControl)
99+
result.summary()
97100

98101

99102
@pytest.mark.integration
@@ -115,6 +118,7 @@ def test_sc():
115118
)
116119
assert isinstance(df, pd.DataFrame)
117120
assert isinstance(result, cp.skl_experiments.SyntheticControl)
121+
result.summary()
118122

119123

120124
@pytest.mark.integration
@@ -136,6 +140,7 @@ def test_rd_linear_main_effects():
136140
)
137141
assert isinstance(data, pd.DataFrame)
138142
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
143+
result.summary()
139144

140145

141146
@pytest.mark.integration
@@ -159,6 +164,7 @@ def test_rd_linear_main_effects_bandwidth():
159164
)
160165
assert isinstance(data, pd.DataFrame)
161166
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
167+
result.summary()
162168

163169

164170
@pytest.mark.integration
@@ -180,6 +186,7 @@ def test_rd_linear_with_interaction():
180186
)
181187
assert isinstance(data, pd.DataFrame)
182188
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
189+
result.summary()
183190

184191

185192
@pytest.mark.integration

docs/source/_static/classes.png

11.3 KB
Loading

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

docs/source/_static/packages.png

546 Bytes
Loading

docs/source/notebooks/did_skl.ipynb

Lines changed: 29 additions & 3 deletions
Large diffs are not rendered by default.

docs/source/notebooks/its_skl.ipynb

Lines changed: 31 additions & 14 deletions
Large diffs are not rendered by default.

docs/source/notebooks/rd_skl.ipynb

Lines changed: 110 additions & 24 deletions
Large diffs are not rendered by default.

docs/source/notebooks/rd_skl_drinking.ipynb

Lines changed: 8 additions & 14 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)