Skip to content

[ML-47076] Support custom frequencies for Prophet #163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
149 changes: 149 additions & 0 deletions runtime/databricks/automl_runtime/forecast/deepar/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#
# Copyright (C) 2024 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Optional

import gluonts
import mlflow
import pandas as pd
from gluonts.dataset.pandas import PandasDataset
from gluonts.torch.model.predictor import PyTorchPredictor
from mlflow.utils.environment import _mlflow_conda_env

from databricks.automl_runtime import version
from databricks.automl_runtime.forecast.model import ForecastModel, mlflow_forecast_log_model
from databricks.automl_runtime.forecast.deepar.utils import set_index_and_fill_missing_time_steps

DEEPAR_ADDITIONAL_PIP_DEPS = [
f"gluonts[torch]=={gluonts.__version__}",
f"pandas=={pd.__version__}",
f"databricks-automl-runtime=={version.__version__}"
]

DEEPAR_CONDA_ENV = _mlflow_conda_env(
additional_pip_deps=DEEPAR_ADDITIONAL_PIP_DEPS
)


class DeepARModel(ForecastModel):
"""
DeepAR mlflow model wrapper for forecasting.
"""

def __init__(self, model: PyTorchPredictor, horizon: int, frequency: str,
num_samples: int,
target_col: str, time_col: str,
id_cols: Optional[List[str]] = None) -> None:
"""
Initialize the DeepAR mlflow Python model wrapper
:param model: DeepAR model
:param horizon: the number of periods to forecast forward
:param frequency: the frequency of the time series
:param num_samples: the number of samples to draw from the distribution
:param target_col: the target column name
:param time_col: the time column name
:param id_cols: the column names of the identity columns for multi-series time series; None for single series
"""

super().__init__()
self._model = model
self._horizon = horizon
self._frequency = frequency
self._num_samples = num_samples
self._target_col = target_col
self._time_col = time_col
self._id_cols = id_cols

@property
def model_env(self):
return DEEPAR_CONDA_ENV

def predict(self,
context: mlflow.pyfunc.model.PythonModelContext,
model_input: pd.DataFrame) -> pd.DataFrame:
"""
Predict the future dataframe given the history dataframe
:param context: A :class:`~PythonModelContext` instance containing artifacts that the model
can use to perform inference.
:param model_input: Input dataframe that contains the history data
:return: predicted pd.DataFrame that starts after the last timestamp in the input dataframe,
and predicts the horizon using the mean of the samples
"""
required_cols = [self._target_col, self._time_col]
if self._id_cols:
required_cols += self._id_cols
self._validate_cols(model_input, required_cols)

forecast_sample_list = self.predict_samples(model_input, num_samples=self._num_samples)

pred_df = pd.concat(
[
forecast.mean_ts.rename('yhat').reset_index().assign(item_id=forecast.item_id)
for forecast in forecast_sample_list
],
ignore_index=True
)

pred_df = pred_df.rename(columns={'index': self._time_col})
if self._id_cols:
id_col_name = '-'.join(self._id_cols)
pred_df = pred_df.rename(columns={'item_id': id_col_name})
else:
pred_df = pred_df.drop(columns='item_id')

pred_df[self._time_col] = pred_df[self._time_col].dt.to_timestamp()

return pred_df

def predict_samples(self,
model_input: pd.DataFrame,
num_samples: int = None) -> List[gluonts.model.forecast.SampleForecast]:
"""
Predict the future samples given the history dataframe
:param model_input: Input dataframe that contains the history data
:param num_samples: the number of samples to draw from the distribution
:return: List of SampleForecast, where each SampleForecast contains num_samples sampled forecasts
"""
if num_samples is None:
num_samples = self._num_samples

# Group by the time column in case there are multiple rows for each time column,
# for example, the user didn't provide all the identity columns for a multi-series dataset
group_cols = [self._time_col]
if self._id_cols:
group_cols += self._id_cols
model_input = model_input.groupby(group_cols).agg({self._target_col: "mean"}).reset_index()

model_input_transformed = set_index_and_fill_missing_time_steps(model_input,
self._time_col,
self._frequency,
self._id_cols)

test_ds = PandasDataset(model_input_transformed, target=self._target_col)

