Skip to content

[ML-47076] Support custom frequencies #162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion runtime/databricks/automl_runtime/forecast/deepar/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class DeepARModel(ForecastModel):
DeepAR mlflow model wrapper for forecasting.
"""

def __init__(self, model: PyTorchPredictor, horizon: int, frequency: str,
def __init__(self, model: PyTorchPredictor, horizon: int, frequency: str, frequency_quantity: int,
num_samples: int,
target_col: str, time_col: str,
id_cols: Optional[List[str]] = None) -> None:
Expand All @@ -51,6 +51,7 @@ def __init__(self, model: PyTorchPredictor, horizon: int, frequency: str,
:param model: DeepAR model
:param horizon: the number of periods to forecast forward
:param frequency: the frequency of the time series
:param frequency_quantity: the frequency quantity of the time series
:param num_samples: the number of samples to draw from the distribution
:param target_col: the target column name
:param time_col: the time column name
Expand All @@ -61,6 +62,7 @@ def __init__(self, model: PyTorchPredictor, horizon: int, frequency: str,
self._model = model
self._horizon = horizon
self._frequency = frequency
self._frequency_quantity = frequency_quantity
self._num_samples = num_samples
self._target_col = target_col
self._time_col = time_col
Expand Down Expand Up @@ -129,6 +131,7 @@ def predict_samples(self,
model_input_transformed = set_index_and_fill_missing_time_steps(model_input,
self._time_col,
self._frequency,
self._frequency_quantity,
self._id_cols)

test_ds = PandasDataset(model_input_transformed, target=self._target_col)
Expand Down
7 changes: 4 additions & 3 deletions runtime/databricks/automl_runtime/forecast/deepar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pandas as pd


def validate_and_generate_index(df: pd.DataFrame, time_col: str, frequency: str):
def validate_and_generate_index(df: pd.DataFrame, time_col: str, frequency: str, frequency_quantity: int):
"""
Generate a complete time index for the given DataFrame based on the specified frequency.
- Ensures the time column is in datetime format.
Expand All @@ -31,7 +31,7 @@ def validate_and_generate_index(df: pd.DataFrame, time_col: str, frequency: str)
:raises ValueError: If the day-of-month pattern is inconsistent for "MS" frequency.
"""
if frequency.upper() != "MS":
return pd.date_range(df[time_col].min(), df[time_col].max(), freq=frequency)
return pd.date_range(df[time_col].min(), df[time_col].max(), freq=f"{frequency_quantity}{frequency}")

df[time_col] = pd.to_datetime(df[time_col]) # Ensure datetime format

Expand Down Expand Up @@ -64,6 +64,7 @@ def validate_and_generate_index(df: pd.DataFrame, time_col: str, frequency: str)

def set_index_and_fill_missing_time_steps(df: pd.DataFrame, time_col: str,
frequency: str,
frequency_quantity: int,
id_cols: Optional[List[str]] = None):
"""
Transform the input dataframe to an acceptable format for the GluonTS library.
Expand All @@ -86,7 +87,7 @@ def set_index_and_fill_missing_time_steps(df: pd.DataFrame, time_col: str,
weekday_name = total_min.strftime("%a").upper() # e.g., "FRI"
frequency = f"W-{weekday_name}"

new_index_full = validate_and_generate_index(df=df, time_col=time_col, frequency=frequency)
new_index_full = validate_and_generate_index(df=df, time_col=time_col, frequency=frequency, frequency_quantity=frequency_quantity)

if id_cols is not None:
df_dict = {}
Expand Down
29 changes: 17 additions & 12 deletions runtime/databricks/automl_runtime/forecast/pmdarima/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def model_env(self):
return ARIMA_CONDA_ENV

@staticmethod
def _get_ds_indices(start_ds: pd.Timestamp, periods: int, frequency: str) -> pd.DatetimeIndex:
def _get_ds_indices(start_ds: pd.Timestamp, periods: int, frequency: str, frequency_quantity: int) -> pd.DatetimeIndex:
"""
Create a DatetimeIndex with specified starting time and frequency, whose length is the given periods.
:param start_ds: the pd.Timestamp as the start of the DatetimeIndex.
Expand All @@ -75,7 +75,7 @@ def _get_ds_indices(start_ds: pd.Timestamp, periods: int, frequency: str) -> pd.
ds_indices = pd.date_range(
start=start_ds,
periods=periods,
freq=pd.DateOffset(**DATE_OFFSET_KEYWORD_MAP[frequency])
freq=pd.DateOffset(**DATE_OFFSET_KEYWORD_MAP[frequency]) * frequency_quantity
)
modified_start_ds = ds_indices.min()
if start_ds != modified_start_ds:
Expand All @@ -90,7 +90,7 @@ class ArimaModel(AbstractArimaModel):
"""

def __init__(self, pickled_model: bytes, horizon: int, frequency: str,
start_ds: pd.Timestamp, end_ds: pd.Timestamp,
frequency_quantity: int, start_ds: pd.Timestamp, end_ds: pd.Timestamp,
time_col: str, exogenous_cols: Optional[List[str]] = None) -> None:
"""
Initialize the mlflow Python model wrapper for ARIMA.
Expand All @@ -107,6 +107,7 @@ def __init__(self, pickled_model: bytes, horizon: int, frequency: str,
self._pickled_model = pickled_model
self._horizon = horizon
self._frequency = OFFSET_ALIAS_MAP[frequency]
self._frequency_quantity = frequency_quantity
self._start_ds = pd.to_datetime(start_ds)
self._end_ds = pd.to_datetime(end_ds)
self._time_col = time_col
Expand Down Expand Up @@ -158,6 +159,7 @@ def make_future_dataframe(self, horizon: int = None, include_history: bool = Tru
end_time=self._end_ds,
horizon=horizon or self._horizon,
frequency=self._frequency,
frequency_quantity=self._frequency_quantity,
include_history=include_history
)

Expand Down Expand Up @@ -192,7 +194,7 @@ def _predict_impl(self, input_df: pd.DataFrame) -> pd.DataFrame:
)
# Check if the time has correct frequency
consistency = df["ds"].apply(lambda x:
is_frequency_consistency(self._start_ds, x, self._frequency)
is_frequency_consistency(self._start_ds, x, self._frequency, self._frequency_quantity)
).all()
if not consistency:
raise MlflowException(
Expand All @@ -203,7 +205,7 @@ def _predict_impl(self, input_df: pd.DataFrame) -> pd.DataFrame:
)
preds_pds = []
# Out-of-sample prediction if needed
horizon = calculate_period_differences(self._end_ds, max(df["ds"]), self._frequency)
horizon = calculate_period_differences(self._end_ds, max(df["ds"]), self._frequency, self._frequency_quantity)
if horizon > 0:
X_future = df[df["ds"] > self._end_ds].set_index("ds")
future_pd = self._forecast(
Expand All @@ -229,8 +231,8 @@ def _predict_in_sample(
end_ds: pd.Timestamp = None,
X: pd.DataFrame = None) -> pd.DataFrame:
if start_ds and end_ds:
start_idx = calculate_period_differences(self._start_ds, start_ds, self._frequency)
end_idx = calculate_period_differences(self._start_ds, end_ds, self._frequency)
start_idx = calculate_period_differences(self._start_ds, start_ds, self._frequency, self._frequency_quantity)
end_idx = calculate_period_differences(self._start_ds, end_ds, self._frequency, self._frequency_quantity)
else:
start_ds = self._start_ds
end_ds = self._end_ds
Expand All @@ -242,8 +244,8 @@ def _predict_in_sample(
start=start_idx,
end=end_idx,
return_conf_int=True)
periods = calculate_period_differences(self._start_ds, end_ds, self._frequency) + 1
ds_indices = self._get_ds_indices(start_ds=self._start_ds, periods=periods, frequency=self._frequency)[start_idx:]
periods = calculate_period_differences(self._start_ds, end_ds, self._frequency, self._frequency_quantity) + 1
ds_indices = self._get_ds_indices(start_ds=self._start_ds, periods=periods, frequency=self._frequency, frequency_quantity=self._frequency_quantity)[start_idx:]
in_sample_pd = pd.DataFrame({'ds': ds_indices, 'yhat': preds_in_sample})
in_sample_pd[["yhat_lower", "yhat_upper"]] = conf_in_sample
return in_sample_pd
Expand All @@ -257,7 +259,7 @@ def _forecast(
horizon,
X=X,
return_conf_int=True)
ds_indices = self._get_ds_indices(start_ds=self._end_ds, periods=horizon + 1, frequency=self._frequency)[1:]
ds_indices = self._get_ds_indices(start_ds=self._end_ds, periods=horizon + 1, frequency=self._frequency, frequency_quantity=self._frequency_quantity)[1:]
preds_pd = pd.DataFrame({'ds': ds_indices, 'yhat': preds})
preds_pd[["yhat_lower", "yhat_upper"]] = conf
return preds_pd
Expand All @@ -268,7 +270,7 @@ class MultiSeriesArimaModel(AbstractArimaModel):
ARIMA mlflow model wrapper for multivariate forecasting.
"""

def __init__(self, pickled_model_dict: Dict[Tuple, bytes], horizon: int, frequency: str,
def __init__(self, pickled_model_dict: Dict[Tuple, bytes], horizon: int, frequency: str, frequency_quantity: int,
start_ds_dict: Dict[Tuple, pd.Timestamp], end_ds_dict: Dict[Tuple, pd.Timestamp],
time_col: str, id_cols: List[str], exogenous_cols: Optional[List[str]] = None) -> None:
"""
Expand All @@ -287,6 +289,7 @@ def __init__(self, pickled_model_dict: Dict[Tuple, bytes], horizon: int, frequen
self._pickled_models = pickled_model_dict
self._horizon = horizon
self._frequency = frequency
self._frequency_quantity = frequency_quantity
self._starts = start_ds_dict
self._ends = end_ds_dict
self._time_col = time_col
Expand Down Expand Up @@ -330,6 +333,7 @@ def make_future_dataframe(
end_time=self._ends,
horizon=horizon,
frequency=self._frequency,
frequency_quantity=self._frequency_quantity,
include_history=include_history,
groups=groups,
identity_column_names=self._id_cols
Expand Down Expand Up @@ -360,7 +364,7 @@ def _predict_timeseries_single_id(
horizon: int,
include_history: bool = True,
df: Optional[pd.DataFrame] = None) -> pd.DataFrame:
arima_model_single_id = ArimaModel(self._pickled_models[id_], self._horizon, self._frequency,
arima_model_single_id = ArimaModel(self._pickled_models[id_], self._horizon, self._frequency, self._frequency_quantity,
self._starts[id_], self._ends[id_], self._time_col, self._exogenous_cols)
preds_df = arima_model_single_id.predict_timeseries(horizon, include_history, df)
for id, col_name in zip(id_, self._id_cols):
Expand Down Expand Up @@ -402,6 +406,7 @@ def _predict_single_id(self, df: pd.DataFrame) -> pd.DataFrame:
arima_model_single_id = ArimaModel(self._pickled_models[id_],
self._horizon,
self._frequency,
self._frequency_quantity,
self._starts[id_],
self._ends[id_],
self._time_col,
Expand Down
27 changes: 15 additions & 12 deletions runtime/databricks/automl_runtime/forecast/pmdarima/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ArimaEstimator:

def __init__(self, horizon: int, frequency_unit: str, metric: str, seasonal_periods: List[int],
num_folds: int = 20, max_steps: int = 150, exogenous_cols: Optional[List[str]] = None,
split_cutoff: Optional[pd.Timestamp] = None) -> None:
split_cutoff: Optional[pd.Timestamp] = None, frequency_quantity: int = 1) -> None:
"""
:param horizon: Number of periods to forecast forward
:param frequency_unit: Frequency of the time series
Expand All @@ -53,6 +53,7 @@ def __init__(self, horizon: int, frequency_unit: str, metric: str, seasonal_peri
"""
self._horizon = horizon
self._frequency_unit = OFFSET_ALIAS_MAP[frequency_unit]
self._frequency_quantity = frequency_quantity
self._metric = metric
self._seasonal_periods = seasonal_periods
self._num_folds = num_folds
Expand All @@ -70,14 +71,14 @@ def fit(self, df: pd.DataFrame) -> pd.DataFrame:
history_pd["ds"] = pd.to_datetime(history_pd["ds"])

# Check if the time has consistent frequency
self._validate_ds_freq(history_pd, self._frequency_unit)
self._validate_ds_freq(history_pd, self._frequency_unit, self._frequency_quantity)

history_periods = utils.calculate_period_differences(
history_pd['ds'].min(), history_pd['ds'].max(), self._frequency_unit
history_pd['ds'].min(), history_pd['ds'].max(), self._frequency_unit, self._frequency_quantity
)
if history_periods + 1 != history_pd['ds'].size:
# Impute missing time steps
history_pd = self._fill_missing_time_steps(history_pd, self._frequency_unit)
history_pd = self._fill_missing_time_steps(history_pd, self._frequency_unit, self._frequency_quantity)


# Tune seasonal periods
Expand All @@ -87,26 +88,28 @@ def fit(self, df: pd.DataFrame) -> pd.DataFrame:
try:
# this check mirrors the the default behavior by prophet
if history_periods < 2 * m:
_logger.warning(f"Skipping seasonal_period={m} ({self._frequency_unit}). Dataframe timestamps must span at least two seasonality periods, but only spans {history_periods} {self._frequency_unit}""")
_logger.warning(f"Skipping seasonal_period={m} ({self._frequency_quantity}{self._frequency_unit}). Dataframe timestamps must span at least two seasonality periods, but only spans {history_periods} {self._frequency_quantity}{self._frequency_unit}""")
continue
# Prophet also rejects the seasonality periods if the seasonality period timedelta is less than the shortest timedelta in the dataframe.
# However, this cannot happen in ARIMA because _fill_missing_time_steps imputes values for each _frequency_unit,
# so the minimum valid seasonality period is always 1

validation_horizon = utils.get_validation_horizon(history_pd, self._horizon, self._frequency_unit)
validation_horizon = utils.get_validation_horizon(history_pd, self._horizon, self._frequency_unit, self._frequency_quantity)
if self._split_cutoff:
cutoffs = utils.generate_custom_cutoffs(
history_pd,
horizon=validation_horizon,
unit=self._frequency_unit,
split_cutoff=self._split_cutoff
split_cutoff=self._split_cutoff,
frequency_quantity=self._frequency_quantity,
)
else:
cutoffs = utils.generate_cutoffs(
history_pd,
horizon=validation_horizon,
unit=self._frequency_unit,
num_folds=self._num_folds,
frequency_quantity=self._frequency_quantity,
)

result = self._fit_predict(history_pd, cutoffs=cutoffs, seasonal_period=m, max_steps=self._max_steps)
Expand Down Expand Up @@ -150,22 +153,22 @@ def _fit_predict(self, df: pd.DataFrame, cutoffs: List[pd.Timestamp], seasonal_p
return {"metrics": metrics, "model": arima_model}

@staticmethod
def _fill_missing_time_steps(df: pd.DataFrame, frequency: str):
def _fill_missing_time_steps(df: pd.DataFrame, frequency: str, frequency_quantity: int):
# Forward fill missing time steps
df_filled = df.set_index("ds").resample(rule=OFFSET_ALIAS_MAP[frequency]).pad().reset_index()
df_filled = df.set_index("ds").resample(rule=f"{frequency_quantity}{OFFSET_ALIAS_MAP[frequency]}").pad().reset_index()
start_ds, modified_start_ds = df["ds"].min(), df_filled["ds"].min()
if start_ds != modified_start_ds:
offset = modified_start_ds - start_ds
df_filled["ds"] = df_filled["ds"] - offset
return df_filled

@staticmethod
def _validate_ds_freq(df: pd.DataFrame, frequency: str):
def _validate_ds_freq(df: pd.DataFrame, frequency: str, frequency_qantity: int):
start_ds = df["ds"].min()
consistency = df["ds"].apply(lambda x:
utils.is_frequency_consistency(start_ds, x, frequency)
utils.is_frequency_consistency(start_ds, x, frequency, frequency_qantity)
).all()
if not consistency:
raise ValueError(
f"Input time column includes different frequency than the specified frequency {frequency}."
f"Input time column includes different frequency than the specified frequency {frequency_qantity}{frequency}."
)
Loading
Loading