Skip to content

Commit 4681ead

Browse files
committed
Added support for embedding onnx model.
1 parent 1fd0b99 commit 4681ead

File tree

9 files changed

+2314
-80
lines changed

9 files changed

+2314
-80
lines changed

ads/model/__init__.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,26 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

7-
from ads.model.generic_model import GenericModel, ModelState
86
from ads.model.datascience_model import DataScienceModel
9-
from ads.model.model_properties import ModelProperties
7+
from ads.model.deployment.model_deployer import ModelDeployer
8+
from ads.model.deployment.model_deployment import ModelDeployment
9+
from ads.model.deployment.model_deployment_properties import ModelDeploymentProperties
1010
from ads.model.framework.automl_model import AutoMLModel
11+
from ads.model.framework.embedding_onnx_model import EmbeddingONNXModel
12+
from ads.model.framework.huggingface_model import HuggingFacePipelineModel
1113
from ads.model.framework.lightgbm_model import LightGBMModel
1214
from ads.model.framework.pytorch_model import PyTorchModel
1315
from ads.model.framework.sklearn_model import SklearnModel
16+
from ads.model.framework.spark_model import SparkPipelineModel
1417
from ads.model.framework.tensorflow_model import TensorFlowModel
1518
from ads.model.framework.xgboost_model import XGBoostModel
16-
from ads.model.framework.spark_model import SparkPipelineModel
17-
from ads.model.framework.huggingface_model import HuggingFacePipelineModel
18-
19-
from ads.model.deployment.model_deployer import ModelDeployer
20-
from ads.model.deployment.model_deployment import ModelDeployment
21-
from ads.model.deployment.model_deployment_properties import ModelDeploymentProperties
22-
19+
from ads.model.generic_model import GenericModel, ModelState
20+
from ads.model.model_properties import ModelProperties
21+
from ads.model.model_version_set import ModelVersionSet, experiment
2322
from ads.model.serde.common import SERDE
2423
from ads.model.serde.model_input import ModelInputSerializer
25-
26-
from ads.model.model_version_set import ModelVersionSet, experiment
2724
from ads.model.service.oci_datascience_model_version_set import (
2825
ModelVersionSetNotExists,
2926
ModelVersionSetNotSaved,
@@ -42,6 +39,7 @@
4239
"XGBoostModel",
4340
"SparkPipelineModel",
4441
"HuggingFacePipelineModel",
42+
"EmbeddingONNXModel",
4543
"ModelDeployer",
4644
"ModelDeployment",
4745
"ModelDeploymentProperties",

ads/model/artifact.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
import fnmatch
87
import importlib
98
import os
10-
import sys
119
import shutil
10+
import sys
1211
import tempfile
1312
import uuid
14-
import fsspec
13+
from datetime import datetime
1514
from typing import Dict, Optional, Tuple
15+
16+
import fsspec
17+
from jinja2 import Environment, PackageLoader
18+
19+
from ads import __version__
1620
from ads.common import auth as authutil
1721
from ads.common import logger, utils
1822
from ads.common.object_storage_details import ObjectStorageDetails
1923
from ads.config import CONDA_BUCKET_NAME, CONDA_BUCKET_NS
2024
from ads.model.runtime.env_info import EnvInfo, InferenceEnvInfo, TrainingEnvInfo
2125
from ads.model.runtime.runtime_info import RuntimeInfo
22-
from jinja2 import Environment, PackageLoader
23-
import warnings
24-
from ads import __version__
25-
from datetime import datetime
2626

2727
MODEL_ARTIFACT_VERSION = "3.0"
2828
REQUIRED_ARTIFACT_FILES = ("runtime.yaml", "score.py")
@@ -378,6 +378,45 @@ def prepare_score_py(
378378
) as f:
379379
f.write(scorefn_template.render(context))
380380

381+
def prepare_schema(self, schema_name: str):
382+
"""Copies schema to artifact directory.
383+
384+
Parameters
385+
----------
386+
schema_name: str
387+
The schema name
388+
389+
Returns
390+
-------
391+
None
392+
393+
Raises
394+
------
395+
FileExistsError
396+
If `schema_name` doesn't exist.
397+
"""
398+
uri_src = os.path.join(
399+
os.path.abspath(os.path.join(os.path.dirname(__file__), "..")),
400+
"templates",
401+
"schemas",
402+
f"{schema_name}",
403+
)
404+
405+
if not os.path.exists(uri_src):
406+
raise FileExistsError(
407+
f"{schema_name} does not exists. "
408+
"Ensure the schema name is valid or specify a different one."
409+
)
410+
411+
uri_dst = os.path.join(self.artifact_dir, os.path.basename(uri_src))
412+
413+
utils.copy_file(
414+
uri_src=uri_src,
415+
uri_dst=uri_dst,
416+
force_overwrite=True,
417+
auth=self.auth,
418+
)
419+
381420
def reload(self):
382421
"""Syncs the `score.py` to reload the model and predict function.
383422

ads/model/extractor/embedding_onnx_extractor.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,35 @@
33
# Copyright (c) 2024 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

6+
from ads.common.decorator.runtime_dependency import (
7+
OptionalDependency,
8+
runtime_dependency,
9+
)
610
from ads.model.extractor.model_info_extractor import ModelInfoExtractor
11+
from ads.model.model_metadata import Framework
712

813

914
class EmbeddingONNXExtractor(ModelInfoExtractor):
10-
def __init__(self, model):
15+
"""Class that extract model metadata from EmbeddingONNXModel models.
16+
17+
Attributes
18+
----------
19+
model: object
20+
The model to extract metadata from.
21+
22+
Methods
23+
-------
24+
framework(self) -> str
25+
Returns the framework of the model.
26+
algorithm(self) -> object
27+
Returns the algorithm of the model.
28+
version(self) -> str
29+
Returns the version of framework of the model.
30+
hyperparameter(self) -> dict
31+
Returns the hyperparameter of the model.
32+
"""
33+
34+
def __init__(self, model=None):
1135
self.model = model
1236

1337
@property
@@ -19,7 +43,7 @@ def framework(self):
1943
str:
2044
The framework of the model.
2145
"""
22-
pass
46+
return Framework.EMBEDDING_ONNX
2347

2448
@property
2549
def algorithm(self):
@@ -30,9 +54,10 @@ def algorithm(self):
3054
object:
3155
The algorithm of the model.
3256
"""
33-
pass
57+
return "Embedding_ONNX"
3458

3559
@property
60+
@runtime_dependency(module="onnxruntime", install_from=OptionalDependency.ONNX)
3661
def version(self):
3762
"""Extracts the framework version of the model.
3863
@@ -41,7 +66,7 @@ def version(self):
4166
str:
4267
The framework version of the model.
4368
"""
44-
pass
69+
return onnxruntime.__version__
4570

4671
@property
4772
def hyperparameter(self):
@@ -52,4 +77,4 @@ def hyperparameter(self):
5277
dict:
5378
The hyperparameters of the model.
5479
"""
55-
pass
80+
return None

0 commit comments

Comments
 (0)