Skip to content

[ML-50316] Refactor frequency_unit and frequency_quantity in automl runtime #165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions runtime/databricks/automl_runtime/forecast/deepar/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from mlflow.utils.environment import _mlflow_conda_env

from databricks.automl_runtime import version
from databricks.automl_runtime.forecast.frequency import Frequency
from databricks.automl_runtime.forecast.model import ForecastModel, mlflow_forecast_log_model
from databricks.automl_runtime.forecast.deepar.utils import set_index_and_fill_missing_time_steps

Expand All @@ -42,16 +43,15 @@ class DeepARModel(ForecastModel):
DeepAR mlflow model wrapper for forecasting.
"""

def __init__(self, model: PyTorchPredictor, horizon: int, frequency_unit: str, frequency_quantity: int,
def __init__(self, model: PyTorchPredictor, horizon: int, frequency: Frequency,
num_samples: int,
target_col: str, time_col: str,
id_cols: Optional[List[str]] = None) -> None:
"""
Initialize the DeepAR mlflow Python model wrapper
:param model: DeepAR model
:param horizon: the number of periods to forecast forward
:param frequency_unit: the frequency unit of the time series
:param frequency_quantity: the frequency quantity of the time series
:param frequency: the frequency of the time series
:param num_samples: the number of samples to draw from the distribution
:param target_col: the target column name
:param time_col: the time column name
Expand All @@ -61,8 +61,7 @@ def __init__(self, model: PyTorchPredictor, horizon: int, frequency_unit: str, f
super().__init__()
self._model = model
self._horizon = horizon
self._frequency_unit = frequency_unit
self._frequency_quantity = frequency_quantity
self._frequency = frequency
self._num_samples = num_samples
self._target_col = target_col
self._time_col = time_col
Expand Down Expand Up @@ -130,8 +129,7 @@ def predict_samples(self,

model_input_transformed = set_index_and_fill_missing_time_steps(model_input,
self._time_col,
self._frequency_unit,
self._frequency_quantity,
self._frequency,
self._id_cols)

test_ds = PandasDataset(model_input_transformed, target=self._target_col)
Expand Down
27 changes: 13 additions & 14 deletions runtime/databricks/automl_runtime/forecast/deepar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,25 @@
from typing import List, Optional

import pandas as pd
from databricks.automl_runtime.forecast.frequency import Frequency


def validate_and_generate_index(df: pd.DataFrame,
time_col: str,
frequency_unit: str,
frequency_quantity: int):
frequency: Frequency):
"""
Generate a complete time index for the given DataFrame based on the specified frequency.
- Ensures the time column is in datetime format.
- Validates consistency in the day of the month if frequency is "MS" (month start).
- Generates a new time index from the minimum to the maximum timestamp in the data.
:param df: The input DataFrame containing the time column.
:param time_col: The name of the time column.
:param frequency_unit: The frequency unit of the time series.
:param frequency_quantity: The frequency quantity of the time series.
:param frequency: The frequency of the time series.
:return: A complete time index covering the full range of the dataset.
:raises ValueError: If the day-of-month pattern is inconsistent for "MS" frequency.
"""
if frequency_unit.upper() != "MS":
return pd.date_range(df[time_col].min(), df[time_col].max(), freq=f"{frequency_quantity}{frequency_unit}")
if frequency.frequency_unit.upper() != "MS":
return pd.date_range(df[time_col].min(), df[time_col].max(), freq=f"{frequency.frequency_quantity}{frequency.frequency_unit}")

df[time_col] = pd.to_datetime(df[time_col]) # Ensure datetime format

Expand Down Expand Up @@ -67,8 +66,7 @@ def validate_and_generate_index(df: pd.DataFrame,
return new_index_full

def set_index_and_fill_missing_time_steps(df: pd.DataFrame, time_col: str,
frequency_unit: str,
frequency_quantity: int,
frequency: Frequency,
id_cols: Optional[List[str]] = None):
"""
Transform the input dataframe to an acceptable format for the GluonTS library.
Expand All @@ -78,8 +76,7 @@ def set_index_and_fill_missing_time_steps(df: pd.DataFrame, time_col: str,

:param df: the input dataframe that contains time_col
:param time_col: time column name
:param frequency_unit: the frequency unit of the time series
:param frequency_quantity: the frequency quantity of the time series
:param frequency: the frequency of the time series
:param id_cols: the column names of the identity columns for multi-series time series; None for single series
:return: single-series - transformed dataframe;
multi-series - dictionary of transformed dataframes, each key is the (concatenated) id of the time series
Expand All @@ -88,11 +85,13 @@ def set_index_and_fill_missing_time_steps(df: pd.DataFrame, time_col: str,

# We need to adjust the frequency_unit for pd.date_range if it is weekly,
# otherwise it would always be "W-SUN"
if frequency_unit.upper() == "W":
if frequency.frequency_unit.upper() == "W":
weekday_name = total_min.strftime("%a").upper() # e.g., "FRI"
frequency_unit = f"W-{weekday_name}"
adjusted_frequency = Frequency(frequency_unit=f"W-{weekday_name}", frequency_quantity=frequency.frequency_quantity)
else:
adjusted_frequency = Frequency(frequency_unit=frequency.frequency_unit, frequency_quantity=frequency.frequency_quantity)

valid_index = validate_and_generate_index(df=df, time_col=time_col, frequency_unit=frequency_unit, frequency_quantity=frequency_quantity)
valid_index = validate_and_generate_index(df=df, time_col=time_col, frequency=adjusted_frequency)

if id_cols is not None:
df_dict = {}
Expand All @@ -111,7 +110,7 @@ def set_index_and_fill_missing_time_steps(df: pd.DataFrame, time_col: str,
# Fill in missing time steps between the min and max time steps
df = df.reindex(valid_index)

if frequency_unit.upper() == "MS":
if frequency.frequency_unit.upper() == "MS":
# Truncate the day of month to avoid issues with pandas frequency check
df = df.to_period("M")

Expand Down
73 changes: 73 additions & 0 deletions runtime/databricks/automl_runtime/forecast/frequency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#
# Copyright (C) 2022 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from dataclasses import dataclass
from typing import ClassVar, Set

@dataclass(frozen=True)
class Frequency:
"""
Represents the frequency of a time series.

Attributes:
frequency_unit (str): The unit of time for the frequency.
frequency_quantity (int): The number of frequency_units in the period.

Valid frequency units: source of truth is OFFSET_ALIAS_MAP in forecast.__init__.py
- Weeks: "W", "W-SUN", "W-MON", "W-TUE", "W-WED", "W-THU", "W-FRI", "W-SAT" These are aliases for "W", used for DeepAR only
- Days: "d", "D", "days", "day"
- Hours: "hours", "hour", "hr", "h", "H
- Minutes: "m", "minute", "min", "minutes", "T"
- Seconds: "S", "seconds", "sec", "second"
- Months: "M", "MS", "month", "months"
- Quarters: "Q", "QS", "quarter", "quarters"
- Years: "Y", "YS", "year", "years"

Valid frequency quantities:
- For minutes: {1, 5, 10, 15, 30}
- For all other units: {1}
"""

VALID_FREQUENCY_UNITS: ClassVar[Set[str]] = {
"W", "W-SUN", "W-MON", "W-TUE", "W-WED", "W-THU", "W-FRI", "W-SAT",
"d", "D", "days", "day", "hours", "hour", "hr", "h", "H",
"m", "minute", "min", "minutes", "T", "S", "seconds",
"sec", "second", "M", "MS", "month", "months", "Q", "QS", "quarter",
"quarters", "Y", "YS", "year", "years"
}

VALID_MINUTE_QUANTITIES: ClassVar[Set[int]] = {1, 5, 10, 15, 30}
DEFAULT_QUANTITY: ClassVar[int] = 1 # Default for non-minute units

frequency_unit: str
frequency_quantity: int

def __post_init__(self):
if self.frequency_unit not in self.VALID_FREQUENCY_UNITS:
raise ValueError(f"Invalid frequency unit: {self.frequency_unit}")

if self.frequency_unit in {"m", "minute", "min", "minutes", "T"}:
if self.frequency_quantity not in self.VALID_MINUTE_QUANTITIES:
raise ValueError(
f"Invalid frequency quantity {self.frequency_quantity} for minutes. "
f"Allowed values: {sorted(self.VALID_MINUTE_QUANTITIES)}"
)
else:
if self.frequency_quantity != self.DEFAULT_QUANTITY:
raise ValueError(
f"Invalid frequency quantity {self.frequency_quantity} for {self.frequency_unit}. "
"Only 1 is allowed for this unit."
)

50 changes: 21 additions & 29 deletions runtime/databricks/automl_runtime/forecast/pmdarima/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from mlflow.utils.environment import _mlflow_conda_env

from databricks.automl_runtime.forecast import OFFSET_ALIAS_MAP, DATE_OFFSET_KEYWORD_MAP
from databricks.automl_runtime.forecast.frequency import Frequency
from databricks.automl_runtime.forecast.model import ForecastModel, mlflow_forecast_log_model
from databricks.automl_runtime.forecast.utils import calculate_period_differences, is_frequency_consistency, \
make_future_dataframe, make_single_future_dataframe
Expand Down Expand Up @@ -64,19 +65,18 @@ def model_env(self):
return ARIMA_CONDA_ENV

@staticmethod
def _get_ds_indices(start_ds: pd.Timestamp, periods: int, frequency_unit: str, frequency_quantity: int) -> pd.DatetimeIndex:
def _get_ds_indices(start_ds: pd.Timestamp, periods: int, frequency: Frequency) -> pd.DatetimeIndex:
"""
Create a DatetimeIndex with specified starting time and frequency, whose length is the given periods.
:param start_ds: the pd.Timestamp as the start of the DatetimeIndex.
:param periods: the length of the DatetimeIndex.
:param frequency_unit: the frequency unit of the DatetimeIndex.
:param frequency_quantity: the frequency quantity of the DatetimeIndex.
:param frequency: the frequency of the DatetimeIndex.
:return: a DatetimeIndex.
"""
ds_indices = pd.date_range(
start=start_ds,
periods=periods,
freq=pd.DateOffset(**DATE_OFFSET_KEYWORD_MAP[frequency_unit]) * frequency_quantity
freq=pd.DateOffset(**DATE_OFFSET_KEYWORD_MAP[frequency.frequency_unit]) * frequency.frequency_quantity
)
modified_start_ds = ds_indices.min()
if start_ds != modified_start_ds:
Expand All @@ -90,15 +90,13 @@ class ArimaModel(AbstractArimaModel):
ARIMA mlflow model wrapper for univariate forecasting.
"""

def __init__(self, pickled_model: bytes, horizon: int, frequency_unit: str,
frequency_quantity: int, start_ds: pd.Timestamp, end_ds: pd.Timestamp,
def __init__(self, pickled_model: bytes, horizon: int, frequency: Frequency, start_ds: pd.Timestamp, end_ds: pd.Timestamp,
time_col: str, exogenous_cols: Optional[List[str]] = None) -> None:
"""
Initialize the mlflow Python model wrapper for ARIMA.
:param pickled_model: the pickled ARIMA model as a bytes object.
:param horizon: int number of periods to forecast forward.
:param frequency_unit: the frequency unit of the time series
:param frequency_quantity: the frequency quantity of the time series
:param frequency: the frequency of the time series
:param start_ds: the start time of training data
:param end_ds: the end time of training data
:param time_col: the column name of the time column
Expand All @@ -108,8 +106,7 @@ def __init__(self, pickled_model: bytes, horizon: int, frequency_unit: str,
super().__init__()
self._pickled_model = pickled_model
self._horizon = horizon
self._frequency_unit = OFFSET_ALIAS_MAP[frequency_unit]
self._frequency_quantity = frequency_quantity
self._frequency = Frequency(frequency_unit=OFFSET_ALIAS_MAP[frequency.frequency_unit], frequency_quantity=frequency.frequency_quantity)
self._start_ds = pd.to_datetime(start_ds)
self._end_ds = pd.to_datetime(end_ds)
self._time_col = time_col
Expand Down Expand Up @@ -160,8 +157,7 @@ def make_future_dataframe(self, horizon: int = None, include_history: bool = Tru
start_time=self._start_ds,
end_time=self._end_ds,
horizon=horizon or self._horizon,
frequency_unit=self._frequency_unit,
frequency_quantity=self._frequency_quantity,
frequency=self._frequency,
include_history=include_history
)

Expand Down Expand Up @@ -196,7 +192,7 @@ def _predict_impl(self, input_df: pd.DataFrame) -> pd.DataFrame:
)
# Check if the time has correct frequency
consistency = df["ds"].apply(lambda x:
is_frequency_consistency(self._start_ds, x, self._frequency_unit, self._frequency_quantity)
is_frequency_consistency(self._start_ds, x, self._frequency)
).all()
if not consistency:
raise MlflowException(
Expand All @@ -207,7 +203,7 @@ def _predict_impl(self, input_df: pd.DataFrame) -> pd.DataFrame:
)
preds_pds = []
# Out-of-sample prediction if needed
horizon = calculate_period_differences(self._end_ds, max(df["ds"]), self._frequency_unit, self._frequency_quantity)
horizon = calculate_period_differences(self._end_ds, max(df["ds"]), self._frequency)
if horizon > 0:
X_future = df[df["ds"] > self._end_ds].set_index("ds")
future_pd = self._forecast(
Expand All @@ -233,8 +229,8 @@ def _predict_in_sample(
end_ds: pd.Timestamp = None,
X: pd.DataFrame = None) -> pd.DataFrame:
if start_ds and end_ds:
start_idx = calculate_period_differences(self._start_ds, start_ds, self._frequency_unit, self._frequency_quantity)
end_idx = calculate_period_differences(self._start_ds, end_ds, self._frequency_unit, self._frequency_quantity)
start_idx = calculate_period_differences(self._start_ds, start_ds, self._frequency)
end_idx = calculate_period_differences(self._start_ds, end_ds, self._frequency)
else:
start_ds = self._start_ds
end_ds = self._end_ds
Expand All @@ -246,8 +242,8 @@ def _predict_in_sample(
start=start_idx,
end=end_idx,
return_conf_int=True)
periods = calculate_period_differences(self._start_ds, end_ds, self._frequency_unit, self._frequency_quantity) + 1
ds_indices = self._get_ds_indices(start_ds=self._start_ds, periods=periods, frequency_unit=self._frequency_unit, frequency_quantity=self._frequency_quantity)[start_idx:]
periods = calculate_period_differences(self._start_ds, end_ds, self._frequency) + 1
ds_indices = self._get_ds_indices(start_ds=self._start_ds, periods=periods, frequency=self._frequency)[start_idx:]
in_sample_pd = pd.DataFrame({'ds': ds_indices, 'yhat': preds_in_sample})
in_sample_pd[["yhat_lower", "yhat_upper"]] = conf_in_sample
return in_sample_pd
Expand All @@ -261,7 +257,7 @@ def _forecast(
horizon,
X=X,
return_conf_int=True)
ds_indices = self._get_ds_indices(start_ds=self._end_ds, periods=horizon + 1, frequency_unit=self._frequency_unit, frequency_quantity=self._frequency_quantity)[1:]
ds_indices = self._get_ds_indices(start_ds=self._end_ds, periods=horizon + 1, frequency=self._frequency)[1:]
preds_pd = pd.DataFrame({'ds': ds_indices, 'yhat': preds})
preds_pd[["yhat_lower", "yhat_upper"]] = conf
return preds_pd
Expand All @@ -272,15 +268,14 @@ class MultiSeriesArimaModel(AbstractArimaModel):
ARIMA mlflow model wrapper for multivariate forecasting.
"""

def __init__(self, pickled_model_dict: Dict[Tuple, bytes], horizon: int, frequency_unit: str, frequency_quantity: int,
def __init__(self, pickled_model_dict: Dict[Tuple, bytes], horizon: int, frequency: Frequency,
start_ds_dict: Dict[Tuple, pd.Timestamp], end_ds_dict: Dict[Tuple, pd.Timestamp],
time_col: str, id_cols: List[str], exogenous_cols: Optional[List[str]] = None) -> None:
"""
Initialize the mlflow Python model wrapper for multiseries ARIMA.
:param pickled_model_dict: the dictionary of binarized ARIMA models for different time series.
:param horizon: int number of periods to forecast forward.
:param frequency_unit: the frequency unit of the time series
:param frequency_quantity: the frequency quantity of the time series
:param frequency: the frequency of the time series
:param start_ds_dict: the dictionary of the starting time of each time series in training data.
:param end_ds_dict: the dictionary of the end time of each time series in training data.
:param time_col: the column name of the time column
Expand All @@ -291,8 +286,7 @@ def __init__(self, pickled_model_dict: Dict[Tuple, bytes], horizon: int, frequen
super().__init__()
self._pickled_models = pickled_model_dict
self._horizon = horizon
self._frequency_unit = frequency_unit
self._frequency_quantity = frequency_quantity
self._frequency = frequency
self._starts = start_ds_dict
self._ends = end_ds_dict
self._time_col = time_col
Expand Down Expand Up @@ -335,8 +329,7 @@ def make_future_dataframe(
start_time=self._starts,
end_time=self._ends,
horizon=horizon,
frequency_unit=self._frequency_unit,
frequency_quantity=self._frequency_quantity,
frequency=self._frequency,
include_history=include_history,
groups=groups,
identity_column_names=self._id_cols
Expand Down Expand Up @@ -367,7 +360,7 @@ def _predict_timeseries_single_id(
horizon: int,
include_history: bool = True,
df: Optional[pd.DataFrame] = None) -> pd.DataFrame:
arima_model_single_id = ArimaModel(self._pickled_models[id_], self._horizon, self._frequency_unit, self._frequency_quantity,
arima_model_single_id = ArimaModel(self._pickled_models[id_], self._horizon, self._frequency,
self._starts[id_], self._ends[id_], self._time_col, self._exogenous_cols)
preds_df = arima_model_single_id.predict_timeseries(horizon, include_history, df)
for id, col_name in zip(id_, self._id_cols):
Expand Down Expand Up @@ -408,8 +401,7 @@ def _predict_single_id(self, df: pd.DataFrame) -> pd.DataFrame:
id_ = df["ts_id"].to_list()[0]
arima_model_single_id = ArimaModel(self._pickled_models[id_],
self._horizon,
self._frequency_unit,
self._frequency_quantity,
self._frequency,
self._starts[id_],
self._ends[id_],
self._time_col,
Expand Down
Loading