Skip to content

Commit de56ae2

Browse files
committed
added support for What-If Analysis
1 parent affeae4 commit de56ae2

File tree

10 files changed

+409
-3
lines changed

10 files changed

+409
-3
lines changed

ads/opctl/operator/lowcode/common/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
121121
return data
122122

123123

124-
def write_data(data, filename, format, storage_options, index=False, **kwargs):
124+
def write_data(data, filename, format, storage_options=None, index=False, **kwargs):
125125
if not format:
126126
_, format = os.path.splitext(filename)
127127
format = format[1:]

ads/opctl/operator/lowcode/forecast/__main__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from .operator_config import ForecastOperatorConfig
1919
from .model.forecast_datasets import ForecastDatasets
20+
from .whatifserve import ModelDeploymentManager
2021

2122

2223
def operate(operator_config: ForecastOperatorConfig) -> None:
@@ -27,6 +28,12 @@ def operate(operator_config: ForecastOperatorConfig) -> None:
2728
ForecastOperatorModelFactory.get_model(
2829
operator_config, datasets
2930
).generate_report()
31+
# saving to model catalog
32+
spec = operator_config.spec
33+
if spec.what_if_analysis and datasets.additional_data:
34+
mdm = ModelDeploymentManager(spec, datasets.additional_data)
35+
mdm.save_to_catalog()
36+
3037

3138
def verify(spec: Dict, **kwargs: Dict) -> bool:
3239
"""Verifies the forecasting operator config."""

ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def get_data_multi_indexed(self):
168168
self.additional_data.data,
169169
],
170170
axis=1,
171+
join='inner'
171172
)
172173

173174
def get_data_by_series(self, include_horizon=True):

ads/opctl/operator/lowcode/forecast/operator_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,14 @@ class ForecastOperatorSpec(DataClassSerializable):
9090
confidence_interval_width: float = None
9191
metric: str = None
9292
tuning: Tuning = field(default_factory=Tuning)
93+
what_if_analysis: bool = False
9394

9495
def __post_init__(self):
9596
"""Adjusts the specification details."""
9697
self.output_directory = self.output_directory or OutputDirectory(
9798
url=find_output_dirname(self.output_directory)
9899
)
100+
self.generate_model_pickle = True if self.generate_model_pickle or self.what_if_analysis else False
99101
self.metric = (self.metric or "").lower() or SupportedMetrics.SMAPE.lower()
100102
self.model = self.model or SupportedModels.Prophet
101103
self.confidence_interval_width = self.confidence_interval_width or 0.80

ads/opctl/operator/lowcode/forecast/schema.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,13 @@ spec:
340340
meta:
341341
description: "Report file generation can be enabled using this flag. Defaults to true."
342342