forecast_iter = self._model.predict(test_ds, num_samples=num_samples)
forecast_sample_list = list(forecast_iter)

return forecast_sample_list


def mlflow_deepar_log_model(deepar_model: DeepARModel,
sample_input: pd.DataFrame = None) -> None:
"""
Log the DeepAR model to mlflow
:param deepar_model: DeepAR mlflow PythonModel wrapper
:param sample_input: sample input Dataframes for model inference
"""
mlflow_forecast_log_model(deepar_model, sample_input)
112 changes: 112 additions & 0 deletions runtime/databricks/automl_runtime/forecast/deepar/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#
# Copyright (C) 2024 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Optional

import pandas as pd


def validate_and_generate_index(df: pd.DataFrame, time_col: str, frequency: str):
"""
Generate a complete time index for the given DataFrame based on the specified frequency.
- Ensures the time column is in datetime format.
- Validates consistency in the day of the month if frequency is "MS" (month start).
- Generates a new time index from the minimum to the maximum timestamp in the data.
:param df: The input DataFrame containing the time column.
:param time_col: The name of the time column.
:param frequency: The frequency of the time series.
:return: A complete time index covering the full range of the dataset.
:raises ValueError: If the day-of-month pattern is inconsistent for "MS" frequency.
"""
if frequency.upper() != "MS":
return pd.date_range(df[time_col].min(), df[time_col].max(), freq=frequency)

df[time_col] = pd.to_datetime(df[time_col]) # Ensure datetime format

# Extract unique days
unique_days = df[time_col].dt.day.unique()

if len(unique_days) == 1:
# All dates have the same day-of-month, considered consistent
day_of_month = unique_days[0]
else:
# Check if all dates are last days of their respective months
is_last_day = (df[time_col] + pd.offsets.MonthEnd(0)) == df[time_col]
if is_last_day.all():
day_of_month = "MonthEnd"
else:
raise ValueError("Inconsistent day of the month found in time column.")

# Generate new index based on detected pattern
total_min, total_max = df[time_col].min(), df[time_col].max()
month_starts = pd.date_range(start=total_min.to_period("M").to_timestamp(),
end=total_max.to_period("M").to_timestamp(),
freq="MS")

if day_of_month == "MonthEnd":
new_index_full = month_starts + pd.offsets.MonthEnd(0)
else:
new_index_full = month_starts.map(lambda d: d.replace(day=day_of_month))

return new_index_full

def set_index_and_fill_missing_time_steps(df: pd.DataFrame, time_col: str,
frequency: str,
id_cols: Optional[List[str]] = None):
"""
Transform the input dataframe to an acceptable format for the GluonTS library.

- Set the time column as the index
- Impute missing time steps between the min and max time steps

:param df: the input dataframe that contains time_col
:param time_col: time column name
:param frequency: the frequency of the time series
:param id_cols: the column names of the identity columns for multi-series time series; None for single series
:return: single-series - transformed dataframe;
multi-series - dictionary of transformed dataframes, each key is the (concatenated) id of the time series
"""
total_min, total_max = df[time_col].min(), df[time_col].max()

# We need to adjust the frequency for pd.date_range if it is weekly,
# otherwise it would always be "W-SUN"
if frequency.upper() == "W":
weekday_name = total_min.strftime("%a").upper() # e.g., "FRI"
frequency = f"W-{weekday_name}"

new_index_full = validate_and_generate_index(df=df, time_col=time_col, frequency=frequency)

if id_cols is not None:
df_dict = {}
for grouped_id, grouped_df in df.groupby(id_cols):
if isinstance(grouped_id, tuple):
ts_id = "-".join([str(x) for x in grouped_id])
else:
ts_id = str(grouped_id)
df_dict[ts_id] = (grouped_df.set_index(time_col).sort_index()
.reindex(new_index_full).drop(id_cols, axis=1))

return df_dict

df = df.set_index(time_col).sort_index()

# Fill in missing time steps between the min and max time steps
df = df.reindex(new_index_full)

if frequency.upper() == "MS":
# Truncate the day of month to avoid issues with pandas frequency check
df = df.to_period("M")

