Skip to content

Commit e84715c

Browse files
committed
Updated pr.
1 parent 3803600 commit e84715c

File tree

3 files changed

+74
-44
lines changed

3 files changed

+74
-44
lines changed

ads/model/framework/embedding_onnx_model.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Copyright (c) 2024 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

6-
from typing import Dict
6+
from typing import Dict, Optional
77

88
from ads.model.extractor.embedding_onnx_extractor import EmbeddingONNXExtractor
99
from ads.model.generic_model import FrameworkSpecificModel
@@ -108,18 +108,26 @@ class EmbeddingONNXModel(FrameworkSpecificModel):
108108
>>> from huggingface_hub import snapshot_download
109109
110110
>>> local_dir=tempfile.mkdtemp()
111-
>>> # download sentence-transformers/all-MiniLM-L6-v2 from huggingface
111+
>>> allow_patterns=[
112+
... "onnx/model.onnx",
113+
... "config.json",
114+
... "special_tokens_map.json",
115+
... "tokenizer_config.json",
116+
... "tokenizer.json",
117+
... "vocab.txt"
118+
... ]
119+
120+
>>> # download files needed for this demostration to local folder
112121
>>> snapshot_download(
113122
... repo_id="sentence-transformers/all-MiniLM-L6-v2",
114-
... local_dir=local_dir
123+
... local_dir=local_dir,
124+
... allow_patterns=allow_patterns
115125
... )
116126
117-
>>> # copy all files from local_dir to artifact_dir
118127
>>> artifact_dir = tempfile.mkdtemp()
119-
>>> for root, dirs, files in os.walk(local_dir):
120-
>>> for file in files:
121-
>>> src_path = os.path.join(root, file)
122-
>>> shutil.copy(src_path, artifact_dir)
128+
>>> # copy all downloaded files to artifact folder
129+
>>> for file in allow_patterns:
130+
>>> shutil.copy(local_dir + "/" + file, artifact_dir)
123131
124132
>>> model = EmbeddingONNXModel(artifact_dir=artifact_dir)
125133
>>> model.summary_status()
@@ -157,8 +165,8 @@ class EmbeddingONNXModel(FrameworkSpecificModel):
157165

