Skip to content

Commit cbdb2bd

Browse files
authored
bug patch in prophet min/max handling (#1195)
2 parents 33c9966 + f29c9cb commit cbdb2bd

File tree

10 files changed

+345
-231
lines changed

10 files changed

+345
-231
lines changed

ads/opctl/operator/lowcode/common/utils.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import tempfile
1212
from typing import List, Union
1313

14+
import cloudpickle
1415
import fsspec
1516
import oracledb
1617
import pandas as pd
@@ -126,7 +127,26 @@ def load_data(data_spec, storage_options=None, **kwargs):
126127
return data
127128

128129

130+
def _safe_write(fn, **kwargs):
131+
try:
132+
fn(**kwargs)
133+
except Exception:
134+
logger.warning(f'Failed to write file {kwargs.get("filename", "UNKNOWN")}')
135+
136+
129137
def write_data(data, filename, format, storage_options=None, index=False, **kwargs):
138+
return _safe_write(
139+
fn=_write_data,
140+
data=data,
141+
filename=filename,
142+
format=format,
143+
storage_options=storage_options,
144+
index=index,
145+
**kwargs,
146+
)
147+
148+
149+
def _write_data(data, filename, format, storage_options=None, index=False, **kwargs):
130150
disable_print()
131151
if not format:
132152
_, format = os.path.splitext(filename)
@@ -143,11 +163,24 @@ def write_data(data, filename, format, storage_options=None, index=False, **kwar
143163

144164

145165
def write_json(json_dict, filename, storage_options=None):
166+
return _safe_write(
167+
fn=_write_json,
168+
json_dict=json_dict,
169+
filename=filename,
170+
storage_options=storage_options,
171+
)
172+
173+
174+
def _write_json(json_dict, filename, storage_options=None):
146175
with fsspec.open(filename, mode="w", **storage_options) as f:
147176
f.write(json.dumps(json_dict))
148177

149178

150179
def write_simple_json(data, path):
180+
return _safe_write(fn=_write_simple_json, data=data, path=path)
181+
182+
183+
def _write_simple_json(data, path):
151184
if ObjectStorageDetails.is_oci_path(path):
152185
storage_options = default_signer()
153186
else:
@@ -156,6 +189,60 @@ def write_simple_json(data, path):
156189
json.dump(data, f, indent=4)
157190

158191

192+
def write_file(local_filename, remote_filename, storage_options, **kwargs):
193+
return _safe_write(
194+
fn=_write_file,
195+
local_filename=local_filename,
196+
remote_filename=remote_filename,
197+
storage_options=storage_options,
198+
**kwargs,
199+
)
200+
201+
202+
def _write_file(local_filename, remote_filename, storage_options, **kwargs):
203+
with open(local_filename) as f1:
204+
with fsspec.open(
205+
remote_filename,
206+
"w",
207+
**storage_options,
208+
) as f2:
209+
f2.write(f1.read())
210+
211+
212+
def load_pkl(filepath):
213+
return _safe_write(fn=_load_pkl, filepath=filepath)
214+
215+
216+
def _load_pkl(filepath):
217+
storage_options = {}
218+
if ObjectStorageDetails.is_oci_path(filepath):
219+
storage_options = default_signer()
220+
221+
with fsspec.open(filepath, "rb", **storage_options) as f:
222+
return cloudpickle.load(f)
223+
return None
224+
225+
226+
def write_pkl(obj, filename, output_dir, storage_options):
227+
return _safe_write(
228+
fn=_write_pkl,
229+
obj=obj,
230+
filename=filename,
231+
output_dir=output_dir,
232+
storage_options=storage_options,
233+
)
234+
235+
236+
def _write_pkl(obj, filename, output_dir, storage_options):
237+
pkl_path = os.path.join(output_dir, filename)
238+
with fsspec.open(
239+
pkl_path,
240+
"wb",
241+
**storage_options,
242+
) as f:
243+
cloudpickle.dump(obj, f)
244+
245+
159246
def merge_category_columns(data, target_category_columns):
160247
result = data.apply(
161248
lambda x: "__".join([str(x[col]) for col in target_category_columns]), axis=1
@@ -290,4 +377,8 @@ def disable_print():
290377

291378
# Restore
292379
def enable_print():
380+
try:
381+
sys.stdout.close()
382+
except Exception:
383+
pass
293384
sys.stdout = sys.__stdout__

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(self, config: ForecastOperatorConfig, datasets: ForecastDatasets):
3838
super().__init__(config, datasets)
3939
self.global_explanation = {}
4040
self.local_explanation = {}
41+
self.explainability_kwargs = {}
4142

4243
def set_kwargs(self):
4344
model_kwargs_cleaned = self.spec.model_kwargs
@@ -54,6 +55,9 @@ def set_kwargs(self):
5455
self.spec.preprocessing.enabled
5556
or model_kwargs_cleaned.get("preprocessing", True)
5657
)
58+
sample_ratio = model_kwargs_cleaned.pop("sample_to_feature_ratio", None)
59+
if sample_ratio is not None:
60+
self.explainability_kwargs = {"sample_to_feature_ratio": sample_ratio}
5761
return model_kwargs_cleaned, time_budget
5862

5963
def preprocess(self, data, series_id): # TODO: re-use self.le for explanations
@@ -445,6 +449,7 @@ def explain_model(self):
445449
else None,
446450
pd.DataFrame(data_i[self.spec.target_column]),
447451
task="forecasting",
452+
**self.explainability_kwargs,
448453
)
449454

450455
# Generate explanations for the forecast
@@ -518,7 +523,9 @@ def get_validation_score_and_metric(self, model):
518523
model_params = model.selected_model_params_
519524
if len(trials) > 0:
520525
score_col = [col for col in trials.columns if "Score" in col][0]
521-
validation_score = trials[trials.Hyperparameters == model_params][score_col].iloc[0]
526+
validation_score = trials[trials.Hyperparameters == model_params][
527+
score_col
528+
].iloc[0]
522529
else:
523530
validation_score = 0
524531
return -1 * validation_score
@@ -531,8 +538,12 @@ def generate_train_metrics(self) -> pd.DataFrame:
531538
for s_id in self.forecast_output.list_series_ids():
532539
try:
533540
metrics = {self.spec.metric.upper(): self.models[s_id]["score"]}
534-
metrics_df = pd.DataFrame.from_dict(metrics, orient="index", columns=[s_id])
535-
logger.warning("AutoMLX failed to generate training metrics. Recovering validation loss instead")
541+
metrics_df = pd.DataFrame.from_dict(
542+
metrics, orient="index", columns=[s_id]
543+
)
544+
logger.warning(
545+
"AutoMLX failed to generate training metrics. Recovering validation loss instead"
546+
)
536547
total_metrics = pd.concat([total_metrics, metrics_df], axis=1)
537548
except Exception as e:
538549
logger.debug(

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from abc import ABC, abstractmethod
1212
from typing import Tuple
1313

14-
import fsspec
1514
import numpy as np
1615
import pandas as pd
1716
import report_creator as rc
@@ -25,10 +24,13 @@
2524
disable_print,
2625
enable_print,
2726
human_time_friendly,
27+
load_pkl,
2828
merged_category_column_name,
2929
seconds_to_datetime,
3030
write_data,
31+
write_file,
3132
write_json,
33+
write_pkl,
3234
)
3335
from ads.opctl.operator.lowcode.forecast.utils import (
3436
_build_metrics_df,
@@ -38,8 +40,6 @@
3840
evaluate_train_metrics,
3941
get_auto_select_plot,
4042
get_forecast_plots,
41-
load_pkl,
42-
write_pkl,
4343
)
4444

4545
from ..const import (
@@ -493,13 +493,11 @@ def _save_report(
493493
enable_print()
494494

495495
report_path = os.path.join(unique_output_dir, self.spec.report_filename)
496-
with open(report_local_path) as f1:
497-
with fsspec.open(
498-
report_path,
499-
"w",
500-
**storage_options,
501-
) as f2:
502-
f2.write(f1.read())
496+
write_file(
497+
local_filename=report_local_path,
498+
remote_filename=report_path,
499+
storage_options=storage_options,
500+
)
503501

504502
# forecast csv report
505503
# todo: add test data into forecast.csv
@@ -576,7 +574,9 @@ def _save_report(
576574
# Round to 4 decimal places before writing
577575
global_expl_rounded = self.formatted_global_explanation.copy()
578576
global_expl_rounded = global_expl_rounded.apply(
579-
lambda col: np.round(col, 4) if np.issubdtype(col.dtype, np.number) else col
577+
lambda col: np.round(col, 4)
578+
if np.issubdtype(col.dtype, np.number)
579+
else col
580580
)
581581
if self.spec.generate_explanation_files:
582582
write_data(
@@ -598,7 +598,9 @@ def _save_report(
598598
# Round to 4 decimal places before writing
599599
local_expl_rounded = self.formatted_local_explanation.copy()
600600
local_expl_rounded = local_expl_rounded.apply(
601-
lambda col: np.round(col, 4) if np.issubdtype(col.dtype, np.number) else col
601+
lambda col: np.round(col, 4)
602+
if np.issubdtype(col.dtype, np.number)
603+
else col
602604
)
603605
if self.spec.generate_explanation_files:
604606
write_data(

ads/opctl/operator/lowcode/forecast/model/neuralprophet.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@
1919
from ads.opctl.operator.lowcode.common.utils import (
2020
disable_print,
2121
enable_print,
22-
)
23-
from ads.opctl.operator.lowcode.forecast.utils import (
24-
_select_plot_list,
2522
load_pkl,
2623
write_pkl,
2724
)
25+
from ads.opctl.operator.lowcode.forecast.utils import _select_plot_list
2826

2927
from ..const import DEFAULT_TRIALS, SupportedModels
3028
from ..operator_config import ForecastOperatorConfig
@@ -159,20 +157,18 @@ def _train_model(self, i, s_id, df, model_kwargs):
159157
upper_bound=self.get_horizon(forecast[upper_bound_col_name]).values,
160158
lower_bound=self.get_horizon(forecast[lower_bound_col_name]).values,
161159
)
162-
core_columns = set(forecast.columns) - set(
163-
[
164-
"y",
165-
"yhat1",
166-
upper_bound_col_name,
167-
lower_bound_col_name,
168-
"future_regressors_additive",
169-
"future_regressors_multiplicative",
170-
]
171-
)
160+
core_columns = set(forecast.columns) - {
161+
"y",
162+
"yhat1",
163+
upper_bound_col_name,
164+
lower_bound_col_name,
165+
"future_regressors_additive",
166+
"future_regressors_multiplicative",
167+
}
172168
exog_variables = set(
173169
filter(lambda x: x.startswith("future_regressor_"), list(core_columns))
174170
)
175-
combine_terms = list(core_columns - exog_variables - set(["ds"]))
171+
combine_terms = list(core_columns - exog_variables - {"ds"})
176172
temp_df = (
177173
forecast[list(core_columns)]
178174
.rename({"ds": "Date"}, axis=1)

ads/opctl/operator/lowcode/forecast/model/prophet.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,22 +43,27 @@ def _fit_model(data, params, additional_regressors):
4343
from prophet import Prophet
4444

4545
monthly_seasonality = params.pop("monthly_seasonality", False)
46-
data_floor = params.pop("min", None)
47-
data_cap = params.pop("max", None)
48-
if data_cap or data_floor:
46+
47+
has_min = "min" in params
48+
has_max = "max" in params
49+
if has_min or has_max:
4950
params["growth"] = "logistic"
51+
data_floor = params.pop("min", None)
52+
data_cap = params.pop("max", None)
53+
5054
model = Prophet(**params)
5155
if monthly_seasonality:
5256
model.add_seasonality(name="monthly", period=30.5, fourier_order=5)
5357
params["monthly_seasonality"] = monthly_seasonality
5458
for add_reg in additional_regressors:
5559
model.add_regressor(add_reg)
56-
if data_floor:
60+
61+
if has_min:
5762
data["floor"] = float(data_floor)
58-
params["floor"] = data_floor
59-
if data_cap:
63+
params["min"] = data_floor
64+
if has_max:
6065
data["cap"] = float(data_cap)
61-
params["cap"] = data_cap
66+
params["max"] = data_cap
6267

6368
model.fit(data)
6469
return model

0 commit comments

Comments
 (0)