Skip to content

Commit 1619dee

Browse files
authored
Merge branch 'main' into feature/odsc68406/amlx_global_explainer
2 parents c33c9a2 + a030882 commit 1619dee

File tree

12 files changed

+248
-84
lines changed

12 files changed

+248
-84
lines changed

.github/workflows/run-forecast-unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,6 @@ jobs:
5656
$CONDA/bin/conda init
5757
source /home/runner/.bashrc
5858
pip install -r test-requirements-operators.txt
59-
pip install "oracle-automlx[forecasting]>=24.4.1"
59+
pip install "oracle-automlx[forecasting]>=25.1.1"
6060
pip install pandas>=2.2.0
6161
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast

ads/opctl/anomaly_detection.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2025 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
from ads.opctl.operator.lowcode.anomaly.__main__ import operate
7+
from ads.opctl.operator.lowcode.anomaly.operator_config import AnomalyOperatorConfig
8+
9+
if __name__ == "__main__":
10+
config = AnomalyOperatorConfig()
11+
operate(config)

ads/opctl/forecast.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2025 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
from ads.opctl.operator.lowcode.forecast.__main__ import operate
7+
from ads.opctl.operator.lowcode.forecast.operator_config import ForecastOperatorConfig
8+
9+
if __name__ == "__main__":
10+
config = ForecastOperatorConfig()
11+
operate(config)

ads/opctl/operator/lowcode/forecast/__main__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
import json
@@ -15,17 +14,17 @@
1514
from ads.opctl.operator.common.const import ENV_OPERATOR_ARGS
1615
from ads.opctl.operator.common.utils import _parse_input_args
1716

17+
from .model.forecast_datasets import ForecastDatasets, ForecastResults
1818
from .operator_config import ForecastOperatorConfig
19-
from .model.forecast_datasets import ForecastDatasets
2019
from .whatifserve import ModelDeploymentManager
2120

2221

23-
def operate(operator_config: ForecastOperatorConfig) -> None:
22+
def operate(operator_config: ForecastOperatorConfig) -> ForecastResults:
2423
"""Runs the forecasting operator."""
2524
from .model.factory import ForecastOperatorModelFactory
2625

2726
datasets = ForecastDatasets(operator_config)
28-
ForecastOperatorModelFactory.get_model(
27+
results = ForecastOperatorModelFactory.get_model(
2928
operator_config, datasets
3029
).generate_report()
3130
# saving to model catalog
@@ -36,6 +35,7 @@ def operate(operator_config: ForecastOperatorConfig) -> None:
3635
if spec.what_if_analysis.model_deployment:
3736
mdm.create_deployment()
3837
mdm.save_deployment_info()
38+
return results
3939

4040

4141
def verify(spec: Dict, **kwargs: Dict) -> bool:

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
import logging
55
import os
@@ -66,8 +66,7 @@ def preprocess(self, data, series_id): # TODO: re-use self.le for explanations
6666
@runtime_dependency(
6767
module="automlx",
6868
err_msg=(
69-
"Please run `pip3 install oracle-automlx>=23.4.1` and "
70-
"`pip3 install oracle-automlx[forecasting]>=23.4.1` "
69+
"Please run `pip3 install oracle-automlx[forecasting]>=25.1.1` "
7170
"to install the required dependencies for automlx."
7271
),
7372
)
@@ -105,7 +104,7 @@ def _build_model(self) -> pd.DataFrame:
105104
engine_opts = (
106105
None
107106
if engine_type == "local"
108-
else ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
107+
else {"ray_setup": {"_temp_dir": "/tmp/ray-temp"}}
109108
)
110109
init(
111110
engine=engine_type,

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import logging
@@ -51,7 +51,7 @@
5151
SupportedModels,
5252
)
5353
from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
54-
from .forecast_datasets import ForecastDatasets
54+
from .forecast_datasets import ForecastDatasets, ForecastResults
5555

5656
logging.getLogger("report_creator").setLevel(logging.WARNING)
5757

@@ -350,11 +350,12 @@ def generate_report(self):
350350
)
351351

