|
80 | 80 | parameters_short.append((model, dataset_i))
|
81 | 81 |
|
82 | 82 |
|
83 |
| -def verify_explanations(global_fn, local_fn, yaml_i, additional_cols): |
84 |
| - glb_expl = pd.read_csv(global_fn, index_col=0) |
85 |
| - loc_expl = pd.read_csv(local_fn) |
| 83 | +def verify_explanations(tmpdirname, additional_cols): |
| 84 | + glb_expl = pd.read_csv(f"{tmpdirname}/results/global_explanation.csv", index_col=0) |
| 85 | + loc_expl = pd.read_csv(f"{tmpdirname}/results/local_explanation.csv") |
86 | 86 | assert loc_expl.shape[0] == PERIODS
|
87 |
| - for x in [yaml_i["spec"]["datetime_column"]["name"], "Series"]: |
| 87 | + for x in ["Date", "Series"]: |
88 | 88 | assert x in set(loc_expl.columns)
|
89 |
| - for x in additional_cols: |
90 |
| - assert x in set(loc_expl.columns) |
91 |
| - assert x in set(glb_expl.index) |
| 89 | + # for x in additional_cols: |
| 90 | + # assert x in set(loc_expl.columns) |
| 91 | + # assert x in set(glb_expl.index) |
92 | 92 | assert "Series 1" in set(glb_expl.columns)
|
93 | 93 |
|
94 | 94 |
|
@@ -146,9 +146,7 @@ def test_load_datasets(model, data_details):
|
146 | 146 | subprocess.run(f"ls -a {output_data_path}", shell=True)
|
147 | 147 | if yaml_i["spec"]["generate_explanations"]:
|
148 | 148 | verify_explanations(
|
149 |
| - global_fn=f"{tmpdirname}/results/global_explanation.csv", |
150 |
| - local_fn=f"{tmpdirname}/results/local_explanation.csv", |
151 |
| - yaml_i=yaml_i, |
| 149 | + tmpdirname=tmpdirname, |
152 | 150 | additional_cols=additional_cols,
|
153 | 151 | )
|
154 | 152 | if include_test_data:
|
|
0 commit comments