return df
10 changes: 7 additions & 3 deletions runtime/databricks/automl_runtime/forecast/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,14 @@ def mlflow_forecast_log_model(forecast_model: ForecastModel,
:param forecast_model: Forecast model wrapper
:param sample_input: sample input Dataframes for model inference
"""
# log the model without signature if infer_signature is failed.
# TODO: [ML-46185] we should not be logging without a signature since it cannot be registered to UC then
try:
signature = forecast_model.infer_signature(sample_input)
except Exception: # noqa
signature = None
mlflow.pyfunc.log_model("model", conda_env=forecast_model.model_env,
python_model=forecast_model, signature=signature)
mlflow.pyfunc.log_model(
artifact_path="model",
conda_env=forecast_model.model_env,
python_model=forecast_model,
signature=signature
)
12 changes: 8 additions & 4 deletions runtime/databricks/automl_runtime/forecast/pmdarima/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@
from databricks.automl_runtime.forecast.model import ForecastModel, mlflow_forecast_log_model
from databricks.automl_runtime.forecast.utils import calculate_period_differences, is_frequency_consistency, \
make_future_dataframe, make_single_future_dataframe
from databricks.automl_runtime import version


ARIMA_ADDITIONAL_PIP_DEPS = [
f"pmdarima=={pmdarima.__version__}",
f"pandas=={pd.__version__}",
f"databricks-automl-runtime=={version.__version__}"
]

ARIMA_CONDA_ENV = _mlflow_conda_env(
additional_pip_deps=[
f"pmdarima=={pmdarima.__version__}",
f"pandas=={pd.__version__}",
]
additional_pip_deps=ARIMA_ADDITIONAL_PIP_DEPS
)


Expand Down
28 changes: 21 additions & 7 deletions runtime/databricks/automl_runtime/forecast/pmdarima/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class ArimaEstimator:
"""

def __init__(self, horizon: int, frequency_unit: str, metric: str, seasonal_periods: List[int],
num_folds: int = 20, max_steps: int = 150, exogenous_cols: Optional[List[str]] = None) -> None:
num_folds: int = 20, max_steps: int = 150, exogenous_cols: Optional[List[str]] = None,
split_cutoff: Optional[pd.Timestamp] = None) -> None:
"""
:param horizon: Number of periods to forecast forward
:param frequency_unit: Frequency of the time series
Expand All @@ -45,6 +46,10 @@ def __init__(self, horizon: int, frequency_unit: str, metric: str, seasonal_peri
:param max_steps: Max steps for stepwise auto_arima
:param exogenous_cols: Optional list of column names of exogenous variables. If provided, these columns are
used as additional features in arima model.
:param split_cutoff: Optional cutoff specified by user. If provided,
it is the starting point of cutoffs for cross validation.
For tuning job, it is the cutoff between train and validate split.
For training job, it is the cutoff bewteen validate and test split.
"""
self._horizon = horizon
self._frequency_unit = OFFSET_ALIAS_MAP[frequency_unit]
Expand All @@ -53,6 +58,7 @@ def __init__(self, horizon: int, frequency_unit: str, metric: str, seasonal_peri
self._num_folds = num_folds
self._max_steps = max_steps
self._exogenous_cols = exogenous_cols
self._split_cutoff = split_cutoff

def fit(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Expand Down Expand Up @@ -88,12 +94,20 @@ def fit(self, df: pd.DataFrame) -> pd.DataFrame:
# so the minimum valid seasonality period is always 1

validation_horizon = utils.get_validation_horizon(history_pd, self._horizon, self._frequency_unit)
cutoffs = utils.generate_cutoffs(
history_pd,
horizon=validation_horizon,
unit=self._frequency_unit,
num_folds=self._num_folds,
)
if self._split_cutoff:
cutoffs = utils.generate_custom_cutoffs(
history_pd,
horizon=validation_horizon,
unit=self._frequency_unit,
split_cutoff=self._split_cutoff
)
else:
cutoffs = utils.generate_cutoffs(
history_pd,
horizon=validation_horizon,
unit=self._frequency_unit,
num_folds=self._num_folds,
)

result = self._fit_predict(history_pd, cutoffs=cutoffs, seasonal_period=m, max_steps=self._max_steps)
metric = result["metrics"]["smape"]
Expand Down
Loading
Loading