Skip to content

Commit bf1ec22

Browse files
committed
added options for reading and model saving
1 parent e48a966 commit bf1ec22

File tree

5 files changed

+57
-24
lines changed

5 files changed

+57
-24
lines changed

ads/opctl/operator/lowcode/forecast/operator_config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@
1919
from .const import SpeedAccuracyMode, SupportedMetrics, SupportedModels
2020

2121

22+
@dataclass(repr=True)
23+
class WhatIfAnalysis(DataClassSerializable):
24+
"""Class representing operator specification for whatif-analysis."""
25+
model_name: str = None
26+
compartment_id: str = None
27+
project_id: str = None
28+
29+
2230
@dataclass(repr=True)
2331
class TestData(InputData):
2432
"""Class representing operator specification test data details."""
@@ -90,7 +98,7 @@ class ForecastOperatorSpec(DataClassSerializable):
9098
confidence_interval_width: float = None
9199
metric: str = None
92100
tuning: Tuning = field(default_factory=Tuning)
93-
what_if_analysis: bool = False
101+
what_if_analysis: WhatIfAnalysis = field(default_factory=WhatIfAnalysis)
94102

95103
def __post_init__(self):
96104
"""Adjusts the specification details."""

ads/opctl/operator/lowcode/forecast/schema.yaml

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,20 @@ spec:
341341
description: "Report file generation can be enabled using this flag. Defaults to true."
342342

