23
23
from patsy import build_design_matrices , dmatrices
24
24
from sklearn .linear_model import LinearRegression as sk_lin_reg
25
25
26
- from causalpy .custom_exceptions import BadIndexException
27
- from causalpy .custom_exceptions import DataException , FormulaException
26
+ from causalpy .custom_exceptions import (
27
+ BadIndexException , # NOQA
28
+ DataException ,
29
+ FormulaException ,
30
+ )
28
31
from causalpy .plot_utils import plot_xY
29
- from causalpy .utils import _is_variable_dummy_coded
32
+ from causalpy .utils import _is_variable_dummy_coded , round_num
30
33
31
34
LEGEND_FONT_SIZE = 12
32
35
az .style .use ("arviz-darkgrid" )
@@ -228,9 +231,12 @@ def _input_validation(self, data, treatment_time):
228
231
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
229
232
)
230
233
231
- def plot (self , counterfactual_label = "Counterfactual" , ** kwargs ):
234
+ def plot (self , counterfactual_label = "Counterfactual" , round_to = None , ** kwargs ):
232
235
"""
233
236
Plot the results
237
+
238
+ :param round_to:
239
+ Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
234
240
"""
235
241
fig , ax = plt .subplots (3 , 1 , sharex = True , figsize = (7 , 8 ))
236
242
@@ -275,8 +281,8 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs):
275
281
276
282
ax [0 ].set (
277
283
title = f"""
278
- Pre-intervention Bayesian $R^2$: { self .score .r2 :.3f }
279
- (std = { self .score .r2_std :.3f } )
284
+ Pre-intervention Bayesian $R^2$: { round_num ( self .score .r2 , round_to ) }
285
+ (std = { round_num ( self .score .r2_std , round_to ) } )
280
286
"""
281
287
)
282
288
@@ -416,7 +422,11 @@ class SyntheticControl(PrePostFit):
416
422
expt_type = "Synthetic Control"
417
423
418
424
def plot (self , plot_predictors = False , ** kwargs ):
419
- """Plot the results"""
425
+ """Plot the results
426
+
427
+ :param round_to:
428
+ Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
429
+ """
420
430
fig , ax = super ().plot (counterfactual_label = "Synthetic control" , ** kwargs )
421
431
if plot_predictors :
422
432
# plot control units as well
@@ -580,9 +590,11 @@ def _input_validation(self):
580
590
coded. Consisting of 0's and 1's only."""
581
591
)
582
592
583
- def plot (self ):
593
+ def plot (self , round_to = None ):
584
594
"""Plot the results.
585
- Creating the combined mean + HDI legend entries is a bit involved.
595
+
596
+ :param round_to:
597
+ Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
586
598
"""
587
599
fig , ax = plt .subplots ()
588
600
@@ -658,7 +670,7 @@ def plot(self):
658
670
# formatting
659
671
ax .set (
660
672
xticks = self .x_pred_treatment [self .time_variable_name ].values ,
661
- title = self ._causal_impact_summary_stat (),
673
+ title = self ._causal_impact_summary_stat (round_to ),
662
674
)
663
675
ax .legend (
664
676
handles = (h_tuple for h_tuple in handles ),
@@ -711,11 +723,14 @@ def _plot_causal_impact_arrow(self, ax):
711
723
va = "center" ,
712
724
)
713
725
714
- def _causal_impact_summary_stat (self ) -> str :
726
+ def _causal_impact_summary_stat (self , round_to = None ) -> str :
715
727
"""Computes the mean and 94% credible interval bounds for the causal impact."""
716
728
percentiles = self .causal_impact .quantile ([0.03 , 1 - 0.03 ]).values
717
- ci = "$CI_{94\\ %}$" + f"[{ percentiles [0 ]:.2f} , { percentiles [1 ]:.2f} ]"
718
- causal_impact = f"{ self .causal_impact .mean ():.2f} , "
729
+ ci = (
730
+ "$CI_{94\\ %}$"
731
+ + f"[{ round_num (percentiles [0 ], round_to )} , { round_num (percentiles [1 ], round_to )} ]"
732
+ )
733
+ causal_impact = f"{ round_num (self .causal_impact .mean (), round_to )} , "
719
734
return f"Causal impact = { causal_impact + ci } "
720
735
721
736
def summary (self ) -> None :
@@ -893,9 +908,12 @@ def _is_treated(self, x):
893
908
"""
894
909
return np .greater_equal (x , self .treatment_threshold )
895
910
896
- def plot (self ):
911
+ def plot (self , round_to = None ):
897
912
"""
898
913
Plot the results
914
+
915
+ :param round_to:
916
+ Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
899
917
"""
900
918
fig , ax = plt .subplots ()
901
919
# Plot raw data
@@ -918,12 +936,15 @@ def plot(self):
918
936
labels = ["Posterior mean" ]
919
937
920
938
# create strings to compose title
921
- title_info = f"{ self .score .r2 :.3f } (std = { self .score .r2_std :.3f } )"
939
+ title_info = f"{ round_num ( self .score .r2 , round_to ) } (std = { round_num ( self .score .r2_std , round_to ) } )"
922
940
r2 = f"Bayesian $R^2$ on all data = { title_info } "
923
941
percentiles = self .discontinuity_at_threshold .quantile ([0.03 , 1 - 0.03 ]).values
924
- ci = r"$CI_{94\%}$" + f"[{ percentiles [0 ]:.2f} , { percentiles [1 ]:.2f} ]"
942
+ ci = (
943
+ r"$CI_{94\%}$"
944
+ + f"[{ round_num (percentiles [0 ], round_to )} , { round_num (percentiles [1 ], round_to )} ]"
945
+ )
925
946
discon = f"""
926
- Discontinuity at threshold = { self .discontinuity_at_threshold .mean ():.2f } ,
947
+ Discontinuity at threshold = { round_num ( self .discontinuity_at_threshold .mean (), round_to ) } ,
927
948
"""
928
949
ax .set (title = r2 + "\n " + discon + ci )
929
950
# Intervention line
@@ -1104,9 +1125,12 @@ def _is_treated(self, x):
1104
1125
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.""" # noqa: E501
1105
1126
return np .greater_equal (x , self .kink_point )
1106
1127
1107
- def plot (self ):
1128
+ def plot (self , round_to = None ):
1108
1129
"""
1109
1130
Plot the results
1131
+
1132
+ :param round_to:
1133
+ Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
1110
1134
"""
1111
1135
fig , ax = plt .subplots ()
1112
1136
# Plot raw data
@@ -1129,12 +1153,15 @@ def plot(self):
1129
1153
labels = ["Posterior mean" ]
1130
1154
1131
1155
# create strings to compose title
1132
- title_info = f"{ self .score .r2 :.3f } (std = { self .score .r2_std :.3f } )"
1156
+ title_info = f"{ round_num ( self .score .r2 , round_to ) } (std = { round_num ( self .score .r2_std , round_to ) } )"
1133
1157
r2 = f"Bayesian $R^2$ on all data = { title_info } "
1134
1158
percentiles = self .gradient_change .quantile ([0.03 , 1 - 0.03 ]).values
1135
- ci = r"$CI_{94\%}$" + f"[{ percentiles [0 ]:.2f} , { percentiles [1 ]:.2f} ]"
1159
+ ci = (
1160
+ r"$CI_{94\%}$"
1161
+ + f"[{ round_num (percentiles [0 ], round_to )} , { round_num (percentiles [1 ], round_to )} ]"
1162
+ )
1136
1163
grad_change = f"""
1137
- Change in gradient = { self .gradient_change .mean ():.2f } ,
1164
+ Change in gradient = { round_num ( self .gradient_change .mean (), round_to ) } ,
1138
1165
"""
1139
1166
ax .set (title = r2 + "\n " + grad_change + ci )
1140
1167
# Intervention line
@@ -1210,9 +1237,9 @@ class PrePostNEGD(ExperimentalDesign):
1210
1237
Formula: post ~ 1 + C(group) + pre
1211
1238
<BLANKLINE>
1212
1239
Results:
1213
- Causal impact = 1.8, $CI_{94%}$[1.6 , 2.0 ]
1240
+ Causal impact = 1.8, $CI_{94%}$[1.7 , 2.1 ]
1214
1241
Model coefficients:
1215
- Intercept -0.4, 94% HDI [-1.2 , 0.2]
1242
+ Intercept -0.4, 94% HDI [-1.1 , 0.2]
1216
1243
C(group)[T.1] 1.8, 94% HDI [1.6, 2.0]
1217
1244
pre 1.0, 94% HDI [0.9, 1.1]
1218
1245
sigma 0.5, 94% HDI [0.4, 0.5]
@@ -1292,8 +1319,12 @@ def _input_validation(self) -> None:
1292
1319
"""
1293
1320
)
1294
1321
1295
- def plot (self ):
1296
- """Plot the results"""
1322
+ def plot (self , round_to = None ):
1323
+ """Plot the results
1324
+
1325
+ :param round_to:
1326
+ Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
1327
+ """
1297
1328
fig , ax = plt .subplots (
1298
1329
2 , 1 , figsize = (7 , 9 ), gridspec_kw = {"height_ratios" : [3 , 1 ]}
1299
1330
)
@@ -1339,18 +1370,21 @@ def plot(self):
1339
1370
)
1340
1371
1341
1372
# Plot estimated caual impact / treatment effect
1342
- az .plot_posterior (self .causal_impact , ref_val = 0 , ax = ax [1 ])
1373
+ az .plot_posterior (self .causal_impact , ref_val = 0 , ax = ax [1 ], round_to = round_to )
1343
1374
ax [1 ].set (title = "Estimated treatment effect" )
1344
1375
return fig , ax
1345
1376
1346
- def _causal_impact_summary_stat (self ) -> str :
1377
+ def _causal_impact_summary_stat (self , round_to ) -> str :
1347
1378
"""Computes the mean and 94% credible interval bounds for the causal impact."""
1348
1379
percentiles = self .causal_impact .quantile ([0.03 , 1 - 0.03 ]).values
1349
- ci = r"$CI_{94%}$" + f"[{ percentiles [0 ]:.2f} , { percentiles [1 ]:.2f} ]"
1380
+ ci = (
1381
+ r"$CI_{94%}$"
1382
+ + f"[{ round_num (percentiles [0 ], round_to )} , { round_num (percentiles [1 ], round_to )} ]"
1383
+ )
1350
1384
causal_impact = f"{ self .causal_impact .mean ():.2f} , "
1351
1385
return f"Causal impact = { causal_impact + ci } "
1352
1386
1353
- def summary (self ) -> None :
1387
+ def summary (self , round_to = None ) -> None :
1354
1388
"""
1355
1389
Print text output summarising the results
1356
1390
"""
@@ -1359,7 +1393,7 @@ def summary(self) -> None:
1359
1393
print (f"Formula: { self .formula } " )
1360
1394
print ("\n Results:" )
1361
1395
# TODO: extra experiment specific outputs here
1362
- print (self ._causal_impact_summary_stat ())
1396
+ print (self ._causal_impact_summary_stat (round_to ))
1363
1397
self .print_coefficients ()
1364
1398
1365
1399
def _get_treatment_effect_coeff (self ) -> str :
0 commit comments