@@ -61,6 +61,18 @@ def _train_model(self, data_train, data_test, model_kwargs):
61
61
"verbosity" : - 1 ,
62
62
"num_leaves" : 512 ,
63
63
}
64
+ additional_data_params = {}
65
+ if len (self .datasets .get_additional_data_column_names ()) > 0 :
66
+ additional_data_params = {
67
+ "target_transforms" : [Differences ([12 ])],
68
+ "lags" : model_kwargs .get ("lags" , [1 , 6 , 12 ]),
69
+ "lag_transforms" : (
70
+ {
71
+ 1 : [ExpandingMean ()],
72
+ 12 : [RollingMean (window_size = 24 )],
73
+ }
74
+ ),
75
+ }
64
76
65
77
fcst = MLForecast (
66
78
models = {
@@ -80,24 +92,7 @@ def _train_model(self, data_train, data_test, model_kwargs):
80
92
},
81
93
freq = pd .infer_freq (data_train [self .date_col ].drop_duplicates ())
82
94
or pd .infer_freq (data_train [self .date_col ].drop_duplicates ()[- 5 :]),
83
- target_transforms = [Differences ([12 ])],
84
- lags = model_kwargs .get (
85
- "lags" ,
86
- (
87
- [1 , 6 , 12 ]
88
- if len (self .datasets .get_additional_data_column_names ()) > 0
89
- else []
90
- ),
91
- ),
92
- lag_transforms = (
93
- {
94
- 1 : [ExpandingMean ()],
95
- 12 : [RollingMean (window_size = 24 )],
96
- }
97
- if len (self .datasets .get_additional_data_column_names ()) > 0
98
- else {}
99
- ),
100
- # date_features=[hour_index],
95
+ ** additional_data_params ,
101
96
)
102
97
103
98
num_models = model_kwargs .get ("recursive_models" , False )
@@ -164,6 +159,7 @@ def _train_model(self, data_train, data_test, model_kwargs):
164
159
"error" : str (e ),
165
160
}
166
161
logger .debug (f"Encountered Error: { e } . Skipping." )
162
+ raise e
167
163
168
164
def _build_model (self ) -> pd .DataFrame :
169
165
data_train = self .datasets .get_all_data_long (include_horizon = False )
0 commit comments