Skip to content

Commit 3c1e665

Browse files
committed
small fixes
1 parent 7d1aead commit 3c1e665

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -666,10 +666,11 @@ def explain_model(self):
666666
lambda x: x.timestamp()
667667
)
668668

669-
# Explainer fails when boolean columns are passed
670-
_, data_trimmed_encoded = _label_encode_dataframe(
671-
data_trimmed, no_encode={datetime_col_name, self.original_target_column}
672-
)
669+
# Explainer fails when boolean columns are passed for arima
670+
if self.spec.model == SupportedModels.Arima:
671+
_, data_trimmed_encoded = _label_encode_dataframe(
672+
data_trimmed, no_encode={datetime_col_name, self.original_target_column}
673+
)
673674

674675
kernel_explnr = PermutationExplainer(
675676
model=explain_predict_fn, masker=data_trimmed_encoded
@@ -714,15 +715,17 @@ def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> Non
714715
kernel_explainer: The kernel explainer object to use for generating explanations.
715716
"""
716717
data = self.datasets.get_horizon_at_series(s_id=series_id)
718+
# columns that were dropped in train_model in arima, should be dropped here as well
717719
if self.spec.model == SupportedModels.Arima and series_id in self.constant_cols:
718720
data = data.drop(columns=self.constant_cols[series_id])
719721
data[datetime_col_name] = datetime_to_seconds(data[datetime_col_name])
720722
data = data.reset_index(drop=True)
721723

722-
# Explainer fails when boolean columns are passed
723-
_, data = _label_encode_dataframe(
724-
data, no_encode={datetime_col_name, self.original_target_column}
725-
)
724+
# Explainer fails when boolean columns are passed for arima
725+
if self.spec.model == SupportedModels.Arima:
726+
_, data = _label_encode_dataframe(
727+
data, no_encode={datetime_col_name, self.original_target_column}
728+
)
726729
# Generate local SHAP values using the kernel explainer
727730
local_kernel_explnr_vals = kernel_explainer.shap_values(data)
728731

ads/opctl/operator/lowcode/forecast/model/prophet.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,6 @@ def _train_model(self, i, series_id, df, model_kwargs):
133133
}
134134

135135
def _build_model(self) -> pd.DataFrame:
136-
from prophet import Prophet
137-
from prophet.diagnostics import cross_validation, performance_metrics
138136

139137
full_data_dict = self.datasets.get_data_by_series()
140138
self.models = dict()
@@ -160,6 +158,8 @@ def _build_model(self) -> pd.DataFrame:
160158
return self.forecast_output.get_forecast_long()
161159

162160
def run_tuning(self, data_i, model_kwargs_i):
161+
from prophet import Prophet
162+
from prophet.diagnostics import cross_validation, performance_metrics
163163
def objective(trial):
164164
params = {
165165
"seasonality_mode": trial.suggest_categorical(
@@ -245,7 +245,6 @@ def objective(trial):
245245
def _generate_report(self):
246246
import datapane as dp
247247
from prophet.plot import add_changepoints_to_plot
248-
self.models = dict()
249248
series_ids = self.models.keys()
250249
all_sections = []
251250
if len(series_ids) > 0:

0 commit comments

Comments
 (0)