158166
def __init__(
159167
self,
160-
artifact_dir: str | None = None,
161-
auth: Dict | None = None,
168+
artifact_dir: Optional[str] = None,
169+
auth: Optional[Dict] = None,
162170
serialize: bool = False,
163171
**kwargs: dict,
164172
):
@@ -191,18 +199,26 @@ def __init__(
191199
>>> from huggingface_hub import snapshot_download
192200
193201
>>> local_dir=tempfile.mkdtemp()
194-
>>> # download sentence-transformers/all-MiniLM-L6-v2 from huggingface
202+
>>> allow_patterns=[
203+
... "onnx/model.onnx",
204+
... "config.json",
205+
... "special_tokens_map.json",
206+
... "tokenizer_config.json",
207+
... "tokenizer.json",
208+
... "vocab.txt"
209+
... ]
210+
211+
>>> # download files needed for this demostration to local folder
195212
>>> snapshot_download(
196213
... repo_id="sentence-transformers/all-MiniLM-L6-v2",
197-
... local_dir=local_dir
214+
... local_dir=local_dir,
215+
... allow_patterns=allow_patterns
198216
... )
199217
200-
>>> # copy all files from subdirectory to artifact_dir
201218
>>> artifact_dir = tempfile.mkdtemp()
202-
>>> for root, dirs, files in os.walk(local_dir):
203-
>>> for file in files:
204-
>>> src_path = os.path.join(root, file)
205-
>>> shutil.copy(src_path, artifact_dir)
219+
>>> # copy all downloaded files to artifact folder
220+
>>> for file in allow_patterns:
221+
>>> shutil.copy(local_dir + "/" + file, artifact_dir)
206222
207223
>>> model = EmbeddingONNXModel(artifact_dir=artifact_dir)
208224
>>> model.summary_status()

ads/templates/score_embedding_onnx.jinja2

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import sys
44
import json
5+
import subprocess
56
from functools import lru_cache
67
import onnxruntime as ort
78
import jsonschema
@@ -33,13 +34,26 @@ def load_model(model_file_name=model_name):
3334
contents = os.listdir(model_dir)
3435
if model_file_name in contents:
3536
print(f'Start loading {model_file_name} from model directory {model_dir} ...')
36-
model = ort.InferenceSession(os.path.join(model_dir, model_file_name), providers=['CUDAExecutionProvider','CPUExecutionProvider'])
37+
providers= ['CPUExecutionProvider']
38+
if is_gpu_available():
39+
providers=['CUDAExecutionProvider','CPUExecutionProvider']
40+
model = ort.InferenceSession(os.path.join(model_dir, model_file_name), providers=providers)
3741
print("Model is successfully loaded.")
3842
return model
3943
else:
4044
raise Exception(f'{model_file_name} is not found in model directory {model_dir}')
4145

4246

47+
def is_gpu_available():
48+
"""Check if gpu is available on the infrastructure."""
49+
try:
50+
result = subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
51+
if result.returncode == 0:
52+
return True
53+
except FileNotFoundError:
54+
return False
55+
56+
4357
@lru_cache(maxsize=1)
4458
def load_tokenizer(model_full_name):
4559

docs/source/user_guide/model_registration/frameworks/embeddingonnxmodel.rst

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ See `API Documentation <../../../ads.model.framework.html#ads.model.framework.em
66
Overview
77
========
88

9-
The ``ads.model.framework.embedding_onnx_model.EmbeddingONNXModel`` class in ADS is designed to rapidly get an Embedding ONNX Model into production. The ``.prepare()`` method creates the model artifacts that are needed without configuring it or writing code. However, you can customize the required ``score.py`` file.
9+
The ``ads.model.framework.embedding_onnx_model.EmbeddingONNXModel`` class in ADS is designed to rapidly get an Embedding ONNX Model into production. The ``.prepare()`` method creates the model artifacts that are needed without configuring it or writing code. ``EmbeddingONNXModel`` supports `OpenAI spec <https://github.com/huggingface/text-embeddings-inference/blob/main/docs/openapi.json>`_ for embeddings endpoint.
1010

1111
.. include:: ../_template/overview.rst
1212

@@ -24,26 +24,26 @@ The following steps take the `sentence-transformers/all-MiniLM-L6-v2 <https://hu
2424
2525
local_dir = tempfile.mkdtemp()
2626
27+
allow_patterns=[
28+
"onnx/model.onnx",
29+
"config.json",
30+
"special_tokens_map.json",
31+
"tokenizer_config.json",
32+
"tokenizer.json",
33+
"vocab.txt"
34+
]
35+
2736
# download files needed for this demostration to local folder
2837
snapshot_download(
2938
repo_id="sentence-transformers/all-MiniLM-L6-v2",
3039
local_dir=local_dir,
31-
allow_patterns=[
32-
"onnx/model.onnx",
33-
"config.json",
34-
"special_tokens_map.json",
35-
"tokenizer_config.json",
36-
"tokenizer.json",
37-
"vocab.txt"
38-
]
40+
allow_patterns=allow_patterns
3941
)
4042
4143
artifact_dir = tempfile.mkdtemp()
4244
# copy all downloaded files to artifact folder
43-
for root, dirs, files in os.walk(local_dir):
44-
for file in files:
45-
src_path = os.path.join(root, file)
46-
shutil.copy(src_path, artifact_dir)
45+
for file in allow_patterns:
46+
shutil.copy(local_dir + "/" + file, artifact_dir)
4747
4848
4949
Install Conda Pack
@@ -213,26 +213,26 @@ Example
213213
214214
local_dir = tempfile.mkdtemp()
215215
216-
# download files needed for the demostration to local folder
216+
allow_patterns=[
217+
"onnx/model.onnx",
218+
"config.json",
219+
"special_tokens_map.json",
220+
"tokenizer_config.json",
221+
"tokenizer.json",
222+
"vocab.txt"
223+
]
224+
225+
# download files needed for this demostration to local folder
217226
snapshot_download(
218227
repo_id="sentence-transformers/all-MiniLM-L6-v2",
219228
local_dir=local_dir,
220-
allow_patterns=[
221-
"onnx/model.onnx",
222-
"config.json",
223-
"special_tokens_map.json",
224-
"tokenizer_config.json",
225-
"tokenizer.json",
226-
"vocab.txt"
227-
]
229+
allow_patterns=allow_patterns
228230
)
229231
230232
artifact_dir = tempfile.mkdtemp()
231233
# copy all downloaded files to artifact folder
232-
for root, dirs, files in os.walk(local_dir):
233-
for file in files:
234-
src_path = os.path.join(root, file)
235-
shutil.copy(src_path, artifact_dir)
234+
for file in allow_patterns:
235+
shutil.copy(local_dir + "/" + file, artifact_dir)
236236
237237
# initialize EmbeddingONNXModel instance and prepare score.py, runtime.yaml and openapi.json files.
238238
embedding_onnx_model = EmbeddingONNXModel(artifact_dir=artifact_dir)

0 commit comments

Comments
 (0)