11
11
from functools import lru_cache
12
12
import logging
13
13
import ads
14
- from prophet import Prophet
15
- from neuralprophet import NeuralProphet
16
- from pmdarima import ARIMA
17
- from autots import AutoTS
18
- from automlx ._interface .forecaster import AutoForecaster
14
+ from ads .opctl .operator .lowcode .common .utils import load_data
15
+ from ads .opctl .operator .common .operator_config import InputData
19
16
20
17
ads .set_auth ("resource_principal" )
21
18
29
26
Inference script. This script is used for prediction by scoring server when schema is known.
30
27
"""
31
28
29
+ AUTOTS = "autots"
30
+ ARIMA = "arima"
31
+ PROPHET = "prophet"
32
+ NEURALPROPHET = "neuralprophet"
33
+ AUTOMLX = "automlx"
34
+
32
35
33
36
@lru_cache (maxsize = 10 )
34
37
def load_model ():
@@ -132,14 +135,14 @@ def post_inference(yhat):
132
135
return yhat
133
136
134
137
135
- def get_forecast (future_df , series_id , model_object , date_col , target_column , target_cat_col , horizon ):
138
+ def get_forecast (future_df , model_name , series_id , model_object , date_col , target_column , target_cat_col , horizon ):
136
139
date_col_name = date_col ["name" ]
137
140
date_col_format = date_col ["format" ]
138
141
future_df [target_cat_col ] = future_df [target_cat_col ].astype ("str" )
139
142
future_df [date_col_name ] = pd .to_datetime (
140
143
future_df [date_col_name ], format = date_col_format
141
144
)
142
- if isinstance ( model_object , AutoTS ) :
145
+ if model_name == AUTOTS :
143
146
series_id_col = "Series"
144
147
full_data_indexed = future_df .rename (columns = {target_cat_col : series_id_col })
145
148
additional_regressors = list (
@@ -152,12 +155,12 @@ def get_forecast(future_df, series_id, model_object, date_col, target_column, ta
152
155
)
153
156
pred_obj = model_object .predict (future_regressor = future_reg )
154
157
return pred_obj .forecast [series_id ].tolist ()
155
- elif series_id in model_object and isinstance ( model_object [ series_id ], Prophet ) :
158
+ elif model_name == PROPHET and series_id in model_object :
156
159
model = model_object [series_id ]
157
160
processed = future_df .rename (columns = {date_col_name : 'ds' , target_column : 'y' })
158
161
forecast = model .predict (processed )
159
162
return forecast ['yhat' ].tolist ()
160
- elif series_id in model_object and isinstance ( model_object [ series_id ], NeuralProphet ) :
163
+ elif model_name == NEURALPROPHET and series_id in model_object :
161
164
model = model_object [series_id ]
162
165
model .restore_trainer ()
163
166
accepted_regressors = list (model .config_regressors .keys ())
@@ -166,7 +169,7 @@ def get_forecast(future_df, series_id, model_object, date_col, target_column, ta
166
169
future ["y" ] = None
167
170
forecast = model .predict (future )
168
171
return forecast ['yhat1' ].tolist ()
169
- elif series_id in model_object and isinstance ( model_object [ series_id ], ARIMA ) :
172
+ elif model_name == ARIMA and series_id in model_object :
170
173
model = model_object [series_id ]
171
174
future_df = future_df .set_index (date_col_name )
172
175
x_pred = future_df .drop (target_cat_col , axis = 1 )
@@ -177,7 +180,7 @@ def get_forecast(future_df, series_id, model_object, date_col, target_column, ta
177
180
)
178
181
yhat_clean = pd .DataFrame (yhat , index = yhat .index , columns = ["yhat" ])
179
182
return yhat_clean ['yhat' ].tolist ()
180
- elif series_id in model_object and isinstance ( model_object [ series_id ], AutoForecaster ) :
183
+ elif model_name == AUTOMLX and series_id in model_object :
181
184
# automlx model
182
185
model = model_object [series_id ]
183
186
x_pred = future_df .drop (target_cat_col , axis = 1 )
@@ -188,7 +191,7 @@ def get_forecast(future_df, series_id, model_object, date_col, target_column, ta
188
191
)
189
192
return forecast [target_column ].tolist ()
190
193
else :
191
- raise Exception ( f"Invalid model object type: { type (model_object ).__name__ } ." )
194
+ raise Exception (f"Invalid model object type: { type (model_object ).__name__ } ." )
192
195
193
196
194
197
def predict (data , model = load_model ()) -> dict :
@@ -211,20 +214,26 @@ def predict(data, model=load_model()) -> dict:
211
214
models = model ["models" ]
212
215
specs = model ["spec" ]
213
216
horizon = specs ["horizon" ]
217
+ model_name = specs ["model" ]
214
218
date_col = specs ["datetime_column" ]
215
219
target_column = specs ["target_column" ]
216
- forecasts = {}
217
- uri = f"{ data ['additional_data_uri' ]} "
218
220
target_category_column = specs ["target_category_columns" ][0 ]
219
- signer = ads .common .auth .default_signer () if uri .lower ().startswith ("oci://" ) else {}
220
- additional_data = pd .read_csv (uri , storage_options = signer )
221
+
222
+ try :
223
+ input_data = InputData (** data ["additional_data" ])
224
+ except TypeError as e :
225
+ raise ValueError (f"Validation error: { e } " )
226
+ additional_data = load_data (input_data )
227
+
221
228
unique_values = additional_data [target_category_column ].unique ()
229
+ forecasts = {}
222
230
for key in unique_values :
223
231
try :
224
232
s_id = str (key )
225
233
filtered = additional_data [additional_data [target_category_column ] == key ]
226
234
future = filtered .tail (horizon )
227
- forecast = get_forecast (future , s_id , models , date_col , target_column , target_category_column , horizon )
235
+ forecast = get_forecast (future , model_name , s_id , models , date_col ,
236
+ target_column , target_category_column , horizon )
228
237
forecasts [s_id ] = json .dumps (forecast )
229
238
except Exception as e :
230
239
raise RuntimeError (
0 commit comments