343343
what_if_analysis:
344-
type: boolean
344+
type: dict
345345
required: false
346-
default: false
346+
schema:
347+
model_name:
348+
type: string
349+
required: true
350+
project_id:
351+
type: string
352+
required: false
353+
meta: "If not provided, The project OCID from config.PROJECT_OCID is used"
354+
compartment_id:
355+
type: string
356+
required: false
357+
meta: "If not provided, The compartment OCID from config.NB_SESSION_COMPARTMENT_OCID is used."
347358
meta:
348359
description: "When enabled, the models are saved to the model catalog. Defaults to false."
349360

ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ def __init__(self, spec: ForecastOperatorSpec, additional_data: AdditionalData,
2727
self.horizon = spec.horizon
2828
self.additional_data = additional_data.get_dict_by_series()
2929
self.model_obj = {}
30+
self.display_name = spec.what_if_analysis.model_name
31+
self.project_id = spec.what_if_analysis.project_id
32+
self.compartment_id = spec.what_if_analysis.compartment_id
3033
self.path_to_artifact = f"{self.spec.output_directory.url}/artifacts/"
3134
self.pickle_file_path = f"{self.spec.output_directory.url}/model.pkl"
3235
self.model_version = previous_model_version + 1 if previous_model_version else 1
@@ -48,8 +51,8 @@ def _satiny_test(self):
4851
date_col_format = self.spec.datetime_column.format
4952
sample_prediction_data[date_col_name] = sample_prediction_data[date_col_name].dt.strftime(date_col_format)
5053
sample_prediction_data.to_csv(temp_file.name, index=False)
51-
additional_data_uri = "additional_data_uri"
52-
input_data = {additional_data_uri: temp_file.name}
54+
additional_data_uri = "additional_data"
55+
input_data = {additional_data_uri: {"url": temp_file.name}}
5356
prediction_test = predict(input_data, _)
5457
logger.info(f"prediction test completed with result :{prediction_test}")
5558

@@ -91,7 +94,9 @@ def save_to_catalog(self):
9194

9295
catalog_id = "None"
9396
if not os.environ.get("TEST_MODE", False):
94-
catalog_entry = artifact.save(display_name=f"{self.model_name}-v{self.model_version}",
97+
catalog_entry = artifact.save(display_name=self.display_name,
98+
compartment_id=self.compartment_id,
99+
project_id=self.project_id,
95100
description=description)
96101
catalog_id = catalog_entry.id
97102

ads/opctl/operator/lowcode/forecast/whatifserve/score.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,8 @@
1111
from functools import lru_cache
1212
import logging
1313
import ads
14-
from prophet import Prophet
15-
from neuralprophet import NeuralProphet
16-
from pmdarima import ARIMA
17-
from autots import AutoTS
18-
from automlx._interface.forecaster import AutoForecaster
14+
from ads.opctl.operator.lowcode.common.utils import load_data
15+
from ads.opctl.operator.common.operator_config import InputData
1916

2017
ads.set_auth("resource_principal")
2118

@@ -29,6 +26,12 @@
2926
Inference script. This script is used for prediction by scoring server when schema is known.
3027
"""
3128

29+
AUTOTS = "autots"
30+
ARIMA = "arima"
31+
PROPHET = "prophet"
32+
NEURALPROPHET = "neuralprophet"
33+
AUTOMLX = "automlx"
34+
3235

3336
@lru_cache(maxsize=10)
3437
def load_model():
@@ -132,14 +135,14 @@ def post_inference(yhat):
132135
return yhat
133136

134137

135-
def get_forecast(future_df, series_id, model_object, date_col, target_column, target_cat_col, horizon):
138+
def get_forecast(future_df, model_name, series_id, model_object, date_col, target_column, target_cat_col, horizon):
136139
date_col_name = date_col["name"]
137140
date_col_format = date_col["format"]
138141
future_df[target_cat_col] = future_df[target_cat_col].astype("str")
139142
future_df[date_col_name] = pd.to_datetime(
140143
future_df[date_col_name], format=date_col_format
141144
)
142-
if isinstance(model_object, AutoTS):
145+
if model_name == AUTOTS:
143146
series_id_col = "Series"
144147
full_data_indexed = future_df.rename(columns={target_cat_col: series_id_col})
145148
additional_regressors = list(
@@ -152,12 +155,12 @@ def get_forecast(future_df, series_id, model_object, date_col, target_column, ta
152155
)
153156
pred_obj = model_object.predict(future_regressor=future_reg)
154157
return pred_obj.forecast[series_id].tolist()
155-
elif series_id in model_object and isinstance(model_object[series_id], Prophet):
158+
elif model_name == PROPHET and series_id in model_object:
156159
model = model_object[series_id]
157160
processed = future_df.rename(columns={date_col_name: 'ds', target_column: 'y'})
158161
forecast = model.predict(processed)
159162
return forecast['yhat'].tolist()
160-
elif series_id in model_object and isinstance(model_object[series_id], NeuralProphet):
163+
elif model_name == NEURALPROPHET and series_id in model_object:
161164
model = model_object[series_id]
162165
model.restore_trainer()
163166
accepted_regressors = list(model.config_regressors.keys())
@@ -166,7 +169,7 @@ def get_forecast(future_df, series_id, model_object, date_col, target_column, ta
166169
future["y"] = None
167170
forecast = model.predict(future)
168171
return forecast['yhat1'].tolist()
169-
elif series_id in model_object and isinstance(model_object[series_id], ARIMA):
172+
elif model_name == ARIMA and series_id in model_object:
170173
model = model_object[series_id]
171174
future_df = future_df.set_index(date_col_name)
172175
x_pred = future_df.drop(target_cat_col, axis=1)
@@ -177,7 +180,7 @@ def get_forecast(future_df, series_id, model_object, date_col, target_column, ta
177180
)
178181
yhat_clean = pd.DataFrame(yhat, index=yhat.index, columns=["yhat"])
179182
return yhat_clean['yhat'].tolist()
180-
elif series_id in model_object and isinstance(model_object[series_id], AutoForecaster):
183+
elif model_name == AUTOMLX and series_id in model_object:
181184
# automlx model
182185
model = model_object[series_id]
183186
x_pred = future_df.drop(target_cat_col, axis=1)
@@ -188,7 +191,7 @@ def get_forecast(future_df, series_id, model_object, date_col, target_column, ta
188191
)
189192
return forecast[target_column].tolist()
190193
else:
191-
raise Exception( f"Invalid model object type: {type(model_object).__name__}.")
194+
raise Exception(f"Invalid model object type: {type(model_object).__name__}.")
192195

193196

194197
def predict(data, model=load_model()) -> dict:
@@ -211,20 +214,26 @@ def predict(data, model=load_model()) -> dict:
211214
models = model["models"]
212215
specs = model["spec"]
213216
horizon = specs["horizon"]
217+
model_name = specs["model"]
214218
date_col = specs["datetime_column"]
215219
target_column = specs["target_column"]
216-
forecasts = {}
217-
uri = f"{data['additional_data_uri']}"
218220
target_category_column = specs["target_category_columns"][0]
219-
signer = ads.common.auth.default_signer() if uri.lower().startswith("oci://") else {}
220-
additional_data = pd.read_csv(uri, storage_options=signer)
221+
222+
try:
223+
input_data = InputData(**data["additional_data"])
224+
except TypeError as e:
225+
raise ValueError(f"Validation error: {e}")
226+
additional_data = load_data(input_data)
227+
221228
unique_values = additional_data[target_category_column].unique()
229+
forecasts = {}
222230
for key in unique_values:
223231
try:
224232
s_id = str(key)
225233
filtered = additional_data[additional_data[target_category_column] == key]
226234
future = filtered.tail(horizon)
227-
forecast = get_forecast(future, s_id, models, date_col, target_column, target_category_column, horizon)
235+
forecast = get_forecast(future, model_name, s_id, models, date_col,
236+
target_column, target_category_column, horizon)
228237
forecasts[s_id] = json.dumps(forecast)
229238
except Exception as e:
230239
raise RuntimeError(

tests/operators/forecast/test_errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ def test_what_if_analysis(operator_setup, model):
754754
)
755755
yaml_i["spec"]["horizon"] = 10
756756
yaml_i["spec"]["model"] = model
757-
yaml_i["spec"]["what_if_analysis"] = True
757+
yaml_i["spec"]["what_if_analysis"] = {"model_name": f"model_{model}"}
758758

759759
run_yaml(
760760
tmpdirname=tmpdirname,

0 commit comments

Comments
 (0)