@@ -149,36 +149,42 @@ def _train_model(self, i, s_id, df, model_kwargs):
149
149
logger .debug (f"-----------------Model { i } ----------------------" )
150
150
logger .debug (forecast .tail ())
151
151
152
- # TODO; could also extract trend and seasonality?
153
- cols_to_read = set (
154
- forecast .columns [forecast .columns .str .startswith ("future_regressor" )]
155
- + ["ds" , "trend" ]
156
- )
157
- cols_to_read = cols_to_read - {
158
- "future_regressors_additive" ,
159
- "future_regressors_multiplicative" ,
160
- }
161
- combine_terms = cols_to_read - set (self .accepted_regressors [s_id ])
162
- temp_df = (
163
- forecast [list (cols_to_read )]
164
- .rename ({"ds" : "Date" }, axis = 1 )
165
- .set_index ("Date" )
166
- )
167
- temp_df [self .spec .target_column ] = temp_df [combine_terms ].sum (axis = 1 )
168
- self .explanations_info [s_id ] = temp_df .drop (combine_terms , axis = 1 )
169
-
170
152
self .outputs [s_id ] = forecast
153
+ upper_bound_col_name = f"yhat1 { model_kwargs ['quantiles' ][1 ]* 100 } %"
154
+ lower_bound_col_name = f"yhat1 { model_kwargs ['quantiles' ][0 ]* 100 } %"
171
155
self .forecast_output .populate_series_output (
172
156
series_id = s_id ,
173
157
fit_val = self .drop_horizon (forecast ["yhat1" ]).values ,
174
158
forecast_val = self .get_horizon (forecast ["yhat1" ]).values ,
175
- upper_bound = self .get_horizon (
176
- forecast [f"yhat1 { model_kwargs ['quantiles' ][1 ]* 100 } %" ]
177
- ).values ,
178
- lower_bound = self .get_horizon (
179
- forecast [f"yhat1 { model_kwargs ['quantiles' ][0 ]* 100 } %" ]
180
- ).values ,
159
+ upper_bound = self .get_horizon (forecast [upper_bound_col_name ]).values ,
160
+ lower_bound = self .get_horizon (forecast [lower_bound_col_name ]).values ,
181
161
)
162
+ core_columns = set (forecast .columns ) - set (
163
+ [
164
+ "y" ,
165
+ "yhat1" ,
166
+ upper_bound_col_name ,
167
+ lower_bound_col_name ,
168
+ "future_regressors_additive" ,
169
+ "future_regressors_multiplicative" ,
170
+ ]
171
+ )
172
+ exog_variables = set (
173
+ filter (lambda x : x .startswith ("future_regressor_" ), list (core_columns ))
174
+ )
175
+ combine_terms = list (core_columns - exog_variables - set (["ds" ]))
176
+ temp_df = (
177
+ forecast [list (core_columns )]
178
+ .rename ({"ds" : "Date" }, axis = 1 )
179
+ .set_index ("Date" )
180
+ )
181
+ if combine_terms :
182
+ temp_df [self .spec .target_column ] = temp_df [combine_terms ].sum (axis = 1 )
183
+ temp_df = temp_df .drop (combine_terms , axis = 1 )
184
+ else :
185
+ temp_df [self .spec .target_column ] = 0
186
+ # Todo: check for columns that were dropped, and set them to 0
187
+ self .explanations_info [s_id ] = temp_df
182
188
183
189
self .trainers [s_id ] = model .trainer
184
190
self .models [s_id ] = {}
0 commit comments