Skip to content

Commit ca80d4d

Browse files
committed
Enable notebook tests to be run without editable install
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
1 parent da45e91 commit ca80d4d

File tree

3 files changed

+59
-5
lines changed

3 files changed

+59
-5
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ jobs:
147147
if: ${{ matrix.install_graphviz }}
148148
# Add verbose flag to pip installation if in debug mode
149149
- name: Install econml
150-
run: uv pip install --system -e .${{ matrix.extras }} ${{ fromJSON('["","-v"]')[runner.debug] }} ${{ env.use_lkg && '-r lkg-notebook.txt' }}
150+
run: uv pip install --system .${{ matrix.extras }} ${{ fromJSON('["","-v"]')[runner.debug] }} ${{ env.use_lkg && '-r lkg-notebook.txt' }}
151151
# Install notebook requirements (if not already done as part of lkg)
152152
- name: Install notebook requirements
153153
run: uv pip install --system jupyter jupyter-client nbconvert nbformat seaborn xgboost tqdm

econml/panel/utilities.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11

22
import numpy as np
3+
try:
4+
import matplotlib
5+
import matplotlib.pyplot as plt
6+
except ImportError as exn:
7+
from .utilities import MissingModule
8+
9+
# make any access to matplotlib or plt throw an exception
10+
matplotlib = plt = MissingModule("matplotlib is no longer a dependency of the main econml package; "
11+
"install econml[plt] or econml[all] to require it, or install matplotlib "
12+
"separately, to use the tree interpreters", exn)
313

414

515
def long(x):
@@ -42,3 +52,46 @@ def wide(x):
4252
"""
4353
n_units = x.shape[0]
4454
return x.reshape(n_units, -1)
55+
56+
57+
# Auxiliary function for adding xticks and vertical lines when plotting results
58+
# for dynamic dml vs ground truth parameters.
59+
def add_vlines(n_periods, n_treatments, hetero_inds):
60+
locs, labels = plt.xticks([], [])
61+
locs += [- .501 + (len(hetero_inds) + 1) / 2]
62+
labels += ["\n\n$\\tau_{{{}}}$".format(0)]
63+
locs += [qx for qx in np.arange(len(hetero_inds) + 1)]
64+
labels += ["$1$"] + ["$x_{{{}}}$".format(qx) for qx in hetero_inds]
65+
for q in np.arange(1, n_treatments):
66+
plt.axvline(x=q * (len(hetero_inds) + 1) - .5,
67+
linestyle='--', color='red', alpha=.2)
68+
locs += [q * (len(hetero_inds) + 1) - .501 + (len(hetero_inds) + 1) / 2]
69+
labels += ["\n\n$\\tau_{{{}}}$".format(q)]
70+
locs += [(q * (len(hetero_inds) + 1) + qx)
71+
for qx in np.arange(len(hetero_inds) + 1)]
72+
labels += ["$1$"] + ["$x_{{{}}}$".format(qx) for qx in hetero_inds]
73+
locs += [- .501 + (len(hetero_inds) + 1) * n_treatments / 2]
74+
labels += ["\n\n\n\n$\\theta_{{{}}}$".format(0)]
75+
for t in np.arange(1, n_periods):
76+
plt.axvline(x=t * (len(hetero_inds) + 1) *
77+
n_treatments - .5, linestyle='-', alpha=.6)
78+
locs += [t * (len(hetero_inds) + 1) * n_treatments - .501 +
79+
(len(hetero_inds) + 1) * n_treatments / 2]
80+
labels += ["\n\n\n\n$\\theta_{{{}}}$".format(t)]
81+
locs += [t * (len(hetero_inds) + 1) *
82+
n_treatments - .501 + (len(hetero_inds) + 1) / 2]
83+
labels += ["\n\n$\\tau_{{{}}}$".format(0)]
84+
locs += [t * (len(hetero_inds) + 1) * n_treatments +
85+
qx for qx in np.arange(len(hetero_inds) + 1)]
86+
labels += ["$1$"] + ["$x_{{{}}}$".format(qx) for qx in hetero_inds]
87+
for q in np.arange(1, n_treatments):
88+
plt.axvline(x=t * (len(hetero_inds) + 1) * n_treatments + q * (len(hetero_inds) + 1) - .5,
89+
linestyle='--', color='red', alpha=.2)
90+
locs += [t * (len(hetero_inds) + 1) * n_treatments + q *
91+
(len(hetero_inds) + 1) - .501 + (len(hetero_inds) + 1) / 2]
92+
labels += ["\n\n$\\tau_{{{}}}$".format(q)]
93+
locs += [t * (len(hetero_inds) + 1) * n_treatments + (q * (len(hetero_inds) + 1) + qx)
94+
for qx in np.arange(len(hetero_inds) + 1)]
95+
labels += ["$1$"] + ["$x_{{{}}}$".format(qx) for qx in hetero_inds]
96+
plt.xticks(locs, labels)
97+
plt.tight_layout()

0 commit comments

Comments
 (0)