Skip to content

Commit 2f45252

Browse files
committed
add round_to kwarg to text output functions
1 parent 70de921 commit 2f45252

File tree

1 file changed

+62
-46
lines changed

1 file changed

+62
-46
lines changed

causalpy/pymc_experiments.py

Lines changed: 62 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,13 @@ def idata(self):
5959

6060
return self.model.idata
6161

62-
def print_coefficients(self) -> None:
62+
def print_coefficients(self, round_to=None) -> None:
6363
"""
6464
Prints the model coefficients
6565
66+
:param round_to:
67+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
68+
6669
Example
6770
--------
6871
>>> import causalpy as cp
@@ -80,13 +83,13 @@ def print_coefficients(self) -> None:
8083
... "progressbar": False
8184
... }),
8285
... )
83-
>>> result.print_coefficients() # doctest: +NUMBER
86+
>>> result.print_coefficients(round_to=1) # doctest: +NUMBER
8487
Model coefficients:
85-
Intercept 1.0, 94% HDI [1.0, 1.1]
86-
post_treatment[T.True] 0.9, 94% HDI [0.9, 1.0]
87-
group 0.1, 94% HDI [0.0, 0.2]
88+
Intercept 1, 94% HDI [1, 1]
89+
post_treatment[T.True] 1, 94% HDI [0.9, 1]
90+
group 0.2, 94% HDI [0.09, 0.2]
8891
group:post_treatment[T.True] 0.5, 94% HDI [0.4, 0.6]
89-
sigma 0.0, 94% HDI [0.0, 0.1]
92+
sigma 0.08, 94% HDI [0.07, 0.1]
9093
"""
9194
print("Model coefficients:")
9295
coeffs = az.extract(self.idata.posterior, var_names="beta")
@@ -95,13 +98,13 @@ def print_coefficients(self) -> None:
9598
for name in self.labels:
9699
coeff_samples = coeffs.sel(coeffs=name)
97100
print(
98-
f"{name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]" # noqa: E501
101+
f"{name: <30}{round_num(coeff_samples.mean().data, round_to)}, 94% HDI [{round_num(coeff_samples.quantile(0.03).data, round_to)}, {round_num(coeff_samples.quantile(1-0.03).data, round_to)}]" # noqa: E501
99102
)
100103
# add coeff for measurement std
101104
coeff_samples = az.extract(self.model.idata.posterior, var_names="sigma")
102105
name = "sigma"
103106
print(
104-
f"{name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]" # noqa: E501
107+
f"{name: <30}{round_num(coeff_samples.mean().data, round_to)}, 94% HDI [{round_num(coeff_samples.quantile(0.03).data, round_to)}, {round_num(coeff_samples.quantile(1-0.03).data, round_to)}]" # noqa: E501
105108
)
106109

107110

