Skip to content

Commit ce4a5bf

Browse files
authored
Standardize outputs & report for single-series forecasts without target_category_columns (#1028)
2 parents bccfa31 + ab40b35 commit ce4a5bf

File tree

9 files changed

+115
-55
lines changed

9 files changed

+115
-55
lines changed

ads/opctl/operator/lowcode/forecast/model/arima.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,11 @@ def _generate_report(self):
164164
blocks = [
165165
rc.Html(
166166
m.summary().as_html(),
167-
label=s_id,
167+
label=s_id if self.target_cat_col else None,
168168
)
169169
for i, (s_id, m) in enumerate(self.models.items())
170170
]
171-
sec5 = rc.Select(blocks=blocks)
171+
sec5 = rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
172172
all_sections = [sec5_text, sec5]
173173

174174
if self.spec.generate_explanations:
@@ -188,6 +188,21 @@ def _generate_report(self):
188188
axis=1,
189189
)
190190
)
191+
aggregate_local_explanations = pd.DataFrame()
192+
for s_id, local_ex_df in self.local_explanation.items():
193+
local_ex_df_copy = local_ex_df.copy()
194+
local_ex_df_copy["Series"] = s_id
195+
aggregate_local_explanations = pd.concat(
196+
[aggregate_local_explanations, local_ex_df_copy], axis=0
197+
)
198+
self.formatted_local_explanation = aggregate_local_explanations
199+
200+
if not self.target_cat_col:
201+
self.formatted_global_explanation = self.formatted_global_explanation.rename(
202+
{"Series 1": self.original_target_column},
203+
axis=1,
204+
)
205+
self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
191206

192207
# Create a markdown section for the global explainability
193208
global_explanation_section = rc.Block(
@@ -198,26 +213,17 @@ def _generate_report(self):
198213
rc.DataTable(self.formatted_global_explanation, index=True),
199214
)
200215

201-
aggregate_local_explanations = pd.DataFrame()
202-
for s_id, local_ex_df in self.local_explanation.items():
203-
local_ex_df_copy = local_ex_df.copy()
204-
local_ex_df_copy["Series"] = s_id
205-
aggregate_local_explanations = pd.concat(
206-
[aggregate_local_explanations, local_ex_df_copy], axis=0
207-
)
208-
self.formatted_local_explanation = aggregate_local_explanations
209-
210216
blocks = [
211217
rc.DataTable(
212218
local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
213-
label=s_id,
219+
label=s_id if self.target_cat_col else None,
214220
index=True,
215221
)
216222
for s_id, local_ex_df in self.local_explanation.items()
217223
]
218224
local_explanation_section = rc.Block(
219225
rc.Heading("Local Explanation of Models", level=2),
220-
rc.Select(blocks=blocks),
226+
rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
221227
)
222228

223229
# Append the global explanation text and section to the "all_sections" list

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ def _generate_report(self):
223223
selected_models.items(), columns=["series_id", "best_selected_model"]
224224
)
225225
selected_df = selected_models_df["best_selected_model"].apply(pd.Series)
226+
if not self.target_cat_col:
227+
selected_df = selected_df.drop("series_id", axis=1)
226228
selected_models_section = rc.Block(
227229
rc.Heading("Selected Models Overview", level=2),
228230
rc.Text(
@@ -252,15 +254,6 @@ def _generate_report(self):
252254
)
253255
)
254256

255-
# Create a markdown section for the global explainability
256-
global_explanation_section = rc.Block(
257-
rc.Heading("Global Explanation of Models", level=2),
258-
rc.Text(
259-
"The following tables provide the feature attribution for the global explainability."
260-
),
261-
rc.DataTable(self.formatted_global_explanation, index=True),
262-
)
263-
264257
aggregate_local_explanations = pd.DataFrame()
265258
for s_id, local_ex_df in self.local_explanation.items():
266259
local_ex_df_copy = local_ex_df.copy()
@@ -270,17 +263,33 @@ def _generate_report(self):
270263
)
271264
self.formatted_local_explanation = aggregate_local_explanations
272265

