Skip to content

Commit 70de921

Browse files
authored
Merge pull request #272 from pymc-labs/round_to
User specified number of significant figures for numbers in plots
2 parents fc28a3b + 198bde6 commit 70de921

14 files changed

+569
-250
lines changed

causalpy/pymc_experiments.py

+64-30
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,13 @@
2323
from patsy import build_design_matrices, dmatrices
2424
from sklearn.linear_model import LinearRegression as sk_lin_reg
2525

26-
from causalpy.custom_exceptions import BadIndexException
27-
from causalpy.custom_exceptions import DataException, FormulaException
26+
from causalpy.custom_exceptions import (
27+
BadIndexException, # NOQA
28+
DataException,
29+
FormulaException,
30+
)
2831
from causalpy.plot_utils import plot_xY
29-
from causalpy.utils import _is_variable_dummy_coded
32+
from causalpy.utils import _is_variable_dummy_coded, round_num
3033

3134
LEGEND_FONT_SIZE = 12
3235
az.style.use("arviz-darkgrid")
@@ -228,9 +231,12 @@ def _input_validation(self, data, treatment_time):
228231
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
229232
)
230233

231-
def plot(self, counterfactual_label="Counterfactual", **kwargs):
234+
def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs):
232235
"""
233236
Plot the results
237+
238+
:param round_to:
239+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
234240
"""
235241
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
236242

@@ -275,8 +281,8 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs):
275281

276282
ax[0].set(
277283
title=f"""
278-
Pre-intervention Bayesian $R^2$: {self.score.r2:.3f}
279-
(std = {self.score.r2_std:.3f})
284+
Pre-intervention Bayesian $R^2$: {round_num(self.score.r2, round_to)}
285+
(std = {round_num(self.score.r2_std, round_to)})
280286
"""
281287
)
282288

@@ -416,7 +422,11 @@ class SyntheticControl(PrePostFit):
416422
expt_type = "Synthetic Control"
417423

418424
def plot(self, plot_predictors=False, **kwargs):
419-
"""Plot the results"""
425+
"""Plot the results
426+
427+
:param round_to:
428+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
429+
"""
420430
fig, ax = super().plot(counterfactual_label="Synthetic control", **kwargs)
421431
if plot_predictors:
422432
# plot control units as well
@@ -580,9 +590,11 @@ def _input_validation(self):
580590
coded. Consisting of 0's and 1's only."""
581591
)
582592

583-
def plot(self):
593+
def plot(self, round_to=None):
584594
"""Plot the results.
585-
Creating the combined mean + HDI legend entries is a bit involved.
595+
596+
:param round_to:
597+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
586598
"""
587599
fig, ax = plt.subplots()
588600

@@ -658,7 +670,7 @@ def plot(self):
658670
# formatting
659671
ax.set(
660672
xticks=self.x_pred_treatment[self.time_variable_name].values,
661-
title=self._causal_impact_summary_stat(),
673+
title=self._causal_impact_summary_stat(round_to),
662674
)
663675
ax.legend(
664676
handles=(h_tuple for h_tuple in handles),
@@ -711,11 +723,14 @@ def _plot_causal_impact_arrow(self, ax):
711723
va="center",
712724
)
713725

714-
def _causal_impact_summary_stat(self) -> str:
726+
def _causal_impact_summary_stat(self, round_to=None) -> str:
715727
"""Computes the mean and 94% credible interval bounds for the causal impact."""
716728
percentiles = self.causal_impact.quantile([0.03, 1 - 0.03]).values
717-
ci = "$CI_{94\\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
718-
causal_impact = f"{self.causal_impact.mean():.2f}, "
729+
ci = (
730+
"$CI_{94\\%}$"
731+
+ f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]"
732+
)
733+
causal_impact = f"{round_num(self.causal_impact.mean(), round_to)}, "
719734
return f"Causal impact = {causal_impact + ci}"
720735

721736
def summary(self) -> None:
@@ -893,9 +908,12 @@ def _is_treated(self, x):
893908
"""
894909
return np.greater_equal(x, self.treatment_threshold)
895910

896-
def plot(self):
911+
def plot(self, round_to=None):
897912
"""
898913
Plot the results
914+
915+
:param round_to:
916+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
899917
"""
900918
fig, ax = plt.subplots()
901919
# Plot raw data
@@ -918,12 +936,15 @@ def plot(self):
918936
labels = ["Posterior mean"]
919937

