@@ -19,6 +19,7 @@ def __init__(self, config: ForecastOperatorConfig, datasets: ForecastDatasets):
19
19
self .local_explanation = {}
20
20
self .formatted_global_explanation = None
21
21
self .formatted_local_explanation = None
22
+ self .date_col = config .spec .datetime_column .name
22
23
23
24
def set_kwargs (self ):
24
25
"""
@@ -73,8 +74,8 @@ def _train_model(self, data_train, data_test, model_kwargs):
73
74
alpha = model_kwargs ["lower_quantile" ],
74
75
),
75
76
},
76
- freq = pd .infer_freq (data_train ["Date" ].drop_duplicates ())
77
- or pd .infer_freq (data_train ["Date" ].drop_duplicates ()[- 5 :]),
77
+ freq = pd .infer_freq (data_train [self . date_col ].drop_duplicates ())
78
+ or pd .infer_freq (data_train [self . date_col ].drop_duplicates ()[- 5 :]),
78
79
target_transforms = [Differences ([12 ])],
79
80
lags = model_kwargs .get (
80
81
"lags" ,
@@ -104,7 +105,7 @@ def _train_model(self, data_train, data_test, model_kwargs):
104
105
data_train [self .model_columns ],
105
106
static_features = model_kwargs .get ("static_features" , []),
106
107
id_col = ForecastOutputColumns .SERIES ,
107
- time_col = self .spec . datetime_column . name ,
108
+ time_col = self .date_col ,
108
109
target_col = self .spec .target_column ,
109
110
fitted = True ,
110
111
max_horizon = None if num_models is False else self .spec .horizon ,
@@ -168,7 +169,7 @@ def _build_model(self) -> pd.DataFrame:
168
169
confidence_interval_width = self .spec .confidence_interval_width ,
169
170
horizon = self .spec .horizon ,
170
171
target_column = self .original_target_column ,
171
- dt_column = self .spec . datetime_column . name ,
172
+ dt_column = self .date_col ,
172
173
)
173
174
self ._train_model (data_train , data_test , model_kwargs )
174
175
return self .forecast_output .get_forecast_long ()
0 commit comments