266+
if not self.target_cat_col:
267+
self.formatted_global_explanation = self.formatted_global_explanation.rename(
268+
{"Series 1": self.original_target_column},
269+
axis=1,
270+
)
271+
self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
272+
273+
# Create a markdown section for the global explainability
274+
global_explanation_section = rc.Block(
275+
rc.Heading("Global Explanation of Models", level=2),
276+
rc.Text(
277+
"The following tables provide the feature attribution for the global explainability."
278+
),
279+
rc.DataTable(self.formatted_global_explanation, index=True),
280+
)
281+
273282
blocks = [
274283
rc.DataTable(
275284
local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
276-
label=s_id,
285+
label=s_id if self.target_cat_col else None,
277286
index=True,
278287
)
279288
for s_id, local_ex_df in self.local_explanation.items()
280289
]
281290
local_explanation_section = rc.Block(
282291
rc.Heading("Local Explanation of Models", level=2),
283-
rc.Select(blocks=blocks),
292+
rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
284293
)
285294

286295
# Append the global explanation text and section to the "other_sections" list

ads/opctl/operator/lowcode/forecast/model/autots.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def _generate_report(self) -> tuple:
242242
self.models.df_wide_numeric, series=s_id
243243
),
244244
self.datasets.list_series_ids(),
245+
target_category_column=self.target_cat_col
245246
)
246247
section_1 = rc.Block(
247248
rc.Heading("Forecast Overview", level=2),

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
seconds_to_datetime,
2929
write_data,
3030
)
31+
from ads.opctl.operator.lowcode.common.const import DataColumns
3132
from ads.opctl.operator.lowcode.forecast.model.forecast_datasets import TestData
3233
from ads.opctl.operator.lowcode.forecast.utils import (
3334
_build_metrics_df,
@@ -69,7 +70,7 @@ def __init__(self, config: ForecastOperatorConfig, datasets: ForecastDatasets):
6970
self.config: ForecastOperatorConfig = config
7071
self.spec: ForecastOperatorSpec = config.spec
7172
self.datasets: ForecastDatasets = datasets
72-
73+
self.target_cat_col = self.spec.target_category_columns
7374
self.full_data_dict = datasets.get_data_by_series()
7475

7576
self.test_eval_metrics = None
@@ -124,6 +125,9 @@ def generate_report(self):
124125

125126
if self.spec.generate_report or self.spec.generate_metrics:
126127
self.eval_metrics = self.generate_train_metrics()
128+
if not self.target_cat_col:
129+
self.eval_metrics.rename({"Series 1": self.original_target_column},
130+
axis=1, inplace=True)
127131

128132
if self.spec.test_data:
129133
try:
@@ -134,6 +138,9 @@ def generate_report(self):
134138
) = self._test_evaluate_metrics(
135139
elapsed_time=elapsed_time,
136140
)
141+
if not self.target_cat_col:
142+
self.test_eval_metrics.rename({"Series 1": self.original_target_column},
143+
axis=1, inplace=True)
137144
except Exception:
138145
logger.warn("Unable to generate Test Metrics.")
139146
logger.debug(f"Full Traceback: {traceback.format_exc()}")
@@ -179,7 +186,7 @@ def generate_report(self):
179186
first_5_rows_blocks = [
180187
rc.DataTable(
181188
df.head(5),
182-
label=s_id,
189+
label=s_id if self.target_cat_col else None,
183190
index=True,
184191
)
185192
for s_id, df in self.full_data_dict.items()
@@ -188,7 +195,7 @@ def generate_report(self):
188195
last_5_rows_blocks = [
189196
rc.DataTable(
190197
df.tail(5),
191-
label=s_id,
198+
label=s_id if self.target_cat_col else None,
192199
index=True,
193200
)
194201
for s_id, df in self.full_data_dict.items()
@@ -197,7 +204,7 @@ def generate_report(self):
197204
data_summary_blocks = [
198205
rc.DataTable(
199206
df.describe(),
200-
label=s_id,
207+
label=s_id if self.target_cat_col else None,
201208
index=True,
202209
)
203210
for s_id, df in self.full_data_dict.items()
@@ -215,17 +222,17 @@ def generate_report(self):
215222
rc.Block(
216223
first_10_title,
217224
# series_subtext,
218-
rc.Select(blocks=first_5_rows_blocks),
225+
rc.Select(blocks=first_5_rows_blocks) if self.target_cat_col else first_5_rows_blocks[0],
219226
),
220227
rc.Block(
221228
last_10_title,
222229
# series_subtext,
223-
rc.Select(blocks=last_5_rows_blocks),
230+
rc.Select(blocks=last_5_rows_blocks) if self.target_cat_col else last_5_rows_blocks[0],
224231
),
225232
rc.Block(
226233
summary_title,
227234
# series_subtext,
228-
rc.Select(blocks=data_summary_blocks),
235+
rc.Select(blocks=data_summary_blocks) if self.target_cat_col else data_summary_blocks[0],
229236
),
230237
rc.Separator(),
231238
)
@@ -288,6 +295,7 @@ def generate_report(self):
288295
horizon=self.spec.horizon,
289296
test_data=test_data,
290297
ci_interval_width=self.spec.confidence_interval_width,
298+
target_category_column=self.target_cat_col
291299
)
292300
if (
293301
series_name is not None
@@ -463,6 +471,7 @@ def _save_report(
463471
f2.write(f1.read())
464472

465473
# forecast csv report
474+
result_df = result_df if self.target_cat_col else result_df.drop(DataColumns.Series, axis=1)
466475
write_data(
467476
data=result_df,
468477
filename=os.path.join(unique_output_dir, self.spec.forecast_filename),

ads/opctl/operator/lowcode/forecast/model/neuralprophet.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def _generate_report(self):
360360
pd.Series(
361361
m.state_dict(),
362362
index=m.state_dict().keys(),
363-
name=s_id,
363+
name=s_id if self.target_cat_col else self.original_target_column,
364364
)
365365
)
366366
all_model_states = pd.concat(model_states, axis=1)
@@ -373,6 +373,13 @@ def _generate_report(self):
373373
# If the key is present, call the "explain_model" method
374374
self.explain_model()
375375

376+
if not self.target_cat_col:
377+
self.formatted_global_explanation = self.formatted_global_explanation.rename(
378+
{"Series 1": self.original_target_column},
379+
axis=1,
380+
)
381+
self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
382+
376383
# Create a markdown section for the global explainability
377384
global_explanation_section = rc.Block(
378385
rc.Heading("Global Explainability", level=2),
@@ -385,14 +392,14 @@ def _generate_report(self):
385392
blocks = [
386393
rc.DataTable(
387394
local_ex_df.drop("Series", axis=1),
388-
label=s_id,
395+
label=s_id if self.target_cat_col else None,
389396
index=True,
390397
)
391398
for s_id, local_ex_df in self.local_explanation.items()
392399
]
393400
local_explanation_section = rc.Block(
394401
rc.Heading("Local Explanation of Models", level=2),
395-
rc.Select(blocks=blocks),
402+
rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
396403
)
397404

398405
# Append the global explanation text and section to the "all_sections" list

ads/opctl/operator/lowcode/forecast/model/prophet.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ def _generate_report(self):
256256
self.outputs[s_id], include_legend=True
257257
),
258258
series_ids=series_ids,
259+
target_category_column=self.target_cat_col
259260
)
260261
section_1 = rc.Block(
261262
rc.Heading("Forecast Overview", level=2),
@@ -268,6 +269,7 @@ def _generate_report(self):
268269
sec2 = _select_plot_list(
269270
lambda s_id: self.models[s_id].plot_components(self.outputs[s_id]),
270271
series_ids=series_ids,
272+
target_category_column=self.target_cat_col
271273
)
272274
section_2 = rc.Block(
273275
rc.Heading("Forecast Broken Down by Trend Component", level=2), sec2
@@ -281,7 +283,9 @@ def _generate_report(self):
281283
sec3_figs[s_id].gca(), self.models[s_id], self.outputs[s_id]
282284
)
283285
sec3 = _select_plot_list(
284-
lambda s_id: sec3_figs[s_id], series_ids=series_ids
286+
lambda s_id: sec3_figs[s_id],
287+
series_ids=series_ids,
288+
target_category_column=self.target_cat_col
285289
)
286290
section_3 = rc.Block(rc.Heading("Forecast Changepoints", level=2), sec3)
287291

@@ -295,7 +299,7 @@ def _generate_report(self):
295299
pd.Series(
296300
m.seasonalities,
297301
index=pd.Index(m.seasonalities.keys(), dtype="object"),
298-
name=s_id,
302+
name=s_id if self.target_cat_col else self.original_target_column,
299303
dtype="object",
300304
)
301305
)
@@ -316,15 +320,6 @@ def _generate_report(self):
316320
global_explanation_df / global_explanation_df.sum(axis=0) * 100
317321
)
318322

319-
# Create a markdown section for the global explainability
320-
global_explanation_section = rc.Block(
321-
rc.Heading("Global Explanation of Models", level=2),
322-
rc.Text(
323-
"The following tables provide the feature attribution for the global explainability."
324-
),
325-
rc.DataTable(self.formatted_global_explanation, index=True),
326-
)
327-
328323
aggregate_local_explanations = pd.DataFrame()
329324
for s_id, local_ex_df in self.local_explanation.items():
330325
local_ex_df_copy = local_ex_df.copy()
@@ -334,17 +329,33 @@ def _generate_report(self):
334329
)
335330
self.formatted_local_explanation = aggregate_local_explanations
336331

