Skip to content

Commit 5bd380c

Browse files
committed
Added artifacts validation.
1 parent e84715c commit 5bd380c

File tree

5 files changed

+103
-9
lines changed

5 files changed

+103
-9
lines changed

ads/model/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
from ads.model.datascience_model import DataScienceModel

ads/model/artifact.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2022, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import fnmatch

ads/model/extractor/embedding_onnx_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
from ads.common.decorator.runtime_dependency import (

ads/model/framework/embedding_onnx_model.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,27 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

6+
import logging
7+
import os
8+
from pathlib import Path
69
from typing import Dict, Optional
710

811
from ads.model.extractor.embedding_onnx_extractor import EmbeddingONNXExtractor
912
from ads.model.generic_model import FrameworkSpecificModel
1013

14+
logger = logging.getLogger(__name__)
15+
16+
CONFIG = "config.json"
17+
TOKENIZERS = [
18+
"tokenizer.json",
19+
"tokenizer_config.json",
20+
"spiece.model",
21+
"vocab.txt",
22+
"vocab.json",
23+
]
24+
1125

1226
class EmbeddingONNXModel(FrameworkSpecificModel):
1327
"""EmbeddingONNXModel class for embedding onnx model.
@@ -18,6 +32,12 @@ class EmbeddingONNXModel(FrameworkSpecificModel):
1832
The algorithm of the model.
1933
artifact_dir: str
2034
Artifact directory to store the files needed for deployment.
35+
model_file_name: str
36+
Path to the model artifact.
37+
config_json: str
38+
Path to the config.json file.
39+
tokenizer_dir: str
40+
Path to the tokenizer directory.
2141
auth: Dict
2242
Default authentication is set using the `ads.set_auth` API. To override the
2343
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create
@@ -166,6 +186,9 @@ class EmbeddingONNXModel(FrameworkSpecificModel):
166186
def __init__(
167187
self,
168188
artifact_dir: Optional[str] = None,
189+
model_file_name: Optional[str] = None,
190+
config_json: Optional[str] = None,
191+
tokenizer_dir: Optional[str] = None,
169192
auth: Optional[Dict] = None,
170193
serialize: bool = False,
171194
**kwargs: dict,
@@ -175,8 +198,14 @@ def __init__(
175198
176199
Parameters
177200
----------
178-
artifact_dir: str
201+
artifact_dir: (str, optional). Defaults to None.
179202
Directory for generate artifact.
203+
model_file_name: (str, optional). Defaults to None.
204+
Path to the model artifact.
205+
config_json: (str, optional). Defaults to None.
206+
Path to the config.json file.
207+
tokenizer_dir: (str, optional). Defaults to None.
208+
Path to the tokenizer directory.
180209
auth: (Dict, optional). Defaults to None.
181210
The default authetication is set using `ads.set_auth` API. If you need to override the
182211
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
@@ -260,12 +289,63 @@ def __init__(
260289
**kwargs,
261290
)
262291

292+
self._validate_artifact_directory(
293+
model_file_name=model_file_name,
294+
config_json=config_json,
295+
tokenizer_dir=tokenizer_dir,
296+
)
297+
263298
self._extractor = EmbeddingONNXExtractor()
264299
self.framework = self._extractor.framework
265300
self.algorithm = self._extractor.algorithm
266301
self.version = self._extractor.version
267302
self.hyperparameter = self._extractor.hyperparameter
268303

304+
def _validate_artifact_directory(
305+
self,
306+
model_file_name: str = None,
307+
config_json: str = None,
308+
tokenizer_dir: str = None,
309+
):
310+
artifacts = []
311+
for _, _, files in os.walk(self.artifact_dir):
312+
artifacts.extend(files)
313+
314+
if not artifacts:
315+
raise ValueError(
316+
f"No files found in {self.artifact_dir}. Specify a valid `artifact_dir`."
317+
)
318+
319+
if not model_file_name:
320+
has_model_file = False
321+
for artifact in artifacts:
322+
if Path(artifact).suffix.lstrip(".").lower() == "onnx":
323+
has_model_file = True
324+
break
325+
326+
if not has_model_file:
327+
raise ValueError(
328+
f"No onnx model found in {self.artifact_dir}. Specify a valid `artifact_dir` or `model_file_name`."
329+
)
330+
331+
if not config_json:
332+
if CONFIG not in artifacts:
333+
logger.warning(
334+
f"No {CONFIG} found in {self.artifact_dir}. Specify a valid `artifact_dir` or `config_json`."
335+
)
336+
337+
if not tokenizer_dir:
338+
has_tokenizer = False
339+
for artifact in artifacts:
340+
if artifact in TOKENIZERS:
341+
has_tokenizer = True
342+
break
343+
344+
if not has_tokenizer:
345+
logger.warning(
346+
f"No tokenizer found in {self.artifact_dir}. Specify a valid `artifact_dir` or `tokenizer_dir`."
347+
)
348+
269349
def verify(
270350
self, data=None, reload_artifacts=True, auto_serialize_data=False, **kwargs
271351
):

tests/unitary/with_extras/model/test_model_framework_embedding_onnx_model.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import os
@@ -21,13 +21,20 @@ def setup_class(cls):
2121
cls.inference_conda = "oci://fake_bucket@fake_namespace/inference_conda"
2222
cls.training_conda = "oci://fake_bucket@fake_namespace/training_conda"
2323

24-
def test_init(self):
24+
@patch(
25+
"ads.model.framework.embedding_onnx_model.EmbeddingONNXModel._validate_artifact_directory"
26+
)
27+
def test_init(self, mock_validate):
2528
model = EmbeddingONNXModel(artifact_dir=self.tmp_model_dir)
2629
assert model.algorithm == "Embedding_ONNX"
2730
assert model.framework == Framework.EMBEDDING_ONNX
31+
mock_validate.assert_called()
2832

2933
@patch("ads.model.generic_model.GenericModel.verify")
30-
def test_prepare_and_verify(self, mock_verify):
34+
@patch(
35+
"ads.model.framework.embedding_onnx_model.EmbeddingONNXModel._validate_artifact_directory"
36+
)
37+
def test_prepare_and_verify(self, mock_validate, mock_verify):
3138
mock_verify.return_value = {"results": "successful"}
3239

3340
model = EmbeddingONNXModel(artifact_dir=self.tmp_model_dir)
@@ -87,11 +94,17 @@ def test_prepare_and_verify(self, mock_verify):
8794
reload_artifacts=True,
8895
auto_serialize_data=False,
8996
)
97+
mock_validate.assert_called()
9098

9199
@patch("ads.model.generic_model.GenericModel.predict")
92100
@patch("ads.model.generic_model.GenericModel.deploy")
93101
@patch("ads.model.generic_model.GenericModel.save")
94-
def test_prepare_save_deploy_predict(self, mock_save, mock_deploy, mock_predict):
102+
@patch(
103+
"ads.model.framework.embedding_onnx_model.EmbeddingONNXModel._validate_artifact_directory"
104+
)
105+
def test_prepare_save_deploy_predict(
106+
self, mock_validate, mock_save, mock_deploy, mock_predict
107+
):
95108
model = EmbeddingONNXModel(artifact_dir=self.tmp_model_dir)
96109
model.prepare(
97110
model_file_name="test_model_file_name",
@@ -127,6 +140,7 @@ def test_prepare_save_deploy_predict(self, mock_save, mock_deploy, mock_predict)
127140
deployment_ocpus=20,
128141
deployment_memory_in_gbs=256,
129142
)
143+
mock_validate.assert_called()
130144

131145
def teardown_class(cls):
132146
shutil.rmtree(cls.tmp_model_dir, ignore_errors=True)

0 commit comments

Comments
 (0)