Skip to content

Commit 75cfd96

Browse files
committed
summary for skl models
1 parent e092650 commit 75cfd96

File tree

8 files changed

+325
-74
lines changed

8 files changed

+325
-74
lines changed

causalpy/pymc_experiments.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,9 @@ def __init__(
202202
self.post_X = np.asarray(new_x)
203203
self.post_y = np.asarray(new_y)
204204

205-
# DEVIATION FROM SKL EXPERIMENT CODE =============================
206205
# fit the model to the observed (pre-intervention) data
207206
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.pre_X.shape[0])}
208207
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
209-
# ================================================================
210208

211209
# score the goodness of fit to the pre-intervention data
212210
self.score = self.model.score(X=self.pre_X, y=self.pre_y)
@@ -347,7 +345,6 @@ def summary(self, round_to=None) -> None:
347345

348346
print(f"{self.expt_type:=^80}")
349347
print(f"Formula: {self.formula}")
350-
# TODO: extra experiment specific outputs here
351348
self.print_coefficients(round_to)
352349

353350

causalpy/skl_experiments.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class ExperimentalDesign:
4545
"""Base class for experiment designs"""
4646

4747
model = None
48+
expt_type = None
4849
outcome_variable_name = None
4950

5051
def __init__(self, model=None, **kwargs):
@@ -53,6 +54,17 @@ def __init__(self, model=None, **kwargs):
5354
if self.model is None:
5455
raise ValueError("fitting_model not set or passed.")
5556

57+
def print_coefficients(self, round_to=None) -> None:
58+
"""
59+
Prints the model coefficients
60+
61+
:param round_to:
62+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
63+
"""
64+
print("Model coefficients:")
65+
for name, val in zip(self.labels, self.model.coef_[0]):
66+
print(f"\t{name}\t\t{round_num(val, round_to)}")
67+
5668

5769
class PrePostFit(ExperimentalDesign, PrePostFitDataValidator):
5870
"""
@@ -95,6 +107,8 @@ def __init__(
95107
super().__init__(model=model, **kwargs)
96108
self._input_validation(data, treatment_time)
97109
self.treatment_time = treatment_time
110+
# set experiment type - usually done in subclasses
111+
self.expt_type = "Pre-Post Fit"
98112
# split data in to pre and post intervention
99113
self.datapre = data[data.index < self.treatment_time]
100114
self.datapost = data[data.index >= self.treatment_time]
@@ -103,10 +117,10 @@ def __init__(
103117

104118
# set things up with pre-intervention data
105119
y, X = dmatrices(formula, self.datapre)
120+
self.outcome_variable_name = y.design_info.column_names[0]
106121
self._y_design_info = y.design_info
107122
self._x_design_info = X.design_info
108123
self.labels = X.design_info.column_names
109-
self.outcome_variable_name = y.design_info.column_names[0]
110124
self.pre_y, self.pre_X = np.asarray(y), np.asarray(X)
111125
# process post-intervention data
112126
(new_y, new_x) = build_design_matrices(
@@ -222,6 +236,18 @@ def plot_coeffs(self):
222236
palette=sns.color_palette("husl"),
223237
)
224238

239+
def summary(self, round_to=None) -> None:
240+
"""
241+
Print text output summarising the results
242+
243+
:param round_to:
244+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
245+
"""
246+
247+
print(f"{self.expt_type:=^80}")
248+
print(f"Formula: {self.formula}")
249+
self.print_coefficients(round_to)
250+
225251

226252
class InterruptedTimeSeries(PrePostFit):
227253
"""
@@ -351,6 +377,7 @@ def __init__(
351377
):
352378
super().__init__(model=model, **kwargs)
353379
self.data = data
380+
self.expt_type = "Difference in Differences"
354381
self.formula = formula
355382
self.time_variable_name = time_variable_name
356383
self.group_variable_name = group_variable_name
@@ -509,6 +536,24 @@ def plot(self, round_to=None):
509536
ax.legend(fontsize=LEGEND_FONT_SIZE)
510537
return (fig, ax)
511538

539+
def _causal_impact_summary_stat(self, round_to=None) -> str:
540+
causal_impact = f"{round_num(self.causal_impact, round_to)}, "
541+
return f"Causal impact = {causal_impact}"
542+
543+
def summary(self, round_to=None) -> None:
544+
"""
545+
Print text output summarising the results.
546+
547+
:param round_to:
548+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
549+
"""
550+
551+
print(f"{self.expt_type:=^80}")
552+
print(f"Formula: {self.formula}")
553+
print("\nResults:")
554+
print(f"Causal impact = {round_num(self.causal_impact[0], round_to)}")
555+
self.print_coefficients(round_to)
556+
512557

513558
class RegressionDiscontinuity(ExperimentalDesign, RDDataValidator):
514559
"""
@@ -687,16 +732,17 @@ def plot(self, round_to=None):
687732
ax.legend(fontsize=LEGEND_FONT_SIZE)
688733
return (fig, ax)
689734

690-
def summary(self):
735+
def summary(self, round_to=None):
691736
"""
692737
Print text output summarising the results
738+
739+
:param round_to:
740+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
693741
"""
694742
print("Difference in Differences experiment")
695743
print(f"Formula: {self.formula}")
696744
print(f"Running variable: {self.running_variable_name}")
697745
print(f"Threshold on running variable: {self.treatment_threshold}")
698746
print("\nResults:")
699747
print(f"Discontinuity at threshold = {self.discontinuity_at_threshold:.2f}")
700-
print("Model coefficients:")
701-
for name, val in zip(self.labels, self.model.coef_[0]):
702-
print(f"\t{name}\t\t{val}")
748+
self.print_coefficients(round_to)

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

docs/source/notebooks/did_skl.ipynb

Lines changed: 29 additions & 3 deletions
Large diffs are not rendered by default.

docs/source/notebooks/its_skl.ipynb

Lines changed: 31 additions & 14 deletions
Large diffs are not rendered by default.

docs/source/notebooks/rd_skl.ipynb

Lines changed: 110 additions & 24 deletions
Large diffs are not rendered by default.

docs/source/notebooks/rd_skl_drinking.ipynb

Lines changed: 6 additions & 14 deletions
Large diffs are not rendered by default.

docs/source/notebooks/sc_skl.ipynb

Lines changed: 95 additions & 8 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)