Skip to content

Commit 52a0929

Browse files
committed
fix NP formatting
1 parent f56b5ff commit 52a0929

File tree

3 files changed

+10
-12
lines changed

3 files changed

+10
-12
lines changed

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
SupportedModels,
4646
SpeedAccuracyMode,
4747
)
48-
from ..const import ForecastOutputColumns
4948
from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
5049
from ads.common.decorator.runtime_dependency import runtime_dependency
5150
from .forecast_datasets import ForecastDatasets, ForecastOutput
@@ -711,9 +710,6 @@ def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> Non
711710
local_kernel_explnr_df = pd.DataFrame(
712711
local_kernel_explnr_vals, columns=data.columns
713712
)
714-
local_kernel_explnr_df = local_kernel_explnr_df.rename(
715-
{datetime_col_name: ForecastOutputColumns.DATE}, axis=0
716-
)
717713
self.local_explanation[series_id] = local_kernel_explnr_df
718714

719715
def get_explain_predict_fn(self, series_id, fcst_col_name="yhat"):

ads/opctl/operator/lowcode/forecast/model/neuralprophet.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,15 +440,17 @@ def explain_model(self):
440440
for s_id, expl_df in self.explanations_info.items():
441441
expl_df = expl_df.rename(rename_cols, axis=1)
442442
# Local Expl
443-
self.local_explanation[s_id] = self.get_horizon(expl_df)
443+
self.local_explanation[s_id] = self.get_horizon(expl_df).drop(
444+
["future_regressors_additive"], axis=1
445+
)
444446
self.local_explanation[s_id]["Series"] = s_id
445-
447+
self.local_explanation[s_id].index.rename(self.dt_column_name, inplace=True)
446448
# Global Expl
447449
g_expl = self.drop_horizon(expl_df).mean()
448450
g_expl.name = s_id
449451
global_expl.append(g_expl)
450452
self.global_explanation = pd.concat(global_expl, axis=1)
451-
self.formatted_global_explanation = self.global_explanation.drop(
453+
self.global_explanation = self.global_explanation.drop(
452454
index=["future_regressors_additive"], axis=0
453455
)
454456
self.formatted_global_explanation = (

tests/operators/forecast/test_datasets.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020

2121
DATASETS_LIST = [
22-
# "AirPassengersDataset",
23-
# "AusBeerDataset",
22+
"AirPassengersDataset",
23+
"AusBeerDataset",
2424
"AustralianTourismDataset",
2525
# # "ETTh1Dataset",
2626
# # "ETTh2Dataset",
@@ -34,7 +34,7 @@
3434
# # "ILINetDataset",
3535
# # "IceCreamHeaterDataset",
3636
# # "MonthlyMilkDataset",
37-
# "MonthlyMilkIncompleteDataset",
37+
"MonthlyMilkIncompleteDataset",
3838
# # "SunspotsDataset",
3939
# # "TaylorDataset",
4040
# # "TemperatureDataset",
@@ -43,7 +43,7 @@
4343
# "UberTLCDataset",
4444
# "WeatherDataset",
4545
# "WineDataset",
46-
# "WoolyDataset",
46+
"WoolyDataset",
4747
]
4848

4949
MODELS = [
@@ -93,7 +93,7 @@ def verify_explanations(global_fn, local_fn, yaml_i, additional_cols):
9393
glb_expl = pd.read_csv(global_fn, index_col=0)
9494
loc_expl = pd.read_csv(local_fn)
9595
assert loc_expl.shape[0] == PERIODS
96-
for x in ["Date", "Series"]:
96+
for x in [yaml_i["spec"]["datetime_column"]["name"], "Series"]:
9797
assert x in set(loc_expl.columns)
9898
for x in additional_cols:
9999
assert x in set(loc_expl.columns)

0 commit comments

Comments
 (0)