Skip to content

Commit 28e71e9

Browse files
Merge branch 'main' into ODSC-66852/watch_aqua_ft_job
2 parents eff4d17 + 5451c2c commit 28e71e9

33 files changed

+3619
-372
lines changed

.github/workflows/run-forecast-unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,6 @@ jobs:
5656
$CONDA/bin/conda init
5757
source /home/runner/.bashrc
5858
pip install -r test-requirements-operators.txt
59-
pip install "oracle-automlx[forecasting]>=24.4.0"
59+
pip install "oracle-automlx[forecasting]>=24.4.1"
6060
pip install pandas>=2.2.0
6161
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast

ads/aqua/common/enums.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5252
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
5353
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
5454
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
55+
56+
57+
class CustomInferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5558
AQUA_TEI_CONTAINER_FAMILY = "odsc-tei-serving"
5659

5760

ads/aqua/common/utils.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -838,21 +838,31 @@ def cleanup_local_hf_model_artifact(
838838
"""
839839
if local_dir and os.path.exists(local_dir):
840840
model_dir = os.path.join(local_dir, model_name)
841+
model_dir = (
842+
os.path.dirname(model_dir)
843+
if "/" in model_name or os.sep in model_name
844+
else model_dir
845+
)
846+
shutil.rmtree(model_dir, ignore_errors=True)
841847
if os.path.exists(model_dir):
842-
shutil.rmtree(model_dir)
843-
logger.debug(f"Deleted local model artifact directory: {model_dir}")
844-
845-
if not os.listdir(local_dir):
846-
shutil.rmtree(local_dir)
847-
logger.debug(f"Deleted local directory {model_dir} as it is empty.")
848+
logger.debug(
849+
f"Could not delete local model artifact directory: {model_dir}"
850+
)
851+
else:
852+
logger.debug(f"Deleted local model artifact directory: {model_dir}.")
848853

849854
hf_local_path = os.path.join(
850855
HF_HUB_CACHE, repo_folder_name(repo_id=model_name, repo_type="model")
851856
)
857+
shutil.rmtree(hf_local_path, ignore_errors=True)
858+
852859
if os.path.exists(hf_local_path):
853-
shutil.rmtree(hf_local_path)
854860
logger.debug(
855-
f"Deleted local Hugging Face cache directory {hf_local_path} for the model {model_name} "
861+
f"Could not clear the local Hugging Face cache directory {hf_local_path} for the model {model_name}."
862+
)
863+
else:
864+
logger.debug(
865+
f"Cleared contents of local Hugging Face cache directory {hf_local_path} for the model {model_name}."
856866
)
857867

858868

ads/aqua/extension/model_handler.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
from typing import Optional
@@ -8,6 +8,9 @@
88
from tornado.web import HTTPError
99

1010
from ads.aqua.common.decorator import handle_exceptions
11+
from ads.aqua.common.enums import (
12+
CustomInferenceContainerTypeFamily,
13+
)
1114
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1215
from ads.aqua.common.utils import (
1316
get_hf_model_info,
@@ -128,6 +131,10 @@ def post(self, *args, **kwargs): # noqa: ARG002
128131
download_from_hf = (
129132
str(input_data.get("download_from_hf", "false")).lower() == "true"
130133
)
134+
local_dir = input_data.get("local_dir")
135+
cleanup_model_cache = (
136+
str(input_data.get("cleanup_model_cache", "true")).lower() == "true"
137+
)
131138
inference_container_uri = input_data.get("inference_container_uri")
132139
allow_patterns = input_data.get("allow_patterns")
133140
ignore_patterns = input_data.get("ignore_patterns")
@@ -139,6 +146,8 @@ def post(self, *args, **kwargs): # noqa: ARG002
139146
model=model,
140147
os_path=os_path,
141148
download_from_hf=download_from_hf,
149+
local_dir=local_dir,
150+
cleanup_model_cache=cleanup_model_cache,
142151
inference_container=inference_container,
143152
finetuning_container=finetuning_container,
144153
compartment_id=compartment_id,
@@ -163,7 +172,9 @@ def put(self, id):
163172
raise HTTPError(400, Errors.NO_INPUT_DATA)
164173

165174
inference_container = input_data.get("inference_container")
175+
inference_container_uri = input_data.get("inference_container_uri")
166176
inference_containers = AquaModelApp.list_valid_inference_containers()
177+
inference_containers.extend(CustomInferenceContainerTypeFamily.values())
167178
if (
168179
inference_container is not None
169180
and inference_container not in inference_containers
@@ -176,7 +187,13 @@ def put(self, id):
176187
task = input_data.get("task")
177188
app = AquaModelApp()
178189
self.finish(
179-
app.edit_registered_model(id, inference_container, enable_finetuning, task)
190+
app.edit_registered_model(
191+
id,
192+
inference_container,
193+
inference_container_uri,
194+
enable_finetuning,
195+
task,
196+
)
180197
)
181198
app.clear_model_details_cache(model_id=id)
182199

ads/aqua/model/entities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ class ImportModelDetails(CLIBuilderMixin):
283283
os_path: str
284284
download_from_hf: Optional[bool] = True
285285
local_dir: Optional[str] = None
286-
delete_from_local: Optional[bool] = True
286+
cleanup_model_cache: Optional[bool] = True
287287
inference_container: Optional[str] = None
288288
finetuning_container: Optional[str] = None
289289
compartment_id: Optional[str] = None

ads/aqua/model/model.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
1616
from ads.aqua.app import AquaApp
1717
from ads.aqua.common.enums import (
18+
CustomInferenceContainerTypeFamily,
1819
FineTuningContainerTypeFamily,
1920
InferenceContainerTypeFamily,
2021
Tags,
@@ -377,8 +378,10 @@ def delete_model(self, model_id):
377378
f"Failed to delete model:{model_id}. Only registered models or finetuned model can be deleted."
378379
)
379380

380-
@telemetry(entry_point="plugin=model&action=delete", name="aqua")
381-
def edit_registered_model(self, id, inference_container, enable_finetuning, task):
381+
@telemetry(entry_point="plugin=model&action=edit", name="aqua")
382+
def edit_registered_model(
383+
self, id, inference_container, inference_container_uri, enable_finetuning, task
384+
):
382385
"""Edits the default config of unverified registered model.
383386
384387
Parameters
@@ -387,6 +390,8 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
387390
The model OCID.
388391
inference_container: str.
389392
The inference container family name
393+
inference_container_uri: str
394+
The inference container uri for embedding models
390395
enable_finetuning: str
391396
Flag to enable or disable finetuning over the model. Defaults to None
392397
task:
@@ -402,19 +407,44 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
402407
if ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None):
403408
if ds_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None):
404409
raise AquaRuntimeError(
405-
f"Failed to edit model:{id}. Only registered unverified models can be edited."
410+
"Only registered unverified models can be edited."
406411
)
407412
else:
408413
custom_metadata_list = ds_model.custom_metadata_list
409414
freeform_tags = ds_model.freeform_tags
410415
if inference_container:
411-
custom_metadata_list.add(
412-
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
413-
value=inference_container,
414-
category=MetadataCustomCategory.OTHER,
415-
description="Deployment container mapping for SMC",
416-
replace=True,
417-
)
416+
if (
417+
inference_container in CustomInferenceContainerTypeFamily
418+
and inference_container_uri is None
419+
):
420+
raise AquaRuntimeError(
421+
"Inference container URI must be provided."
422+
)
423+
else:
424+
custom_metadata_list.add(
425+
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
426+
value=inference_container,
427+
category=MetadataCustomCategory.OTHER,
428+
description="Deployment container mapping for SMC",
429+
replace=True,
430+
)
431+
if inference_container_uri:
432+
if (
433+
inference_container in CustomInferenceContainerTypeFamily
434+
or inference_container is None
435+
):
436+
custom_metadata_list.add(
437+
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER_URI,
438+
value=inference_container_uri,
439+
category=MetadataCustomCategory.OTHER,
440+
description=f"Inference container URI for {ds_model.display_name}",
441+
replace=True,
442+
)
443+
else:
444+
raise AquaRuntimeError(
445+
f"Inference container URI can be edited only with container values: {CustomInferenceContainerTypeFamily.values()}"
446+
)
447+
418448
if enable_finetuning is not None:
419449
if enable_finetuning.lower() == "true":
420450
custom_metadata_list.add(
@@ -449,9 +479,7 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
449479
)
450480
AquaApp().update_model(id, update_model_details)
451481
else:
452-
raise AquaRuntimeError(
453-
f"Failed to edit model:{id}. Only registered unverified models can be edited."
454-
)
482+
raise AquaRuntimeError("Only registered unverified models can be edited.")
455483

