Skip to content

Commit 80537c6

Browse files
committed
make gradient change code more modular
1 parent 2ad1ba3 commit 80537c6

File tree

2 files changed

+32
-20
lines changed

2 files changed

+32
-20
lines changed

causalpy/pymc_experiments.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,11 +1041,28 @@ def __init__(
10411041
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred)
10421042
self.pred = self.model.predict(X=np.asarray(new_x))
10431043

1044-
# Calculate the change in gradient by evaluating the function below the kink
1045-
# point, at the kink point, and above the kink point.
1046-
# NOTE: `"treated": np.array([0, 1])`` assumes treatment is applied above
1047-
# (not below) the threshold
1048-
self.x_discon = pd.DataFrame(
1044+
# evaluate gradient change around kink point
1045+
mu_kink_left, mu_kink, mu_kink_right = self._probe_kink_point()
1046+
self.gradient_change = self._eval_gradient_change(
1047+
mu_kink_left, mu_kink, mu_kink_right, epsilon
1048+
)
1049+
1050+
@staticmethod
1051+
def _eval_gradient_change(mu_kink_left, mu_kink, mu_kink_right, epsilon):
1052+
"""Evaluate the gradient change at the kink point.
1053+
It works by evaluating the model below the kink point, at the kink point,
1054+
and above the kink point.
1055+
This is a static method for ease of testing.
1056+
"""
1057+
gradient_left = (mu_kink - mu_kink_left) / epsilon
1058+
gradient_right = (mu_kink_right - mu_kink) / epsilon
1059+
gradient_change = gradient_right - gradient_left
1060+
return gradient_change
1061+
1062+
def _probe_kink_point(self):
1063+
# Create a dataframe to evaluate predicted outcome at the kink point and either
1064+
# side
1065+
x_predict = pd.DataFrame(
10491066
{
10501067
self.running_variable_name: np.array(
10511068
[
@@ -1057,18 +1074,13 @@ def __init__(
10571074
"treated": np.array([0, 1, 1]),
10581075
}
10591076
)
1060-
(new_x,) = build_design_matrices([self._x_design_info], self.x_discon)
1061-
self.pred_discon = self.model.predict(X=np.asarray(new_x))
1062-
1063-
self.gradient_left = (
1064-
self.pred_discon["posterior_predictive"].sel(obs_ind=1)["mu"]
1065-
- self.pred_discon["posterior_predictive"].sel(obs_ind=0)["mu"]
1066-
) / self.epsilon
1067-
self.gradient_right = (
1068-
self.pred_discon["posterior_predictive"].sel(obs_ind=2)["mu"]
1069-
- self.pred_discon["posterior_predictive"].sel(obs_ind=1)["mu"]
1070-
) / self.epsilon
1071-
self.gradient_change = self.gradient_right - self.gradient_left
1077+
(new_x,) = build_design_matrices([self._x_design_info], x_predict)
1078+
predicted = self.model.predict(X=np.asarray(new_x))
1079+
# extract predicted mu values
1080+
mu_kink_left = predicted["posterior_predictive"].sel(obs_ind=0)["mu"]
1081+
mu_kink = predicted["posterior_predictive"].sel(obs_ind=1)["mu"]
1082+
mu_kink_right = predicted["posterior_predictive"].sel(obs_ind=2)["mu"]
1083+
return mu_kink_left, mu_kink, mu_kink_right
10721084

10731085
def _input_validation(self):
10741086
"""Validate the input data and model formula for correctness"""

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)