Skip to content

Commit 48b27b3

Browse files
committed
update prophet explanations
1 parent ea9258b commit 48b27b3

File tree

4 files changed

+40
-28
lines changed

4 files changed

+40
-28
lines changed

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,9 @@ def _save_report(
635635
)
636636
if self.errors_dict:
637637
write_data(
638-
data=pd.DataFrame.from_dict(self.errors_dict),
638+
data=pd.DataFrame(
639+
self.errors_dict, index=np.arange(len(self.errors_dict.keys()))
640+
),
639641
filename=os.path.join(
640642
unique_output_dir, self.spec.errors_dict_filename
641643
),

ads/opctl/operator/lowcode/forecast/model/neuralprophet.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import logging
@@ -153,10 +153,8 @@ def _train_model(self, i, s_id, df, model_kwargs):
153153
cols_to_read = filter(
154154
lambda x: x.startswith("future_regressor"), forecast.columns
155155
)
156-
self.explanations_info[s_id] = forecast[cols_to_read]
157-
self.explanations_info[s_id]["Date"] = forecast["ds"]
158-
self.explanations_info[s_id] = self.explanations_info[s_id].set_index(
159-
"Date"
156+
self.explanations_info[s_id] = (
157+
forecast[cols_to_read].rename({"ds": "Date"}, axis=1).set_index("Date")
160158
)
161159

162160
self.outputs[s_id] = forecast

ads/opctl/operator/lowcode/forecast/model/prophet.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from ..const import (
2323
DEFAULT_TRIALS,
2424
PROPHET_INTERNAL_DATE_COL,
25-
ForecastOutputColumns,
2625
SupportedModels,
2726
)
2827
from .base_model import ForecastOperatorBaseModel
@@ -123,6 +122,14 @@ def _train_model(self, i, series_id, df, model_kwargs):
123122
upper_bound=self.get_horizon(forecast["yhat_upper"]).values,
124123
lower_bound=self.get_horizon(forecast["yhat_lower"]).values,
125124
)
125+
# Get all features that make up the forecast. Exclude CI (upper/lower) and drop yhat ([:-1])
126+
core_columns = forecast.columns[
127+
~forecast.columns.str.endswith("_lower")
128+
& ~forecast.columns.str.endswith("_upper")
129+
][:-1]
130+
self.explanations_info[series_id] = (
131+
forecast[core_columns].rename({"ds": "Date"}, axis=1).set_index("Date")
132+
)
126133

127134
self.models[series_id] = {}
128135
self.models[series_id]["model"] = model
@@ -151,6 +158,7 @@ def _build_model(self) -> pd.DataFrame:
151158
full_data_dict = self.datasets.get_data_by_series()
152159
self.models = {}
153160
self.outputs = {}
161+
self.explanations_info = {}
154162
self.additional_regressors = self.datasets.get_additional_data_column_names()
155163
model_kwargs = self.set_kwargs()
156164
self.forecast_output = ForecastOutput(
@@ -257,6 +265,25 @@ def objective(trial):
257265
model_kwargs_i = study.best_params
258266
return model_kwargs_i
259267

268+
def explain_model(self):
269+
self.local_explanation = {}
270+
global_expl = []
271+
272+
for s_id, expl_df in self.explanations_info.items():
273+
# Local Expl
274+
self.local_explanation[s_id] = self.get_horizon(expl_df)
275+
self.local_explanation[s_id]["Series"] = s_id
276+
self.local_explanation[s_id].index.rename(self.dt_column_name, inplace=True)
277+
# Global Expl
278+
g_expl = self.drop_horizon(expl_df).mean()
279+
g_expl.name = s_id
280+
global_expl.append(g_expl)
281+
self.global_explanation = pd.concat(global_expl, axis=1)
282+
self.formatted_global_explanation = (
283+
self.global_explanation / self.global_explanation.sum(axis=0) * 100
284+
)
285+
self.formatted_local_explanation = pd.concat(self.local_explanation.values())
286+
260287
def _generate_report(self):
261288
import report_creator as rc
262289
from prophet.plot import add_changepoints_to_plot
@@ -335,22 +362,6 @@ def _generate_report(self):
335362
# If the key is present, call the "explain_model" method
336363
self.explain_model()
337364

338-
# Convert the global explanation data to a DataFrame
339-
global_explanation_df = pd.DataFrame(self.global_explanation)
340-
341-
self.formatted_global_explanation = (
342-
global_explanation_df / global_explanation_df.sum(axis=0) * 100
343-
)
344-
345-
aggregate_local_explanations = pd.DataFrame()
346-
for s_id, local_ex_df in self.local_explanation.items():
347-
local_ex_df_copy = local_ex_df.copy()
348-
local_ex_df_copy[ForecastOutputColumns.SERIES] = s_id
349-
aggregate_local_explanations = pd.concat(
350-
[aggregate_local_explanations, local_ex_df_copy], axis=0
351-
)
352-
self.formatted_local_explanation = aggregate_local_explanations
353-
354365
if not self.target_cat_col:
355366
self.formatted_global_explanation = (
356367
self.formatted_global_explanation.rename(
@@ -364,7 +375,7 @@ def _generate_report(self):
364375

365376
# Create a markdown section for the global explainability
366377
global_explanation_section = rc.Block(
367-
rc.Heading("Global Explanation of Models", level=2),
378+
rc.Heading("Global Explainability", level=2),
368379
rc.Text(
369380
"The following tables provide the feature attribution for the global explainability."
370381
),
@@ -373,7 +384,7 @@ def _generate_report(self):
373384

374385
blocks = [
375386
rc.DataTable(
376-
local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
387+
local_ex_df.drop("Series", axis=1),
377388
label=s_id if self.target_cat_col else None,
378389
index=True,
379390
)
@@ -393,6 +404,8 @@ def _generate_report(self):
393404
# Do not fail the whole run due to explanations failure
394405
logger.warning(f"Failed to generate Explanations with error: {e}.")
395406
logger.debug(f"Full Traceback: {traceback.format_exc()}")
407+
self.errors_dict["explainer_error"] = str(e)
408+
self.errors_dict["explainer_error_error"] = traceback.format_exc()
396409

397410
model_description = rc.Text(
398411
"""Prophet is a procedure for forecasting time series data based on an additive model where non-linear trends are fit with yearly, weekly, and daily seasonality, plus holiday effects. It works best with time series that have strong seasonal effects and several seasons of historical data. Prophet is robust to missing data and shifts in the trend, and typically handles outliers well."""

tests/operators/forecast/test_errors.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,7 @@ def test_generate_files(operator_setup, model):
961961
yaml_i["spec"]["generate_forecast_file"] = False
962962
yaml_i["spec"]["generate_metrics_file"] = False
963963
yaml_i["spec"]["generate_explanations"] = True
964+
yaml_i["spec"]["model_kwargs"] = {"min": 0, "max": 20}
964965

965966
df = pd.concat([HISTORICAL_DATETIME_COL[:15], TARGET_COL[:15]], axis=1)
966967
df_add = pd.concat([HISTORICAL_DATETIME_COL[:18], ADD_COLS[:18]], axis=1)
@@ -971,8 +972,7 @@ def test_generate_files(operator_setup, model):
971972
files = os.listdir(operator_setup)
972973
if "errors.json" in files:
973974
with open(os.path.join(operator_setup, "errors.json")) as f:
974-
print(f"Errors in build! {f.read()}")
975-
assert False, "Failed due to errors.json being created"
975+
assert False, f"Failed due to errors.json being created: {f.read()}"
976976
assert "report.html" in files, "Failed to generate report"
977977
assert (
978978
"forecast.csv" not in files
@@ -988,7 +988,6 @@ def test_generate_files(operator_setup, model):
988988
), "Generated metrics file, but `generate_explanation_files` was set False"
989989
assert not results.get_forecast().empty
990990
assert not results.get_metrics().empty
991-
print(f"global expl: {results.get_global_explanations()}")
992991
assert not results.get_global_explanations().empty
993992
assert not results.get_local_explanations().empty
994993

0 commit comments

Comments
 (0)