Skip to content

Commit f1a522d

Browse files
authored
Merge pull request #258 from pymc-labs/fix-plot-error
Fix error in plot method
2 parents adea6f7 + 37515db commit f1a522d

File tree

3 files changed

+25
-16
lines changed

3 files changed

+25
-16
lines changed

causalpy/pymc_experiments.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -463,18 +463,6 @@ class DifferenceInDifferences(ExperimentalDesign):
463463
... }
464464
... )
465465
... )
466-
>>> result.summary() # doctest: +NUMBER
467-
===========================Difference in Differences============================
468-
Formula: y ~ 1 + group*post_treatment
469-
<BLANKLINE>
470-
Results:
471-
Causal impact = 0.5, $CI_{94%}$[0.4, 0.6]
472-
Model coefficients:
473-
Intercept 1.0, 94% HDI [1.0, 1.1]
474-
post_treatment[T.True] 0.9, 94% HDI [0.9, 1.0]
475-
group 0.1, 94% HDI [0.0, 0.2]
476-
group:post_treatment[T.True] 0.5, 94% HDI [0.4, 0.6]
477-
sigma 0.0, 94% HDI [0.0, 0.1]
478466
"""
479467

480468
def __init__(
@@ -726,7 +714,7 @@ def _plot_causal_impact_arrow(self, ax):
726714
def _causal_impact_summary_stat(self) -> str:
727715
"""Computes the mean and 94% credible interval bounds for the causal impact."""
728716
percentiles = self.causal_impact.quantile([0.03, 1 - 0.03]).values
729-
ci = "$CI_{94%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
717+
ci = "$CI_{94\\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
730718
causal_impact = f"{self.causal_impact.mean():.2f}, "
731719
return f"Causal impact = {causal_impact + ci}"
732720

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""
2+
Unit tests for pymc_experiments.py
3+
"""
4+
5+
import causalpy as cp
6+
7+
sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2}
8+
9+
10+
def test_did_summary():
11+
"""Test that the summary stat function returns a string."""
12+
df = cp.load_data("did")
13+
result = cp.pymc_experiments.DifferenceInDifferences(
14+
df,
15+
formula="y ~ 1 + group*post_treatment",
16+
time_variable_name="t",
17+
group_variable_name="group",
18+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
19+
)
20+
print(type(result._causal_impact_summary_stat()))
21+
assert isinstance(result._causal_impact_summary_stat(), str)

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)