26
26
from causalpy .custom_exceptions import BadIndexException # NOQA
27
27
from causalpy .custom_exceptions import DataException , FormulaException
28
28
from causalpy .plot_utils import plot_xY
29
- from causalpy .utils import _is_variable_dummy_coded , _series_has_2_levels
29
+ from causalpy .utils import _is_variable_dummy_coded
30
30
31
31
LEGEND_FONT_SIZE = 12
32
32
az .style .use ("arviz-darkgrid" )
@@ -978,7 +978,8 @@ class PrePostNEGD(ExperimentalDesign):
978
978
:param formula:
979
979
A statistical model formula
980
980
:param group_variable_name:
981
- Name of the column in data for the group variable
981
+ Name of the column in data for the group variable, should be either
982
+ binary or boolean
982
983
:param pretreatment_variable_name:
983
984
Name of the column in data for the pretreatment variable
984
985
:param model:
@@ -1058,17 +1059,19 @@ def __init__(
1058
1059
self .group_variable_name : np .zeros (self .pred_xi .shape ),
1059
1060
}
1060
1061
)
1061
- (new_x ,) = build_design_matrices ([self ._x_design_info ], x_pred_untreated )
1062
- self .pred_untreated = self .model .predict (X = np .asarray (new_x ))
1062
+ (new_x_untreated ,) = build_design_matrices (
1063
+ [self ._x_design_info ], x_pred_untreated
1064
+ )
1065
+ self .pred_untreated = self .model .predict (X = np .asarray (new_x_untreated ))
1063
1066
# treated
1064
- x_pred_untreated = pd .DataFrame (
1067
+ x_pred_treated = pd .DataFrame (
1065
1068
{
1066
1069
self .pretreatment_variable_name : self .pred_xi ,
1067
1070
self .group_variable_name : np .ones (self .pred_xi .shape ),
1068
1071
}
1069
1072
)
1070
- (new_x ,) = build_design_matrices ([self ._x_design_info ], x_pred_untreated )
1071
- self .pred_treated = self .model .predict (X = np .asarray (new_x ))
1073
+ (new_x_treated ,) = build_design_matrices ([self ._x_design_info ], x_pred_treated )
1074
+ self .pred_treated = self .model .predict (X = np .asarray (new_x_treated ))
1072
1075
1073
1076
# Evaluate causal impact as equal to the trestment effect
1074
1077
self .causal_impact = self .idata .posterior ["beta" ].sel (
@@ -1079,7 +1082,7 @@ def __init__(
1079
1082
1080
1083
def _input_validation (self ) -> None :
1081
1084
"""Validate the input data and model formula for correctness"""
1082
- if not _series_has_2_levels (self .data [self .group_variable_name ]):
1085
+ if not _is_variable_dummy_coded (self .data [self .group_variable_name ]):
1083
1086
raise DataException (
1084
1087
f"""
1085
1088
There must be 2 levels of the grouping variable
@@ -1165,7 +1168,7 @@ def _get_treatment_effect_coeff(self) -> str:
1165
1168
then we want `C(group)[T.1]`.
1166
1169
"""
1167
1170
for label in self .labels :
1168
- if ("group" in label ) & (":" not in label ):
1171
+ if (self . group_variable_name in label ) & (":" not in label ):
1169
1172
return label
1170
1173
1171
1174
raise NameError ("Unable to find coefficient name for the treatment effect" )
0 commit comments