@@ -138,18 +141,18 @@ class PrePostFit(ExperimentalDesign):
138141
... }
139142
... ),
140143
... )
141-
>>> result.summary() # doctest: +NUMBER
144+
>>> result.summary(round_to=1) # doctest: +NUMBER
142145
==================================Pre-Post Fit==================================
143146
Formula: actual ~ 0 + a + b + c + d + e + f + g
144147
Model coefficients:
145-
a 0.3, 94% HDI [0.3, 0.3]
146-
b 0.0, 94% HDI [0.0, 0.0]
147-
c 0.3, 94% HDI [0.2, 0.3]
148-
d 0.0, 94% HDI [0.0, 0.1]
149-
e 0.0, 94% HDI [0.0, 0.0]
150-
f 0.1, 94% HDI [0.1, 0.2]
151-
g 0.0, 94% HDI [0.0, 0.0]
152-
sigma 0.2, 94% HDI [0.2, 0.3]
148+
a 0.3, 94% HDI [0.3, 0.4]
149+
b 0.05, 94% HDI [0.009, 0.09]
150+
c 0.3, 94% HDI [0.3, 0.3]
151+
d 0.05, 94% HDI [0.01, 0.1]
152+
e 0.03, 94% HDI [0.001, 0.07]
153+
f 0.2, 94% HDI [0.1, 0.3]
154+
g 0.04, 94% HDI [0.003, 0.09]
155+
sigma 0.3, 94% HDI [0.2, 0.3]
153156
"""
154157

155158
def __init__(
@@ -336,15 +339,18 @@ def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs):
336339

337340
return fig, ax
338341

339-
def summary(self) -> None:
342+
def summary(self, round_to=None) -> None:
340343
"""
341344
Print text output summarising the results
345+
346+
:param round_to:
347+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
342348
"""
343349

344350
print(f"{self.expt_type:=^80}")
345351
print(f"Formula: {self.formula}")
346352
# TODO: extra experiment specific outputs here
347-
self.print_coefficients()
353+
self.print_coefficients(round_to)
348354

349355

350356
class InterruptedTimeSeries(PrePostFit):
@@ -733,17 +739,19 @@ def _causal_impact_summary_stat(self, round_to=None) -> str:
733739
causal_impact = f"{round_num(self.causal_impact.mean(), round_to)}, "
734740
return f"Causal impact = {causal_impact + ci}"
735741

736-
def summary(self) -> None:
742+
def summary(self, round_to=None) -> None:
737743
"""
738-
Print text output summarising the results
744+
Print text output summarising the results.
745+
746+
:param round_to:
747+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
739748
"""
740749

741750
print(f"{self.expt_type:=^80}")
742751
print(f"Formula: {self.formula}")
743752
print("\nResults:")
744-
# TODO: extra experiment specific outputs here
745-
print(self._causal_impact_summary_stat())
746-
self.print_coefficients()
753+
print(round_num(self._causal_impact_summary_stat(), round_to))
754+
self.print_coefficients(round_to)
747755

748756

749757
class RegressionDiscontinuity(ExperimentalDesign):
@@ -785,20 +793,20 @@ class RegressionDiscontinuity(ExperimentalDesign):
785793
... ),
786794
... treatment_threshold=0.5,
787795
... )
788-
>>> result.summary() # doctest: +NUMBER
796+
>>> result.summary(round_to=1) # doctest: +NUMBER
789797
============================Regression Discontinuity============================
790798
Formula: y ~ 1 + x + treated + x:treated
791799
Running variable: x
792800
Threshold on running variable: 0.5
793801
<BLANKLINE>
794802
Results:
795-
Discontinuity at threshold = 0.91
803+
Discontinuity at threshold = 0.9
796804
Model coefficients:
797-
Intercept 0.0, 94% HDI [0.0, 0.1]
798-
treated[T.True] 2.4, 94% HDI [1.6, 3.2]
799-
x 1.3, 94% HDI [1.1, 1.5]
800-
x:treated[T.True] -3.0, 94% HDI [-4.1, -2.0]
801-
sigma 0.3, 94% HDI [0.3, 0.4]
805+
Intercept 0.09, 94% HDI [-0.001, 0.2]
806+
treated[T.True] 2, 94% HDI [2, 3]
807+
x 1, 94% HDI [1, 2]
808+
x:treated[T.True] -3, 94% HDI [-4, -2]
809+
sigma 0.4, 94% HDI [0.3, 0.4]
802810
"""
803811

804812
def __init__(
@@ -962,9 +970,12 @@ def plot(self, round_to=None):
962970
)
963971
return fig, ax
964972

965-
def summary(self) -> None:
973+
def summary(self, round_to: None) -> None:
966974
"""
967975
Print text output summarising the results
976+
977+
:param round_to:
978+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
968979
"""
969980

970981
print(f"{self.expt_type:=^80}")
@@ -973,9 +984,9 @@ def summary(self) -> None:
973984
print(f"Threshold on running variable: {self.treatment_threshold}")
974985
print("\nResults:")
975986
print(
976-
f"Discontinuity at threshold = {self.discontinuity_at_threshold.mean():.2f}"
987+
f"Discontinuity at threshold = {round_num(self.discontinuity_at_threshold.mean(), round_to)}"
977988
)
978-
self.print_coefficients()
989+
self.print_coefficients(round_to)
979990

