Skip to content

Commit 723a763

Browse files
authored
Merge branch 'main' into dataflow_changes
2 parents e14237f + beef7b1 commit 723a763

File tree

14 files changed

+2988
-68
lines changed

14 files changed

+2988
-68
lines changed

ads/aqua/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

65

76
import os
7+
from logging import getLogger
88

9-
from ads import logger, set_auth
9+
from ads import set_auth
1010
from ads.aqua.common.utils import fetch_service_compartment
1111
from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION
1212

@@ -19,6 +19,7 @@ def get_logger_level():
1919
return level
2020

2121

22+
logger = getLogger(__name__)
2223
logger.setLevel(get_logger_level())
2324

2425

@@ -27,7 +28,6 @@ def set_log_level(log_level: str):
2728

2829
log_level = log_level.upper()
2930
logger.setLevel(log_level.upper())
30-
logger.handlers[0].setLevel(log_level)
3131

3232

3333
if OCI_RESOURCE_PRINCIPAL_VERSION:

ads/common/auth.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def create_signer(
424424
"signer": signer,
425425
"client_kwargs": client_kwargs,
426426
}
427-
logger.info(f"Using authentication signer type {type(signer)}.")
427+
logger.debug(f"Using authentication signer type {type(signer)}.")
428428
return signer_dict
429429
else:
430430
signer_args = dict(
@@ -492,7 +492,7 @@ def default_signer(client_kwargs: Optional[Dict] = None) -> Dict:
492492
**(client_kwargs or {}),
493493
},
494494
}
495-
logger.info(f"Using authentication signer type {type(signer)}.")
495+
logger.debug(f"Using authentication signer type {type(signer)}.")
496496
return signer_dict
497497
else:
498498
signer_args = dict(
@@ -621,7 +621,7 @@ def create_signer(self) -> Dict:
621621
)
622622

623623
oci.config.validate_config(configuration)
624-
logger.info(f"Using 'api_key' authentication.")
624+
logger.debug(f"Using 'api_key' authentication.")
625625
return {
626626
"config": configuration,
627627
"signer": oci.signer.Signer(
@@ -684,7 +684,7 @@ def create_signer(self) -> Dict:
684684
"signer": oci.auth.signers.get_resource_principals_signer(),
685685
"client_kwargs": self.client_kwargs,
686686
}
687-
logger.info(f"Using 'resource_principal' authentication.")
687+
logger.debug(f"Using 'resource_principal' authentication.")
688688
return signer_dict
689689

690690
@staticmethod
@@ -747,7 +747,7 @@ def create_signer(self) -> Dict:
747747
),
748748
"client_kwargs": self.client_kwargs,
749749
}
750-
logger.info(f"Using 'instance_principal' authentication.")
750+
logger.debug(f"Using 'instance_principal' authentication.")
751751
return signer_dict
752752

753753

@@ -814,7 +814,7 @@ def create_signer(self) -> Dict:
814814
oci.config.from_file(self.oci_config_location, self.oci_key_profile)
815815
)
816816

817-
logger.info(f"Using 'security_token' authentication.")
817+
logger.debug(f"Using 'security_token' authentication.")
818818

819819
for parameter in self.SECURITY_TOKEN_REQUIRED:
820820
if parameter not in configuration:
@@ -883,7 +883,7 @@ def _validate_and_refresh_token(self, configuration: Dict[str, Any]):
883883
)
884884

885885
date_time = datetime.fromtimestamp(time_expired).strftime("%Y-%m-%d %H:%M:%S")
886-
logger.info(f"Session is valid until {date_time}.")
886+
logger.debug(f"Session is valid until {date_time}.")
887887

888888
def _read_security_token_file(self, security_token_file: str) -> str:
889889
"""Reads security token from file.
@@ -1020,10 +1020,10 @@ def __enter__(self):
10201020
"""
10211021
if self.profile:
10221022
ads.set_auth(auth=AuthType.API_KEY, profile=self.profile)
1023-
logger.info(f"OCI profile set to {self.profile}")
1023+
logger.debug(f"OCI profile set to {self.profile}")
10241024
else:
10251025
ads.set_auth(auth=AuthType.RESOURCE_PRINCIPAL)
1026-
logger.info(f"OCI auth set to resource principal")
1026+
logger.debug(f"OCI auth set to resource principal")
10271027
return self
10281028

10291029
def __exit__(self, exc_type, exc_val, exc_tb):

ads/model/__init__.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,26 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

