|
1 | 1 |
|
2 | 2 | import numpy as np
|
| 3 | +try: |
| 4 | + import matplotlib |
| 5 | + import matplotlib.pyplot as plt |
| 6 | +except ImportError as exn: |
| 7 | + from .utilities import MissingModule |
| 8 | + |
| 9 | + # make any access to matplotlib or plt throw an exception |
| 10 | + matplotlib = plt = MissingModule("matplotlib is no longer a dependency of the main econml package; " |
| 11 | + "install econml[plt] or econml[all] to require it, or install matplotlib " |
| 12 | + "separately, to use the tree interpreters", exn) |
3 | 13 |
|
4 | 14 |
|
5 | 15 | def long(x):
|
@@ -42,3 +52,46 @@ def wide(x):
|
42 | 52 | """
|
43 | 53 | n_units = x.shape[0]
|
44 | 54 | return x.reshape(n_units, -1)
|
| 55 | + |
| 56 | + |
| 57 | +# Auxiliary function for adding xticks and vertical lines when plotting results |
| 58 | +# for dynamic dml vs ground truth parameters. |
| 59 | +def add_vlines(n_periods, n_treatments, hetero_inds): |
| 60 | + locs, labels = plt.xticks([], []) |
| 61 | + locs += [- .501 + (len(hetero_inds) + 1) / 2] |
| 62 | + labels += ["\n\n$\\tau_{{{}}}$".format(0)] |
| 63 | + locs += [qx for qx in np.arange(len(hetero_inds) + 1)] |
| 64 | + labels += ["$1$"] + ["$x_{{{}}}$".format(qx) for qx in hetero_inds] |
| 65 | + for q in np.arange(1, n_treatments): |
| 66 | + plt.axvline(x=q * (len(hetero_inds) + 1) - .5, |
| 67 | + linestyle='--', color='red', alpha=.2) |
| 68 | + locs += [q * (len(hetero_inds) + 1) - .501 + (len(hetero_inds) + 1) / 2] |
| 69 | + labels += ["\n\n$\\tau_{{{}}}$".format(q)] |
| 70 | + locs += [(q * (len(hetero_inds) + 1) + qx) |
| 71 | + for qx in np.arange(len(hetero_inds) + 1)] |
| 72 | + labels += ["$1$"] + ["$x_{{{}}}$".format(qx) for qx in hetero_inds] |
| 73 | + locs += [- .501 + (len(hetero_inds) + 1) * n_treatments / 2] |
| 74 | + labels += ["\n\n\n\n$\\theta_{{{}}}$".format(0)] |
| 75 | + for t in np.arange(1, n_periods): |
| 76 | + plt.axvline(x=t * (len(hetero_inds) + 1) * |
| 77 | + n_treatments - .5, linestyle='-', alpha=.6) |
| 78 | + locs += [t * (len(hetero_inds) + 1) * n_treatments - .501 + |
| 79 | + (len(hetero_inds) + 1) * n_treatments / 2] |
| 80 | + labels += ["\n\n\n\n$\\theta_{{{}}}$".format(t)] |
| 81 | + locs += [t * (len(hetero_inds) + 1) * |
| 82 | + n_treatments - .501 + (len(hetero_inds) + 1) / 2] |
| 83 | + labels += ["\n\n$\\tau_{{{}}}$".format(0)] |
| 84 | + locs += [t * (len(hetero_inds) + 1) * n_treatments + |
| 85 | + qx for qx in np.arange(len(hetero_inds) + 1)] |
| 86 | + labels += ["$1$"] + ["$x_{{{}}}$".format(qx) for qx in hetero_inds] |
| 87 | + for q in np.arange(1, n_treatments): |
| 88 | + plt.axvline(x=t * (len(hetero_inds) + 1) * n_treatments + q * (len(hetero_inds) + 1) - .5, |
| 89 | + linestyle='--', color='red', alpha=.2) |
| 90 | + locs += [t * (len(hetero_inds) + 1) * n_treatments + q * |
| 91 | + (len(hetero_inds) + 1) - .501 + (len(hetero_inds) + 1) / 2] |
| 92 | + labels += ["\n\n$\\tau_{{{}}}$".format(q)] |
| 93 | + locs += [t * (len(hetero_inds) + 1) * n_treatments + (q * (len(hetero_inds) + 1) + qx) |
| 94 | + for qx in np.arange(len(hetero_inds) + 1)] |
| 95 | + labels += ["$1$"] + ["$x_{{{}}}$".format(qx) for qx in hetero_inds] |
| 96 | + plt.xticks(locs, labels) |
| 97 | + plt.tight_layout() |
0 commit comments