352352
# save the report and result CSV
353-
self._save_report(
353+
return self._save_report(
354354
report_sections=report_sections,
355355
result_df=result_df,
356356
metrics_df=self.eval_metrics,
357357
test_metrics_df=self.test_eval_metrics,
358+
test_data=test_data,
358359
)
359360

360361
def _test_evaluate_metrics(self, elapsed_time=0):
@@ -471,10 +472,12 @@ def _save_report(
471472
result_df: pd.DataFrame,
472473
metrics_df: pd.DataFrame,
473474
test_metrics_df: pd.DataFrame,
475+
test_data: pd.DataFrame,
474476
):
475477
"""Saves resulting reports to the given folder."""
476478

477479
unique_output_dir = self.spec.output_directory.url
480+
results = ForecastResults()
478481

479482
if ObjectStorageDetails.is_oci_path(unique_output_dir):
480483
storage_options = default_signer()
@@ -500,6 +503,11 @@ def _save_report(
500503
f2.write(f1.read())
501504

502505
# forecast csv report
506+
# todo: add test data into forecast.csv
507+
# if self.spec.test_data is not None:
508+
# test_data_dict = test_data.get_dict_by_series()
509+
# for series_id, test_data_values in test_data_dict.items():
510+
# result_df[DataColumns.Series] = test_data_values[]
503511
result_df = (
504512
result_df
505513
if self.target_cat_col
@@ -511,6 +519,7 @@ def _save_report(
511519
format="csv",
512520
storage_options=storage_options,
513521
)
522+
results.set_forecast(result_df)
514523

515524
# metrics csv report
516525
if self.spec.generate_metrics:
@@ -520,17 +529,19 @@ def _save_report(
520529
else "Series 1"
521530
)
522531
if metrics_df is not None:
532+
metrics_df_formatted = metrics_df.reset_index().rename(
533+
{"index": "metrics", "Series 1": metrics_col_name}, axis=1
534+
)
523535
write_data(
524-
data=metrics_df.reset_index().rename(
525-
{"index": "metrics", "Series 1": metrics_col_name}, axis=1
526-
),
536+
data=metrics_df_formatted,
527537
filename=os.path.join(
528538
unique_output_dir, self.spec.metrics_filename
529539
),
530540
format="csv",
531541
storage_options=storage_options,
532542
index=False,
533543
)
544+
results.set_metrics(metrics_df_formatted)
534545
else:
535546
logger.warn(
536547
f"Attempted to generate the {self.spec.metrics_filename} file with the training metrics, however the training metrics could not be properly generated."
@@ -539,17 +550,19 @@ def _save_report(
539550
# test_metrics csv report
540551
if self.spec.test_data is not None:
541552
if test_metrics_df is not None:
553+
test_metrics_df_formatted = test_metrics_df.reset_index().rename(
554+
{"index": "metrics", "Series 1": metrics_col_name}, axis=1
555+
)
542556
write_data(
543-
data=test_metrics_df.reset_index().rename(
544-
{"index": "metrics", "Series 1": metrics_col_name}, axis=1
545-
),
557+
data=test_metrics_df_formatted,
546558
filename=os.path.join(
547559
unique_output_dir, self.spec.test_metrics_filename
548560
),
549561
format="csv",
550562
storage_options=storage_options,
551563
index=False,
552564
)
565+
results.set_test_metrics(test_metrics_df_formatted)
553566
else:
554567
logger.warn(
555568
f"Attempted to generate the {self.spec.test_metrics_filename} file with the test metrics, however the test metrics could not be properly generated."
@@ -567,6 +580,7 @@ def _save_report(
567580
storage_options=storage_options,
568581
index=True,
569582
)
583+
results.set_global_explanations(self.formatted_global_explanation)
570584
else:
571585
logger.warn(
572586
f"Attempted to generate global explanations for the {self.spec.global_explanation_filename} file, but an issue occured in formatting the explanations."
@@ -582,6 +596,7 @@ def _save_report(
582596
storage_options=storage_options,
583597
index=True,
584598
)
599+
results.set_local_explanations(self.formatted_local_explanation)
585600
else:
586601
logger.warn(
587602
f"Attempted to generate local explanations for the {self.spec.local_explanation_filename} file, but an issue occured in formatting the explanations."
@@ -602,10 +617,12 @@ def _save_report(
602617
index=True,
603618
indent=4,
604619
)
620+
results.set_model_parameters(self.model_parameters)
605621

606622
# model pickle
607623
if self.spec.generate_model_pickle:
608624
self._save_model(unique_output_dir, storage_options)
625+
results.set_models(self.models)
609626

610627
logger.info(
611628
f"The outputs have been successfully "
@@ -625,8 +642,10 @@ def _save_report(
625642
index=True,
626643
indent=4,
627644
)
645+
results.set_errors_dict(self.errors_dict)
628646
else:
629647
logger.info("All modeling completed successfully.")
648+
return results
630649

631650
def preprocess(self, df, series_id):
632651
"""The method that needs to be implemented on the particular model level."""

ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

6+
from typing import Dict, List
7+
68
import pandas as pd
79

810
from ads.opctl import logger
@@ -167,7 +169,7 @@ def get_data_multi_indexed(self):
167169
self.historical_data.data,
168170
self.additional_data.data,
169171
],
170-
axis=1
172+
axis=1,
171173
)
172174

173175
def get_data_by_series(self, include_horizon=True):
@@ -416,3 +418,59 @@ def get_forecast_long(self):
416418
for df in self.series_id_map.values():
417419
output = pd.concat([output, df])
418420
return output.reset_index(drop=True)
421+
422+
423+
class ForecastResults:
424+
"""
425+
Forecast Results contains all outputs from the forecast run.
426+
This class is returned to users who use the Forecast's `operate` method.
427+
428+
"""
429+
430+
def set_forecast(self, df: pd.DataFrame):
431+
self.forecast = df
432+
433+
def get_forecast(self):
434+
return getattr(self, "forecast", None)
435+
436+
def set_metrics(self, df: pd.DataFrame):
437+
self.metrics = df
438+
439+
def get_metrics(self):
440+
return getattr(self, "metrics", None)
441+
442+
def set_test_metrics(self, df: pd.DataFrame):
443+
self.test_metrics = df
444+
445+
def get_test_metrics(self):
446+
return getattr(self, "test_metrics", None)
447+
448+
def set_local_explanations(self, df: pd.DataFrame):
449+
self.local_explanations = df
450+
451+
def get_local_explanations(self):
452+
return getattr(self, "local_explanations", None)
453+
454+
def set_global_explanations(self, df: pd.DataFrame):
455+
self.global_explanations = df
456+
457+
def get_global_explanations(self):
458+
return getattr(self, "global_explanations", None)
459+
460+
def set_model_parameters(self, df: pd.DataFrame):
461+
self.model_parameters = df
462+
463+
def get_model_parameters(self):
464+
return getattr(self, "model_parameters", None)
465+
466+
def set_models(self, models: List):
467+
self.models = models
468+
469+
def get_models(self):
470+
return getattr(self, "models", None)
471+
472+
def set_errors_dict(self, errors_dict: Dict):
473+
self.errors_dict = errors_dict
474+
475+
def get_errors_dict(self):
476+
return getattr(self, "errors_dict", None)

0 commit comments

Comments
 (0)