920938
# create strings to compose title
921-
title_info = f"{self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
939+
title_info = f"{round_num(self.score.r2, round_to)} (std = {round_num(self.score.r2_std, round_to)})"
922940
r2 = f"Bayesian $R^2$ on all data = {title_info}"
923941
percentiles = self.discontinuity_at_threshold.quantile([0.03, 1 - 0.03]).values
924-
ci = r"$CI_{94\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
942+
ci = (
943+
r"$CI_{94\%}$"
944+
+ f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]"
945+
)
925946
discon = f"""
926-
Discontinuity at threshold = {self.discontinuity_at_threshold.mean():.2f},
947+
Discontinuity at threshold = {round_num(self.discontinuity_at_threshold.mean(), round_to)},
927948
"""
928949
ax.set(title=r2 + "\n" + discon + ci)
929950
# Intervention line
@@ -1104,9 +1125,12 @@ def _is_treated(self, x):
11041125
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.""" # noqa: E501
11051126
return np.greater_equal(x, self.kink_point)
11061127

1107-
def plot(self):
1128+
def plot(self, round_to=None):
11081129
"""
11091130
Plot the results
1131+
1132+
:param round_to:
1133+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
11101134
"""
11111135
fig, ax = plt.subplots()
11121136
# Plot raw data
@@ -1129,12 +1153,15 @@ def plot(self):
11291153
labels = ["Posterior mean"]
11301154

11311155
# create strings to compose title
1132-
title_info = f"{self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
1156+
title_info = f"{round_num(self.score.r2, round_to)} (std = {round_num(self.score.r2_std, round_to)})"
11331157
r2 = f"Bayesian $R^2$ on all data = {title_info}"
11341158
percentiles = self.gradient_change.quantile([0.03, 1 - 0.03]).values
1135-
ci = r"$CI_{94\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
1159+
ci = (
1160+
r"$CI_{94\%}$"
1161+
+ f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]"
1162+
)
11361163
grad_change = f"""
1137-
Change in gradient = {self.gradient_change.mean():.2f},
1164+
Change in gradient = {round_num(self.gradient_change.mean(), round_to)},
11381165
"""
11391166
ax.set(title=r2 + "\n" + grad_change + ci)
11401167
# Intervention line
@@ -1210,9 +1237,9 @@ class PrePostNEGD(ExperimentalDesign):
12101237
Formula: post ~ 1 + C(group) + pre
12111238
<BLANKLINE>
12121239
Results:
1213-
Causal impact = 1.8, $CI_{94%}$[1.6, 2.0]
1240+
Causal impact = 1.8, $CI_{94%}$[1.7, 2.1]
12141241
Model coefficients:
1215-
Intercept -0.4, 94% HDI [-1.2, 0.2]
1242+
Intercept -0.4, 94% HDI [-1.1, 0.2]
12161243
C(group)[T.1] 1.8, 94% HDI [1.6, 2.0]
12171244
pre 1.0, 94% HDI [0.9, 1.1]
12181245
sigma 0.5, 94% HDI [0.4, 0.5]
@@ -1292,8 +1319,12 @@ def _input_validation(self) -> None:
12921319
"""
12931320
)
12941321

1295-
def plot(self):
1296-
"""Plot the results"""
1322+
def plot(self, round_to=None):
1323+
"""Plot the results
1324+
1325+
:param round_to:
1326+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
1327+
"""
12971328
fig, ax = plt.subplots(
12981329
2, 1, figsize=(7, 9), gridspec_kw={"height_ratios": [3, 1]}
12991330
)
@@ -1339,18 +1370,21 @@ def plot(self):
13391370
)
13401371

13411372
# Plot estimated caual impact / treatment effect
1342-
az.plot_posterior(self.causal_impact, ref_val=0, ax=ax[1])
1373+
az.plot_posterior(self.causal_impact, ref_val=0, ax=ax[1], round_to=round_to)
13431374
ax[1].set(title="Estimated treatment effect")
13441375
return fig, ax
13451376

1346-
def _causal_impact_summary_stat(self) -> str:
1377+
def _causal_impact_summary_stat(self, round_to) -> str:
13471378
"""Computes the mean and 94% credible interval bounds for the causal impact."""
13481379
percentiles = self.causal_impact.quantile([0.03, 1 - 0.03]).values
1349-
ci = r"$CI_{94%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
1380+
ci = (
1381+
r"$CI_{94%}$"
1382+
+ f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]"
1383+
)
13501384
causal_impact = f"{self.causal_impact.mean():.2f}, "
13511385
return f"Causal impact = {causal_impact + ci}"
13521386

1353-
def summary(self) -> None:
1387+
def summary(self, round_to=None) -> None:
13541388
"""
13551389
Print text output summarising the results
13561390
"""
@@ -1359,7 +1393,7 @@ def summary(self) -> None:
13591393
print(f"Formula: {self.formula}")
13601394
print("\nResults:")
13611395
# TODO: extra experiment specific outputs here
1362-
print(self._causal_impact_summary_stat())
1396+
print(self._causal_impact_summary_stat(round_to))
13631397
self.print_coefficients()
13641398

13651399
def _get_treatment_effect_coeff(self) -> str:

causalpy/skl_experiments.py

+35-13
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import seaborn as sns
1818
from patsy import build_design_matrices, dmatrices
1919

20+
from causalpy.utils import round_num
21+
2022
LEGEND_FONT_SIZE = 12
2123

2224

@@ -113,8 +115,12 @@ def __init__(
113115
# cumulative impact post
114116
self.post_impact_cumulative = np.cumsum(self.post_impact)
115117

116-
def plot(self, counterfactual_label="Counterfactual", **kwargs):
117-
"""Plot experiment results"""
118+
def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs):
119+
"""Plot experiment results
120+
121+
:param round_to:
122+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
123+
"""
118124
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
119125

