Skip to content

Commit 5ce130b

Browse files
codeloopahosler
authored andcommitted
add the report section, update the pr commenets
1 parent 990bdd6 commit 5ce130b

File tree

3 files changed

+58
-7
lines changed

3 files changed

+58
-7
lines changed

THIRD_PARTY_LICENSES.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,12 @@ python-fire
447447
* Source code: https://github.com/google/python-fire
448448
* Project home: https://github.com/google/python-fire
449449

450+
mlforecast
451+
* Copyright 2024 Nixtla
452+
* License: Apache License 2.0
453+
* Source code: https://github.com/Nixtla/mlforecast
454+
* Project home: https://github.com/Nixtla/mlforecast
455+
450456
=======
451457
=============================== Licenses ===============================
452458
------------------------------------------------------------------------

ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ def get_all_data_long(self, include_horizon=True):
159159
on=[self._datetime_column_name, ForecastOutputColumns.SERIES],
160160
).reset_index()
161161

162-
def get_all_data_long_test(self):
162+
def get_all_data_long_forecast_horizon(self):
163+
"""Returns all data in long format for the forecast horizon."""
163164
test_data = pd.merge(
164165
self.historical_data.data,
165166
self.additional_data.data,

ads/opctl/operator/lowcode/forecast/model/ml_forecast.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from ads.opctl import logger
55
from ads.common.decorator import runtime_dependency
6+
from ads.opctl.operator.lowcode.forecast.utils import _select_plot_list
67
from .base_model import ForecastOperatorBaseModel
78
from .forecast_datasets import ForecastDatasets, ForecastOutput
89
from ..operator_config import ForecastOperatorConfig
@@ -105,7 +106,7 @@ def _train_model(self, data_train, data_test, model_kwargs):
105106
ignore_index=True,
106107
).fillna(0),
107108
)
108-
fitted_values = fcst.forecast_fitted_values()
109+
self.fitted_values = fcst.forecast_fitted_values()
109110
for s_id in self.datasets.list_series_ids():
110111
self.forecast_output.init_series_output(
111112
series_id=s_id,
@@ -114,8 +115,8 @@ def _train_model(self, data_train, data_test, model_kwargs):
114115

115116
self.forecast_output.populate_series_output(
116117
series_id=s_id,
117-
fit_val=fitted_values[
118-
fitted_values[ForecastOutputColumns.SERIES] == s_id
118+
fit_val=self.fitted_values[
119+
self.fitted_values[ForecastOutputColumns.SERIES] == s_id
119120
].forecast.values,
120121
forecast_val=self.outputs[
121122
self.outputs[ForecastOutputColumns.SERIES] == s_id
@@ -135,7 +136,6 @@ def _train_model(self, data_train, data_test, model_kwargs):
135136

136137
logger.debug("===========Done===========")
137138

138-
return self.forecast_output.get_forecast_long()
139139
except Exception as e:
140140
self.errors_dict[self.spec.model] = {
141141
"model_name": self.spec.model,
@@ -154,7 +154,51 @@ def _build_model(self) -> pd.DataFrame:
154154
dt_column=self.spec.datetime_column.name,
155155
)
156156
self._train_model(data_train, data_test, model_kwargs)
157-
pass
157+
return self.forecast_output.get_forecast_long()
158158

159159
def _generate_report(self):
160-
pass
160+
"""
161+
Generates the report for the model
162+
"""
163+
import datapane as dp
164+
from utilsforecast.plotting import plot_series
165+
166+
# Section 1: Forecast Overview
167+
sec1_text = dp.Text(
168+
"## Forecast Overview \n"
169+
"These plots show your forecast in the context of historical data."
170+
)
171+
sec_1 = _select_plot_list(
172+
lambda s_id: plot_series(
173+
self.datasets.get_all_data_long(include_horizon=False),
174+
pd.concat([self.fitted_values,self.outputs], axis=0, ignore_index=True),
175+
id_col=ForecastOutputColumns.SERIES,
176+
time_col=self.spec.datetime_column.name,
177+
target_col=self.original_target_column,
178+
seed=42,
179+
ids=[s_id],
180+
),
181+
self.datasets.list_series_ids(),
182+
)
183+
184+
# Section 2: MlForecast Model Parameters
185+
sec2_text = dp.Text(
186+
"## MlForecast Model Parameters \n"
187+
"These are the parameters used for the MlForecast model."
188+
)
189+
blocks = [
190+
dp.HTML(
191+
s_id[1],
192+
label=s_id[0],
193+
)
194+
for _, s_id in enumerate(self.model_parameters.items())
195+
]
196+
sec_2 = dp.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
197+
198+
all_sections = [sec1_text, sec_1, sec2_text, sec_2]
199+
model_description = dp.Text("mlforecast is a framework to perform time series forecasting using machine learning models"
200+
"with the option to scale to massive amounts of data using remote clusters."
201+
"Fastest implementations of feature engineering for time series forecasting in Python."
202+
"Support for exogenous variables and static covariates.")
203+
204+
return model_description, all_sections

0 commit comments

Comments
 (0)