Skip to content

Commit d392a6c

Browse files
committed
update testcase
1 parent daa8797 commit d392a6c

File tree

1 file changed

+51
-17
lines changed

1 file changed

+51
-17
lines changed

tests/operators/forecast/test_explainers.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,9 @@ def test_generate_datasets():
9696
assert "target" not in additional_columns
9797

9898

99-
def setup_test_data(model, freq, num_series, horizon=5, num_points=100, seed=42, include_additional=True):
99+
def setup_test_data(
100+
model, freq, num_series, horizon=5, num_points=100, seed=42, include_additional=True
101+
):
100102
"""
101103
Setup test data for the given parameters.
102104
@@ -113,17 +115,21 @@ def setup_test_data(model, freq, num_series, horizon=5, num_points=100, seed=42,
113115
- Tuple containing primary, additional datasets and the operator configuration.
114116
"""
115117
primary, additional, _, _ = generate_datasets(
116-
freq=freq, horizon=horizon, num_series=num_series, num_points=num_points, seed=seed
118+
freq=freq,
119+
horizon=horizon,
120+
num_series=num_series,
121+
num_points=num_points,
122+
seed=seed,
117123
)
118124

119125
yaml_i = deepcopy(TEMPLATE_YAML)
120126
yaml_i["spec"]["historical_data"].pop("url")
121127
yaml_i["spec"]["historical_data"]["data"] = primary
122128
yaml_i["spec"]["historical_data"]["format"] = "pandas"
123-
129+
124130
if include_additional:
125131
yaml_i["spec"]["additional_data"] = {"data": additional, "format": "pandas"}
126-
132+
127133
yaml_i["spec"]["model"] = model
128134
yaml_i["spec"]["target_column"] = "target"
129135
yaml_i["spec"]["datetime_column"]["name"] = "ds"
@@ -177,10 +183,18 @@ def test_explanations_output_and_columns(model, freq, num_series):
177183
not (local_explanations == 0).all().all()
178184
), "Local explanations contain only 0 values"
179185

180-
additional_columns = additional.columns.tolist()
186+
additional_columns = list(
187+
set(additional.columns.tolist())
188+
- set(operator_config.spec.target_category_columns)
189+
- {operator_config.spec.datetime_column.name}
190+
)
181191
for column in additional_columns:
182-
assert column in global_explanations.columns, f"Column {column} missing in global explanations"
183-
assert column in local_explanations.columns, f"Column {column} missing in local explanations"
192+
assert (
193+
column in global_explanations.T.columns
194+
), f"Column {column} missing in global explanations"
195+
assert (
196+
column in local_explanations.columns
197+
), f"Column {column} missing in local explanations"
184198

185199

186200
@pytest.mark.parametrize("model", MODELS)
@@ -208,11 +222,19 @@ def test_explanations_filenames(model, num_series):
208222

209223
results = forecast_operate(operator_config)
210224

211-
global_explanation_path = os.path.join(output_directory, global_explanation_filename)
212-
local_explanation_path = os.path.join(output_directory, local_explanation_filename)
225+
global_explanation_path = os.path.join(
226+
output_directory, global_explanation_filename
227+
)
228+
local_explanation_path = os.path.join(
229+
output_directory, local_explanation_filename
230+
)
213231

214-
assert os.path.exists(global_explanation_path), f"Global explanation file not found at {global_explanation_path}"
215-
assert os.path.exists(local_explanation_path), f"Local explanation file not found at {local_explanation_path}"
232+
assert os.path.exists(
233+
global_explanation_path
234+
), f"Global explanation file not found at {global_explanation_path}"
235+
assert os.path.exists(
236+
local_explanation_path
237+
), f"Local explanation file not found at {local_explanation_path}"
216238

217239

218240
@pytest.mark.parametrize("model", MODELS)
@@ -231,19 +253,23 @@ def test_explanations_no_additional_data(model, num_series, caplog):
231253
with tempfile.TemporaryDirectory() as tmpdirname:
232254
output_directory = tmpdirname
233255

234-
_, _, operator_config = setup_test_data(model, "D", num_series, include_additional=False)
256+
_, _, operator_config = setup_test_data(
257+
model, "D", num_series, include_additional=False
258+
)
235259
operator_config.spec.output_directory.url = output_directory
236260

237261
forecast_operate(operator_config)
238262

239263
assert any(
240264
"Unable to generate explanations as there is no additional data passed in. Either set generate_explanations to False, or pass in additional data."
241-
in message for message in caplog.messages
265+
in message
266+
for message in caplog.messages
242267
), "Required warning message not found in logs"
243268

244269

245270
MODES = ["BALANCED", "HIGH_ACCURACY"]
246271

272+
247273
@pytest.mark.skip(reason="Disabled by default. Enable to run this test.")
248274
@pytest.mark.parametrize("mode", MODES)
249275
@pytest.mark.parametrize("model", MODELS)
@@ -269,11 +295,19 @@ def test_explanations_accuracy_mode(mode, model, num_series):
269295

270296
results = forecast_operate(operator_config)
271297

272-
global_explanation_path = os.path.join(output_directory, operator_config.spec.global_explanation_filename)
273-
local_explanation_path = os.path.join(output_directory, operator_config.spec.local_explanation_filename)
298+
global_explanation_path = os.path.join(
299+
output_directory, operator_config.spec.global_explanation_filename
300+
)
301+
local_explanation_path = os.path.join(
302+
output_directory, operator_config.spec.local_explanation_filename
303+
)
274304

275-
assert os.path.exists(global_explanation_path), f"Global explanation file not found at {global_explanation_path}"
276-
assert os.path.exists(local_explanation_path), f"Local explanation file not found at {local_explanation_path}"
305+
assert os.path.exists(
306+
global_explanation_path
307+
), f"Global explanation file not found at {global_explanation_path}"
308+
assert os.path.exists(
309+
local_explanation_path
310+
), f"Local explanation file not found at {local_explanation_path}"
277311

278312

279313
@pytest.mark.parametrize("model", MODELS)

0 commit comments

Comments
 (0)