456484
def _fetch_metric_from_metadata(
457485
self,
@@ -870,8 +898,7 @@ def _create_model_catalog_entry(
870898
# only add cmd vars if inference container is not an SMC
871899
if (
872900
inference_container not in smc_container_set
873-
and inference_container
874-
== InferenceContainerTypeFamily.AQUA_TEI_CONTAINER_FAMILY
901+
and inference_container in CustomInferenceContainerTypeFamily.values()
875902
):
876903
cmd_vars = generate_tei_cmd_var(os_path)
877904
metadata.add(
@@ -1328,7 +1355,9 @@ def _download_model_from_hf(
13281355
if local_dir:
13291356
local_dir = os.path.join(local_dir, model_name)
13301357
os.makedirs(local_dir, exist_ok=True)
1331-
snapshot_download(
1358+
1359+
# if local_dir is not set, the return value points to the cached data folder
1360+
local_dir = snapshot_download(
13321361
repo_id=model_name,
13331362
local_dir=local_dir,
13341363
allow_patterns=allow_patterns,
@@ -1364,7 +1393,7 @@ def register(
13641393
ignore_patterns (list): Model files matching any of the patterns are not downloaded.
13651394
Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
13661395
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
1367-
delete_from_local (bool): Deletes downloaded files from local machine after model is successfully
1396+
cleanup_model_cache (bool): Deletes downloaded files from local machine after model is successfully
13681397
registered. Set to True by default.
13691398
13701399
Returns:
@@ -1477,7 +1506,7 @@ def register(
14771506

14781507
if (
14791508
import_model_details.download_from_hf
1480-
and import_model_details.delete_from_local
1509+
and import_model_details.cleanup_model_cache
14811510
):
14821511
cleanup_local_hf_model_artifact(
14831512
model_name=model_name, local_dir=import_model_details.local_dir

ads/model/__init__.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,26 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

7-
from ads.model.generic_model import GenericModel, ModelState
86
from ads.model.datascience_model import DataScienceModel
9-
from ads.model.model_properties import ModelProperties
7+
from ads.model.deployment.model_deployer import ModelDeployer
8+
from ads.model.deployment.model_deployment import ModelDeployment
9+
from ads.model.deployment.model_deployment_properties import ModelDeploymentProperties
1010
from ads.model.framework.automl_model import AutoMLModel
11+
from ads.model.framework.embedding_onnx_model import EmbeddingONNXModel
12+
from ads.model.framework.huggingface_model import HuggingFacePipelineModel
1113
from ads.model.framework.lightgbm_model import LightGBMModel
1214
from ads.model.framework.pytorch_model import PyTorchModel
1315
from ads.model.framework.sklearn_model import SklearnModel
16+
from ads.model.framework.spark_model import SparkPipelineModel
1417
from ads.model.framework.tensorflow_model import TensorFlowModel
1518
from ads.model.framework.xgboost_model import XGBoostModel
16-
from ads.model.framework.spark_model import SparkPipelineModel
17-
from ads.model.framework.huggingface_model import HuggingFacePipelineModel
18-
19-
from ads.model.deployment.model_deployer import ModelDeployer
20-
from ads.model.deployment.model_deployment import ModelDeployment
21-
from ads.model.deployment.model_deployment_properties import ModelDeploymentProperties
22-
19+
from ads.model.generic_model import GenericModel, ModelState
20+
from ads.model.model_properties import ModelProperties
21+
from ads.model.model_version_set import ModelVersionSet, experiment
2322
from ads.model.serde.common import SERDE
2423
from ads.model.serde.model_input import ModelInputSerializer
25-
26-
from ads.model.model_version_set import ModelVersionSet, experiment
2724
from ads.model.service.oci_datascience_model_version_set import (
2825
ModelVersionSetNotExists,
2926
ModelVersionSetNotSaved,
@@ -42,6 +39,7 @@
4239
"XGBoostModel",
4340
"SparkPipelineModel",
4441
"HuggingFacePipelineModel",
42+
"EmbeddingONNXModel",
4543
"ModelDeployer",
4644
"ModelDeployment",
4745
"ModelDeploymentProperties",

ads/model/artifact.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2022, 2025 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
import fnmatch
87
import importlib
98
import os
10-
import sys
119
import shutil
10+
import sys
1211
import tempfile
1312
import uuid
14-
import fsspec
13+
from datetime import datetime
1514
from typing import Dict, Optional, Tuple
15+
16+
import fsspec
17+
from jinja2 import Environment, PackageLoader
18+
19+
from ads import __version__
1620
from ads.common import auth as authutil
1721
from ads.common import logger, utils
1822
from ads.common.object_storage_details import ObjectStorageDetails
1923
from ads.config import CONDA_BUCKET_NAME, CONDA_BUCKET_NS
2024
from ads.model.runtime.env_info import EnvInfo, InferenceEnvInfo, TrainingEnvInfo
2125
from ads.model.runtime.runtime_info import RuntimeInfo
22-
from jinja2 import Environment, PackageLoader
23-
import warnings
24-
from ads import __version__
25-
from datetime import datetime
2626

2727
MODEL_ARTIFACT_VERSION = "3.0"
2828
REQUIRED_ARTIFACT_FILES = ("runtime.yaml", "score.py")
@@ -378,6 +378,45 @@ def prepare_score_py(
378378
) as f:
379379
f.write(scorefn_template.render(context))
380380

381+
def prepare_schema(self, schema_name: str):
382+
"""Copies schema to artifact directory.
383+
384+
Parameters
385+
----------
386+
schema_name: str
387+
The schema name
388+
389+
Returns
390+
-------
391+
None
392+
393+
Raises
394+
------
395+
FileExistsError
396+
If `schema_name` doesn't exist.
397+
"""
398+
uri_src = os.path.join(
399+
os.path.abspath(os.path.join(os.path.dirname(__file__), "..")),
400+
"templates",
401+
"schemas",
402+
f"{schema_name}",
403+
)
404+
405+
if not os.path.exists(uri_src):
406+
raise FileExistsError(
407+
f"{schema_name} does not exists. "
408+
"Ensure the schema name is valid or specify a different one."
409+
)
410+
411+
uri_dst = os.path.join(self.artifact_dir, os.path.basename(uri_src))
412+
413+
utils.copy_file(
414+
uri_src=uri_src,
415+
uri_dst=uri_dst,
416+
force_overwrite=True,
417+
auth=self.auth,
418+
)
419+
381420
def reload(self):
382421
"""Syncs the `score.py` to reload the model and predict function.
383422

0 commit comments

Comments
 (0)