Skip to content

Commit 526d8a8

Browse files
committed
fix iPrePostNEGD input validation
1 parent 133b987 commit 526d8a8

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

causalpy/pymc_experiments.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from causalpy.custom_exceptions import BadIndexException # NOQA
2727
from causalpy.custom_exceptions import DataException, FormulaException
2828
from causalpy.plot_utils import plot_xY
29-
from causalpy.utils import _is_variable_dummy_coded, _series_has_2_levels
29+
from causalpy.utils import _is_variable_dummy_coded
3030

3131
LEGEND_FONT_SIZE = 12
3232
az.style.use("arviz-darkgrid")
@@ -978,7 +978,8 @@ class PrePostNEGD(ExperimentalDesign):
978978
:param formula:
979979
A statistical model formula
980980
:param group_variable_name:
981-
Name of the column in data for the group variable
981+
Name of the column in data for the group variable, should be either
982+
binary or boolean
982983
:param pretreatment_variable_name:
983984
Name of the column in data for the pretreatment variable
984985
:param model:
@@ -1058,17 +1059,19 @@ def __init__(
10581059
self.group_variable_name: np.zeros(self.pred_xi.shape),
10591060
}
10601061
)
1061-
(new_x,) = build_design_matrices([self._x_design_info], x_pred_untreated)
1062-
self.pred_untreated = self.model.predict(X=np.asarray(new_x))
1062+
(new_x_untreated,) = build_design_matrices(
1063+
[self._x_design_info], x_pred_untreated
1064+
)
1065+
self.pred_untreated = self.model.predict(X=np.asarray(new_x_untreated))
10631066
# treated
1064-
x_pred_untreated = pd.DataFrame(
1067+
x_pred_treated = pd.DataFrame(
10651068
{
10661069
self.pretreatment_variable_name: self.pred_xi,
10671070
self.group_variable_name: np.ones(self.pred_xi.shape),
10681071
}
10691072
)
1070-
(new_x,) = build_design_matrices([self._x_design_info], x_pred_untreated)
1071-
self.pred_treated = self.model.predict(X=np.asarray(new_x))
1073+
(new_x_treated,) = build_design_matrices([self._x_design_info], x_pred_treated)
1074+
self.pred_treated = self.model.predict(X=np.asarray(new_x_treated))
10721075

10731076
# Evaluate causal impact as equal to the trestment effect
10741077
self.causal_impact = self.idata.posterior["beta"].sel(
@@ -1079,7 +1082,7 @@ def __init__(
10791082

10801083
def _input_validation(self) -> None:
10811084
"""Validate the input data and model formula for correctness"""
1082-
if not _series_has_2_levels(self.data[self.group_variable_name]):
1085+
if not _is_variable_dummy_coded(self.data[self.group_variable_name]):
10831086
raise DataException(
10841087
f"""
10851088
There must be 2 levels of the grouping variable
@@ -1165,7 +1168,7 @@ def _get_treatment_effect_coeff(self) -> str:
11651168
then we want `C(group)[T.1]`.
11661169
"""
11671170
for label in self.labels:
1168-
if ("group" in label) & (":" not in label):
1171+
if (self.group_variable_name in label) & (":" not in label):
11691172
return label
11701173

11711174
raise NameError("Unable to find coefficient name for the treatment effect")

0 commit comments

Comments
 (0)