980991

981992
class RegressionKink(ExperimentalDesign):
@@ -1179,9 +1190,12 @@ def plot(self, round_to=None):
11791190
)
11801191
return fig, ax
11811192

1182-
def summary(self) -> None:
1193+
def summary(self, round_to=None) -> None:
11831194
"""
11841195
Print text output summarising the results
1196+
1197+
:param round_to:
1198+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
11851199
"""
11861200

11871201
print(
@@ -1192,10 +1206,10 @@ def summary(self) -> None:
11921206
Kink point on running variable: {self.kink_point}
11931207
11941208
Results:
1195-
Change in slope at kink point = {self.gradient_change.mean():.2f}
1209+
Change in slope at kink point = {round_num(self.gradient_change.mean(), round_to)}
11961210
"""
11971211
)
1198-
self.print_coefficients()
1212+
self.print_coefficients(round_to)
11991213

12001214

12011215
class PrePostNEGD(ExperimentalDesign):
@@ -1232,17 +1246,17 @@ class PrePostNEGD(ExperimentalDesign):
12321246
... }
12331247
... )
12341248
... )
1235-
>>> result.summary() # doctest: +NUMBER
1249+
>>> result.summary(round_to=1) # doctest: +NUMBER
12361250
==================Pretest/posttest Nonequivalent Group Design===================
12371251
Formula: post ~ 1 + C(group) + pre
12381252
<BLANKLINE>
12391253
Results:
1240-
Causal impact = 1.8, $CI_{94%}$[1.7, 2.1]
1254+
Causal impact = 2, $CI_{94%}$[2, 2]
12411255
Model coefficients:
1242-
Intercept -0.4, 94% HDI [-1.1, 0.2]
1243-
C(group)[T.1] 1.8, 94% HDI [1.6, 2.0]
1244-
pre 1.0, 94% HDI [0.9, 1.1]
1245-
sigma 0.5, 94% HDI [0.4, 0.5]
1256+
Intercept -0.5, 94% HDI [-1, 0.2]
1257+
C(group)[T.1] 2, 94% HDI [2, 2]
1258+
pre 1, 94% HDI [1, 1]
1259+
sigma 0.5, 94% HDI [0.5, 0.6]
12461260
"""
12471261

12481262
def __init__(
@@ -1381,20 +1395,23 @@ def _causal_impact_summary_stat(self, round_to) -> str:
13811395
r"$CI_{94%}$"
13821396
+ f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]"
13831397
)
1384-
causal_impact = f"{self.causal_impact.mean():.2f}, "
1398+
causal_impact = f"{round_num(self.causal_impact.mean(), round_to)}, "
13851399
return f"Causal impact = {causal_impact + ci}"
13861400

13871401
def summary(self, round_to=None) -> None:
13881402
"""
13891403
Print text output summarising the results
1404+
1405+
:param round_to:
1406+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
13901407
"""
13911408

13921409
print(f"{self.expt_type:=^80}")
13931410
print(f"Formula: {self.formula}")
13941411
print("\nResults:")
13951412
# TODO: extra experiment specific outputs here
13961413
print(self._causal_impact_summary_stat(round_to))
1397-
self.print_coefficients()
1414+
self.print_coefficients(round_to)
13981415

13991416
def _get_treatment_effect_coeff(self) -> str:
14001417
"""Find the beta regression coefficient corresponding to the
@@ -1471,7 +1488,6 @@ class InstrumentalVariable(ExperimentalDesign):
14711488
... formula=formula,
14721489
... model=InstrumentalVariableRegression(sample_kwargs=sample_kwargs),
14731490
... )
1474-
14751491
"""
14761492

14771493
def __init__(

0 commit comments

Comments
 (0)