343+
what_if_analysis:
344+
type: boolean
345+
required: false
346+
default: false
347+
meta:
348+
description: "When enabled, the models are saved to the model catalog. Defaults to false."
349+
343350
generate_metrics:
344351
type: boolean
345352
required: false
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
7+
from .deployment_manager import ModelDeploymentManager
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
#!/usr/bin/env python
2+
import json
3+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
import os
7+
import pickle
8+
import shutil
9+
import sys
10+
import tempfile
11+
12+
import pandas as pd
13+
from joblib import dump
14+
15+
from ads.common.model_export_util import prepare_generic_model
16+
from ads.opctl.operator.lowcode.common.utils import write_data, call_pandas_fsspec
17+
18+
from ..model.forecast_datasets import AdditionalData
19+
from ..operator_config import ForecastOperatorSpec
20+
21+
22+
class ModelDeploymentManager:
23+
def __init__(self, spec: ForecastOperatorSpec, additional_data: AdditionalData, previous_model_version=None):
24+
self.spec = spec
25+
# self.model_path = spec.output_directory.url
26+
self.model_name = spec.model
27+
self.horizon = spec.horizon
28+
self.additional_data = additional_data.get_dict_by_series()
29+
self.model_obj = {}
30+
self.path_to_artifact = f"{self.spec.output_directory.url}/artifacts/"
31+
self.pickle_file_path = f"{self.spec.output_directory.url}/model.pkl"
32+
self.model_version = previous_model_version + 1 if previous_model_version else 1
33+
34+
def _satiny_test(self):
35+
"""
36+
Function perform sanity test for saved artifact
37+
"""
38+
sys.path.insert(0, f"{self.path_to_artifact}")
39+
from score import load_model, predict
40+
_ = load_model()
41+
42+
# Write additional data to tmp file and perform sanity check
43+
with tempfile.NamedTemporaryFile(suffix='.csv') as temp_file:
44+
one_series = next(iter(self.additional_data))
45+
sample_prediction_data = self.additional_data[one_series].tail(self.horizon)
46+
sample_prediction_data[self.spec.target_category_columns[0]] = one_series
47+
date_col_name = self.spec.datetime_column.name
48+
date_col_format = self.spec.datetime_column.format
49+
sample_prediction_data[date_col_name] = sample_prediction_data[date_col_name].dt.strftime(date_col_format)
50+
sample_prediction_data.to_csv(temp_file.name, index=False)
51+
additional_data_uri = "additional_data_uri"
52+
input_data = {additional_data_uri: temp_file.name}
53+
prediction_test = predict(input_data, _)
54+
print(f"prediction test completed with result :{prediction_test}")
55+
56+
def _copy_score_file(self):
57+
"""
58+
Copies the score.py to the artifact_path.
59+
"""
60+
try:
61+
current_dir = os.path.dirname(os.path.abspath(__file__))
62+
score_file = os.path.join(current_dir, "score.py")
63+
destination_file = os.path.join(self.path_to_artifact, os.path.basename(score_file))
64+
shutil.copy2(score_file, destination_file)
65+
print(f"score.py copied successfully to {self.path_to_artifact}")
66+
except Exception as e:
67+
print(f"Error copying file: {e}")
68+
raise e
69+
70+
def save_to_catalog(self):
71+
"""Save the model to a model catalog"""
72+
with open(self.pickle_file_path, 'rb') as file:
73+
self.model_obj = pickle.load(file)
74+
75+
if not os.path.exists(self.path_to_artifact):
76+
os.mkdir(self.path_to_artifact)
77+
78+
artifact_dict = {"spec": self.spec.to_dict(), "models": self.model_obj}
79+
dump(artifact_dict, os.path.join(self.path_to_artifact, "model.joblib"))
80+
artifact = prepare_generic_model(self.path_to_artifact, function_artifacts=False, force_overwrite=True,
81+
data_science_env=True)
82+
83+
self._copy_score_file()
84+
self._satiny_test()
85+
86+
if isinstance(self.model_obj, dict):
87+
series = self.model_obj.keys()
88+
else:
89+
series = self.additional_data.keys()
90+
description = f"The object contains {len(series)} {self.model_name} models"
91+
92+
catalog_id = "None"
93+
if not os.environ.get("TEST_MODE", False):
94+
catalog_entry = artifact.save(display_name=f"{self.model_name}-v{self.model_version}",
95+
description=description)
96+
catalog_id = catalog_entry.id
97+
98+
99+
print(f"Saved {self.model_name} version-v{self.model_version} to model catalog"
100+
f" with catalog id : {catalog_id}")
101+
102+
catalog_mapping = {"catalog_id": catalog_id, "series": list(series)}
103+
104+
write_data(
105+
data=pd.DataFrame([catalog_mapping]),
106+
filename=os.path.join(
107+
self.spec.output_directory.url, "model_ids.csv"
108+
),
109+
format="csv"
110+
)
111+
return catalog_id
112+
113+
def create_deployment(self, deployment_config):
114+
"""Create a model deployment serving"""
115+
pass

0 commit comments

Comments
 (0)