120126
ax[0].plot(self.datapre.index, self.pre_y, "k.")
@@ -128,7 +134,9 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs):
128134
ls=":",
129135
c="k",
130136
)
131-
ax[0].set(title=f"$R^2$ on pre-intervention data = {self.score:.3f}")
137+
ax[0].set(
138+
title=f"$R^2$ on pre-intervention data = {round_num(self.score, round_to)}"
139+
)
132140

133141
ax[1].plot(self.datapre.index, self.pre_impact, "k.")
134142
ax[1].plot(
@@ -258,9 +266,15 @@ class SyntheticControl(PrePostFit):
258266
... )
259267
"""
260268

261-
def plot(self, plot_predictors=False, **kwargs):
262-
"""Plot the results"""
263-
fig, ax = super().plot(counterfactual_label="Synthetic control", **kwargs)
269+
def plot(self, plot_predictors=False, round_to=None, **kwargs):
270+
"""Plot the results
271+
272+
:param round_to:
273+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
274+
"""
275+
fig, ax = super().plot(
276+
counterfactual_label="Synthetic control", round_to=round_to, **kwargs
277+
)
264278
if plot_predictors:
265279
# plot control units as well
266280
ax[0].plot(self.datapre.index, self.pre_X, "-", c=[0.8, 0.8, 0.8], zorder=1)
@@ -397,8 +411,12 @@ def __init__(
397411
# TODO: THIS IS NOT YET CORRECT
398412
self.causal_impact = self.y_pred_treatment[1] - self.y_pred_counterfactual[0]
399413

400-
def plot(self):
401-
"""Plot results"""
414+
def plot(self, round_to=None):
415+
"""Plot results
416+
417+
:param round_to:
418+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
419+
"""
402420
fig, ax = plt.subplots()
403421

404422
# Plot raw data
@@ -462,7 +480,7 @@ def plot(self):
462480
xlim=[-0.05, 1.1],
463481
xticks=[0, 1],
464482
xticklabels=["pre", "post"],
465-
title=f"Causal impact = {self.causal_impact[0]:.2f}",
483+
title=f"Causal impact = {round_num(self.causal_impact[0], round_to)}",
466484
)
467485
ax.legend(fontsize=LEGEND_FONT_SIZE)
468486
return (fig, ax)
@@ -607,8 +625,12 @@ def _is_treated(self, x):
607625
"""
608626
return np.greater_equal(x, self.treatment_threshold)
609627

610-
def plot(self):
611-
"""Plot results"""
628+
def plot(self, round_to=None):
629+
"""Plot results
630+
631+
:param round_to:
632+
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
633+
"""
612634
fig, ax = plt.subplots()
613635
# Plot raw data
614636
sns.scatterplot(
@@ -627,8 +649,8 @@ def plot(self):
627649
label="model fit",
628650
)
629651
# create strings to compose title
630-
r2 = f"$R^2$ on all data = {self.score:.3f}"
631-
discon = f"Discontinuity at threshold = {self.discontinuity_at_threshold:.2f}"
652+
r2 = f"$R^2$ on all data = {round_num(self.score, round_to)}"
653+
discon = f"Discontinuity at threshold = {round_num(self.discontinuity_at_threshold, round_to)}"
632654
ax.set(title=r2 + "\n" + discon)
633655
# Intervention line
634656
ax.axvline(

causalpy/tests/test_utils.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pandas as pd
66

7-
from causalpy.utils import _is_variable_dummy_coded, _series_has_2_levels
7+
from causalpy.utils import _is_variable_dummy_coded, _series_has_2_levels, round_num
88

99

1010
def test_dummy_coding():
@@ -24,3 +24,23 @@ def test_2_level_series():
2424
assert _series_has_2_levels(pd.Series(["water", "tea", "coffee"])) is False
2525
assert _series_has_2_levels(pd.Series([0, 1, 0, 1])) is True
2626
assert _series_has_2_levels(pd.Series([0, 1, 0, 2])) is False
27+
28+
29+
def test_round_num():
30+
"""Test if the function to round numbers works correctly"""
31+
assert round_num(0.12345, None) == "0.12"
32+
assert round_num(0.12345, 0) == "0.1"
33+
assert round_num(0.12345, 1) == "0.1"
34+
assert round_num(0.12345, 2) == "0.12"
35+
assert round_num(0.12345, 3) == "0.123"
36+
assert round_num(0.12345, 4) == "0.1235"
37+
assert round_num(0.12345, 5) == "0.12345"
38+
assert round_num(0.12345, 6) == "0.12345"
39+
assert round_num(123.456, None) == "123"
40+
assert round_num(123.456, 1) == "123"
41+
assert round_num(123.456, 2) == "123"
42+
assert round_num(123.456, 3) == "123"
43+
assert round_num(123.456, 4) == "123.5"
44+
assert round_num(123.456, 5) == "123.46"
45+
assert round_num(123.456, 6) == "123.456"
46+
assert round_num(123.456, 7) == "123.456"

0 commit comments

Comments
 (0)