Skip to content

Commit 754bb71

Browse files
authored
Merge branch 'main' into operators/automlx_24.2
2 parents e8e09b9 + d410494 commit 754bb71

File tree

13 files changed

+377
-9
lines changed

13 files changed

+377
-9
lines changed

.github/workflows/run-forecast-unit-tests.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ jobs:
3131

3232
steps:
3333
- uses: actions/checkout@v4
34+
with:
35+
fetch-depth: 0
36+
ref: ${{ github.event.pull_request.head.sha }}
37+
3438

3539
- uses: actions/setup-python@v5
3640
with:

THIRD_PARTY_LICENSES.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,12 @@ python-fire
447447
* Source code: https://github.com/google/python-fire
448448
* Project home: https://github.com/google/python-fire
449449

450+
mlforecast
451+
* Copyright 2024 Nixtla
452+
* License: Apache License 2.0
453+
* Source code: https://github.com/Nixtla/mlforecast
454+
* Project home: https://github.com/Nixtla/mlforecast
455+
450456
=======
451457
=============================== Licenses ===============================
452458
------------------------------------------------------------------------

ads/jobs/builders/infrastructure/dsc_job.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import inspect
1010
import logging
1111
import os
12+
import re
1213
import time
1314
import traceback
1415
import uuid
@@ -375,12 +376,13 @@ def delete(self, force_delete: bool = False) -> DSCJob:
375376
"""
376377
runs = self.run_list()
377378
for run in runs:
378-
if run.lifecycle_state in [
379-
DataScienceJobRun.LIFECYCLE_STATE_ACCEPTED,
380-
DataScienceJobRun.LIFECYCLE_STATE_IN_PROGRESS,
381-
DataScienceJobRun.LIFECYCLE_STATE_NEEDS_ATTENTION,
382-
]:
383-
run.cancel(wait_for_completion=True)
379+
if force_delete:
380+
if run.lifecycle_state in [
381+
DataScienceJobRun.LIFECYCLE_STATE_ACCEPTED,
382+
DataScienceJobRun.LIFECYCLE_STATE_IN_PROGRESS,
383+
DataScienceJobRun.LIFECYCLE_STATE_NEEDS_ATTENTION,
384+
]:
385+
run.cancel(wait_for_completion=True)
384386
run.delete()
385387
self.client.delete_job(self.id)
386388
return self
@@ -582,6 +584,25 @@ def logging(self) -> OCILog:
582584
id=self.log_id, log_group_id=self.log_details.log_group_id, **auth
583585
)
584586

587+
@property
588+
def exit_code(self):
589+
"""The exit code of the job run from the lifecycle details.
590+
Note that,
591+
None will be returned if the job run is not finished or failed without exit code.
592+
0 will be returned if job run succeeded.
593+
"""
594+
if self.lifecycle_state == self.LIFECYCLE_STATE_SUCCEEDED:
595+
return 0
596+
if not self.lifecycle_details:
597+
return None
598+
match = re.search(r"exit code (\d+)", self.lifecycle_details)
599+
if not match:
600+
return None
601+
try:
602+
return int(match.group(1))
603+
except Exception:
604+
return None
605+
585606
@staticmethod
586607
def _format_log(message: str, date_time: datetime.datetime) -> dict:
587608
"""Formats a message as log record with datetime.
@@ -655,6 +676,22 @@ def _check_and_print_status(self, prev_status) -> str:
655676
print(f"{timestamp} - {status}")
656677
return status
657678

679+
def wait(self, interval: float = SLEEP_INTERVAL):
680+
"""Waits for the job run until if finishes.
681+
682+
Parameters
683+
----------
684+
interval : float
685+
Time interval in seconds between each request to update the logs.
686+
Defaults to 3 (seconds).
687+
688+
"""
689+
self.sync()
690+
while self.status not in self.TERMINAL_STATES:
691+
time.sleep(interval)
692+
self.sync()
693+
return self
694+
658695
def watch(
659696
self,
660697
interval: float = SLEEP_INTERVAL,
@@ -830,6 +867,12 @@ def download(self, to_dir):
830867
self.job.download(to_dir)
831868
return self
832869

870+
def delete(self, force_delete: bool = False):
871+
if force_delete:
872+
self.cancel(wait_for_completion=True)
873+
super().delete()
874+
return
875+
833876

834877
# This is for backward compatibility
835878
DSCJobRun = DataScienceJobRun

ads/opctl/operator/lowcode/forecast/const.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class SupportedModels(str, metaclass=ExtendedEnumMeta):
1414
Prophet = "prophet"
1515
Arima = "arima"
1616
NeuralProphet = "neuralprophet"
17+
MLForecast = "mlforecast"
1718
AutoMLX = "automlx"
1819
AutoTS = "autots"
1920
Auto = "auto"

ads/opctl/operator/lowcode/forecast/environment.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ dependencies:
88
- oracle-ads>=2.9.0
99
- prophet
1010
- neuralprophet
11+
- mlforecast
1112
- pmdarima
1213
- statsmodels
1314
- report-creator

ads/opctl/operator/lowcode/forecast/model/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .base_model import ForecastOperatorBaseModel
1313
from .neuralprophet import NeuralProphetOperatorModel
1414
from .prophet import ProphetOperatorModel
15+
from .ml_forecast import MLForecastOperatorModel
1516
from ..utils import select_auto_model
1617
from .forecast_datasets import ForecastDatasets
1718

@@ -32,6 +33,7 @@ class ForecastOperatorModelFactory:
3233
SupportedModels.Prophet: ProphetOperatorModel,
3334
SupportedModels.Arima: ArimaOperatorModel,
3435
SupportedModels.NeuralProphet: NeuralProphetOperatorModel,
36+
SupportedModels.MLForecast: MLForecastOperatorModel,
3537
SupportedModels.AutoMLX: AutoMLXOperatorModel,
3638
SupportedModels.AutoTS: AutoTSOperatorModel
3739
}

ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ def create_horizon(self, spec, historical_data):
8686
pd.date_range(
8787
start=historical_data.get_max_time(),
8888
periods=spec.horizon + 1,
89-
freq=historical_data.freq,
89+
freq=historical_data.freq
90+
or pd.infer_freq(
91+
historical_data.data.reset_index()[spec.datetime_column.name][-5:]
92+
),
9093
),
9194
name=spec.datetime_column.name,
9295
)
@@ -135,6 +138,7 @@ def __init__(self, config: ForecastOperatorConfig):
135138

136139
self._horizon = config.spec.horizon
137140
self._datetime_column_name = config.spec.datetime_column.name
141+
self._target_col = config.spec.target_column
138142
self._load_data(config.spec)
139143

140144
def _load_data(self, spec):
@@ -158,6 +162,16 @@ def get_all_data_long(self, include_horizon=True):
158162
on=[self._datetime_column_name, ForecastOutputColumns.SERIES],
159163
).reset_index()
160164

165+
def get_all_data_long_forecast_horizon(self):
166+
"""Returns all data in long format for the forecast horizon."""
167+
test_data = pd.merge(
168+
self.historical_data.data,
169+
self.additional_data.data,
170+
how="outer",
171+
on=[self._datetime_column_name, ForecastOutputColumns.SERIES],
172+
).reset_index()
173+
return test_data[test_data[self._target_col].isnull()].reset_index(drop=True)
174+
161175
def get_data_multi_indexed(self):
162176
return pd.concat(
163177
[

0 commit comments

Comments
 (0)