7-
from ads.model.generic_model import GenericModel, ModelState
86
from ads.model.datascience_model import DataScienceModel
9-
from ads.model.model_properties import ModelProperties
7+
from ads.model.deployment.model_deployer import ModelDeployer
8+
from ads.model.deployment.model_deployment import ModelDeployment
9+
from ads.model.deployment.model_deployment_properties import ModelDeploymentProperties
1010
from ads.model.framework.automl_model import AutoMLModel
11+
from ads.model.framework.embedding_onnx_model import EmbeddingONNXModel
12+
from ads.model.framework.huggingface_model import HuggingFacePipelineModel
1113
from ads.model.framework.lightgbm_model import LightGBMModel
1214
from ads.model.framework.pytorch_model import PyTorchModel
1315
from ads.model.framework.sklearn_model import SklearnModel
16+
from ads.model.framework.spark_model import SparkPipelineModel
1417
from ads.model.framework.tensorflow_model import TensorFlowModel
1518
from ads.model.framework.xgboost_model import XGBoostModel
16-
from ads.model.framework.spark_model import SparkPipelineModel
17-
from ads.model.framework.huggingface_model import HuggingFacePipelineModel
18-
19-
from ads.model.deployment.model_deployer import ModelDeployer
20-
from ads.model.deployment.model_deployment import ModelDeployment
21-
from ads.model.deployment.model_deployment_properties import ModelDeploymentProperties
22-
19+
from ads.model.generic_model import GenericModel, ModelState
20+
from ads.model.model_properties import ModelProperties
21+
from ads.model.model_version_set import ModelVersionSet, experiment
2322
from ads.model.serde.common import SERDE
2423
from ads.model.serde.model_input import ModelInputSerializer
25-
26-
from ads.model.model_version_set import ModelVersionSet, experiment
2724
from ads.model.service.oci_datascience_model_version_set import (
2825
ModelVersionSetNotExists,
2926
ModelVersionSetNotSaved,
@@ -42,6 +39,7 @@
4239
"XGBoostModel",
4340
"SparkPipelineModel",
4441
"HuggingFacePipelineModel",
42+
"EmbeddingONNXModel",
4543
"ModelDeployer",
4644
"ModelDeployment",
4745
"ModelDeploymentProperties",

ads/model/artifact.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2022, 2025 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
import fnmatch
87
import importlib
98
import os
10-
import sys
119
import shutil
10+
import sys
1211
import tempfile
1312
import uuid
14-
import fsspec
13+
from datetime import datetime
1514
from typing import Dict, Optional, Tuple
15+
16+
import fsspec
17+
from jinja2 import Environment, PackageLoader
18+
19+
from ads import __version__
1620
from ads.common import auth as authutil
1721
from ads.common import logger, utils
1822
from ads.common.object_storage_details import ObjectStorageDetails
1923
from ads.config import CONDA_BUCKET_NAME, CONDA_BUCKET_NS
2024
from ads.model.runtime.env_info import EnvInfo, InferenceEnvInfo, TrainingEnvInfo
2125
from ads.model.runtime.runtime_info import RuntimeInfo
22-
from jinja2 import Environment, PackageLoader
23-
import warnings
24-
from ads import __version__
25-
from datetime import datetime
2626

2727
MODEL_ARTIFACT_VERSION = "3.0"
2828
REQUIRED_ARTIFACT_FILES = ("runtime.yaml", "score.py")
@@ -378,6 +378,45 @@ def prepare_score_py(
378378
) as f:
379379
f.write(scorefn_template.render(context))
380380

381+
def prepare_schema(self, schema_name: str):
382+
"""Copies schema to artifact directory.
383+
384+
Parameters
385+
----------
386+
schema_name: str
387+
The schema name
388+
389+
Returns
390+
-------
391+
None
392+
393+
Raises
394+
------
395+
FileExistsError
396+
If `schema_name` doesn't exist.
397+
"""
398+
uri_src = os.path.join(
399+
os.path.abspath(os.path.join(os.path.dirname(__file__), "..")),
400+
"templates",
401+
"schemas",
402+
f"{schema_name}",
403+
)
404+
405+
if not os.path.exists(uri_src):
406+
raise FileExistsError(
407+
f"{schema_name} does not exists. "
408+
"Ensure the schema name is valid or specify a different one."
409+
)
410+
411+
uri_dst = os.path.join(self.artifact_dir, os.path.basename(uri_src))
412+
413+
utils.copy_file(
414+
uri_src=uri_src,
415+
uri_dst=uri_dst,
416+
force_overwrite=True,
417+
auth=self.auth,
418+
)
419+
381420
def reload(self):
382421
"""Syncs the `score.py` to reload the model and predict function.
383422
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2025 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
from ads.common.decorator.runtime_dependency import (
7+
OptionalDependency,
8+
runtime_dependency,
9+
)
10+
from ads.model.extractor.model_info_extractor import ModelInfoExtractor
11+
from ads.model.model_metadata import Framework
12+
13+
14+
class EmbeddingONNXExtractor(ModelInfoExtractor):
15+
"""Class that extract model metadata from EmbeddingONNXModel models.
16+
17+
Attributes
18+
----------
19+
model: object
20+
The model to extract metadata from.
21+
22+
Methods
23+
-------
24+
framework(self) -> str
25+
Returns the framework of the model.
26+
algorithm(self) -> object
27+
Returns the algorithm of the model.
28+
version(self) -> str
29+
Returns the version of framework of the model.
30+
hyperparameter(self) -> dict
31+
Returns the hyperparameter of the model.
32+
"""
33+
34+
def __init__(self, model=None):
35+
self.model = model
36+
37+
@property
38+
def framework(self):
39+
"""Extracts the framework of the model.
40+
41+
Returns
42+
----------
43+
str:
44+
The framework of the model.
45+
"""
46+
return Framework.EMBEDDING_ONNX
47+
48+
@property
49+
def algorithm(self):
50+
"""Extracts the algorithm of the model.
51+
52+
Returns
53+
----------
54+
object:
55+
The algorithm of the model.
56+
"""
57+
return "Embedding_ONNX"
58+
59+
@property
60+
@runtime_dependency(module="onnxruntime", install_from=OptionalDependency.ONNX)
61+
def version(self):
62+
"""Extracts the framework version of the model.
63+
64+
Returns
65+
----------
66+
str:
67+
The framework version of the model.
68+
"""
69+
return onnxruntime.__version__
70+
71+
@property
72+
def hyperparameter(self):
73+
"""Extracts the hyperparameters of the model.
74+
75+
Returns
76+
----------
77+
dict:
78+
The hyperparameters of the model.
79+
"""
80+
return None

0 commit comments

Comments
 (0)