@@ -176,44 +176,3 @@ def policy_gen(Tpre, X, period):
176
176
return self ._gen_data_with_policy (n_units , policy_gen , random_seed = random_seed )
177
177
178
178
179
- # Auxiliary function for adding xticks and vertical lines when plotting results
180
- # for dynamic dml vs ground truth parameters.
181
- def add_vlines (n_periods , n_treatments , hetero_inds ):
182
- locs , labels = plt .xticks ([], [])
183
- locs += [- .501 + (len (hetero_inds ) + 1 ) / 2 ]
184
- labels += ["\n \n $\\ tau_{{{}}}$" .format (0 )]
185
- locs += [qx for qx in np .arange (len (hetero_inds ) + 1 )]
186
- labels += ["$1$" ] + ["$x_{{{}}}$" .format (qx ) for qx in hetero_inds ]
187
- for q in np .arange (1 , n_treatments ):
188
- plt .axvline (x = q * (len (hetero_inds ) + 1 ) - .5 ,
189
- linestyle = '--' , color = 'red' , alpha = .2 )
190
- locs += [q * (len (hetero_inds ) + 1 ) - .501 + (len (hetero_inds ) + 1 ) / 2 ]
191
- labels += ["\n \n $\\ tau_{{{}}}$" .format (q )]
192
- locs += [(q * (len (hetero_inds ) + 1 ) + qx )
193
- for qx in np .arange (len (hetero_inds ) + 1 )]
194
- labels += ["$1$" ] + ["$x_{{{}}}$" .format (qx ) for qx in hetero_inds ]
195
- locs += [- .501 + (len (hetero_inds ) + 1 ) * n_treatments / 2 ]
196
- labels += ["\n \n \n \n $\\ theta_{{{}}}$" .format (0 )]
197
- for t in np .arange (1 , n_periods ):
198
- plt .axvline (x = t * (len (hetero_inds ) + 1 ) *
199
- n_treatments - .5 , linestyle = '-' , alpha = .6 )
200
- locs += [t * (len (hetero_inds ) + 1 ) * n_treatments - .501 +
201
- (len (hetero_inds ) + 1 ) * n_treatments / 2 ]
202
- labels += ["\n \n \n \n $\\ theta_{{{}}}$" .format (t )]
203
- locs += [t * (len (hetero_inds ) + 1 ) *
204
- n_treatments - .501 + (len (hetero_inds ) + 1 ) / 2 ]
205
- labels += ["\n \n $\\ tau_{{{}}}$" .format (0 )]
206
- locs += [t * (len (hetero_inds ) + 1 ) * n_treatments +
207
- qx for qx in np .arange (len (hetero_inds ) + 1 )]
208
- labels += ["$1$" ] + ["$x_{{{}}}$" .format (qx ) for qx in hetero_inds ]
209
- for q in np .arange (1 , n_treatments ):
210
- plt .axvline (x = t * (len (hetero_inds ) + 1 ) * n_treatments + q * (len (hetero_inds ) + 1 ) - .5 ,
211
- linestyle = '--' , color = 'red' , alpha = .2 )
212
- locs += [t * (len (hetero_inds ) + 1 ) * n_treatments + q *
213
- (len (hetero_inds ) + 1 ) - .501 + (len (hetero_inds ) + 1 ) / 2 ]
214
- labels += ["\n \n $\\ tau_{{{}}}$" .format (q )]
215
- locs += [t * (len (hetero_inds ) + 1 ) * n_treatments + (q * (len (hetero_inds ) + 1 ) + qx )
216
- for qx in np .arange (len (hetero_inds ) + 1 )]
217
- labels += ["$1$" ] + ["$x_{{{}}}$" .format (qx ) for qx in hetero_inds ]
218
- plt .xticks (locs , labels )
219
- plt .tight_layout ()
0 commit comments