Skip to content

Commit 2098b24

Browse files
[AQUA] Time series forecasting model support (#1220)
2 parents d55e593 + 3383bca commit 2098b24

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

ads/aqua/model/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class ModelTask(ExtendedEnum):
2626
TEXT_GENERATION = "text-generation"
2727
IMAGE_TEXT_TO_TEXT = "image-text-to-text"
2828
IMAGE_TO_TEXT = "image-to-text"
29+
TIME_SERIES_FORECASTING = "time-series-forecasting"
2930

3031

3132
class FineTuningMetricCategories(ExtendedEnum):

ads/aqua/modeldeployment/deployment.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@
4848
)
4949
from ads.aqua.data import AquaResourceIdentifier
5050
from ads.aqua.model import AquaModelApp
51-
from ads.aqua.model.constants import AquaModelMetadataKeys, ModelCustomMetadataFields
51+
from ads.aqua.model.constants import (
52+
AquaModelMetadataKeys,
53+
ModelCustomMetadataFields,
54+
ModelTask,
55+
)
5256
from ads.aqua.model.utils import (
5357
extract_base_model_from_ft,
5458
extract_fine_tune_artifacts_path,
@@ -215,6 +219,14 @@ def create(
215219
freeform_tags=freeform_tags,
216220
defined_tags=defined_tags,
217221
)
222+
task_tag = aqua_model.freeform_tags.get(Tags.TASK, UNKNOWN)
223+
if (
224+
task_tag == ModelTask.TIME_SERIES_FORECASTING
225+
or task_tag == ModelTask.TIME_SERIES_FORECASTING.replace("-", "_")
226+
):
227+
create_deployment_details.env_var.update(
228+
{Tags.TASK.upper(): ModelTask.TIME_SERIES_FORECASTING}
229+
)
218230
return self._create(
219231
aqua_model=aqua_model,
220232
create_deployment_details=create_deployment_details,

0 commit comments

Comments
 (0)