332+
if not self.target_cat_col:
333+
self.formatted_global_explanation = self.formatted_global_explanation.rename(
334+
{"Series 1": self.original_target_column},
335+
axis=1,
336+
)
337+
self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
338+
339+
# Create a markdown section for the global explainability
340+
global_explanation_section = rc.Block(
341+
rc.Heading("Global Explanation of Models", level=2),
342+
rc.Text(
343+
"The following tables provide the feature attribution for the global explainability."
344+
),
345+
rc.DataTable(self.formatted_global_explanation, index=True),
346+
)
347+
337348
blocks = [
338349
rc.DataTable(
339350
local_ex_df.div(local_ex_df.abs().sum(axis=1), axis=0) * 100,
340-
label=s_id,
351+
label=s_id if self.target_cat_col else None,
341352
index=True,
342353
)
343354
for s_id, local_ex_df in self.local_explanation.items()
344355
]
345356
local_explanation_section = rc.Block(
346357
rc.Heading("Local Explanation of Models", level=2),
347-
rc.Select(blocks=blocks),
358+
rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0],
348359
)
349360

350361
# Append the global explanation text and section to the "all_sections" list

ads/opctl/operator/lowcode/forecast/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ def evaluate_train_metrics(output):
250250
return total_metrics
251251

252252

253-
def _select_plot_list(fn, series_ids):
254-
blocks = [rc.Widget(fn(s_id=s_id), label=s_id) for s_id in series_ids]
253+
def _select_plot_list(fn, series_ids, target_category_column):
254+
blocks = [rc.Widget(fn(s_id=s_id), label=s_id if target_category_column else None) for s_id in series_ids]
255255
return rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
256256

257257

@@ -283,6 +283,7 @@ def get_forecast_plots(
283283
horizon,
284284
test_data=None,
285285
ci_interval_width=0.95,
286+
target_category_column=None
286287
):
287288
def plot_forecast_plotly(s_id):
288289
fig = go.Figure()
@@ -379,7 +380,7 @@ def plot_forecast_plotly(s_id):
379380
)
380381
return fig
381382

382-
return _select_plot_list(plot_forecast_plotly, forecast_output.list_series_ids())
383+
return _select_plot_list(plot_forecast_plotly, forecast_output.list_series_ids(), target_category_column)
383384

384385

385386
def convert_target(target: str, target_col: str):

0 commit comments

Comments
 (0)