Skip to content

Commit 1fd0b99

Browse files
committed
Initial commit.
1 parent 8d358c7 commit 1fd0b99

File tree

3 files changed

+286
-0
lines changed

3 files changed

+286
-0
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
from ads.model.extractor.model_info_extractor import ModelInfoExtractor
7+
8+
9+
class EmbeddingONNXExtractor(ModelInfoExtractor):
10+
def __init__(self, model):
11+
self.model = model
12+
13+
@property
14+
def framework(self):
15+
"""Extracts the framework of the model.
16+
17+
Returns
18+
----------
19+
str:
20+
The framework of the model.
21+
"""
22+
pass
23+
24+
@property
25+
def algorithm(self):
26+
"""Extracts the algorithm of the model.
27+
28+
Returns
29+
----------
30+
object:
31+
The algorithm of the model.
32+
"""
33+
pass
34+
35+
@property
36+
def version(self):
37+
"""Extracts the framework version of the model.
38+
39+
Returns
40+
----------
41+
str:
42+
The framework version of the model.
43+
"""
44+
pass
45+
46+
@property
47+
def hyperparameter(self):
48+
"""Extracts the hyperparameters of the model.
49+
50+
Returns
51+
----------
52+
dict:
53+
The hyperparameters of the model.
54+
"""
55+
pass
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
from typing import Any, Callable, Dict, Self
7+
8+
from ads.model.extractor.embedding_onnx_extractor import EmbeddingONNXExtractor
9+
from ads.model.generic_model import FrameworkSpecificModel
10+
from ads.model.model_properties import ModelProperties
11+
from ads.model.serde.common import SERDE
12+
13+
14+
class EmbeddingONNXModel(FrameworkSpecificModel):
15+
def __init__(
16+
self,
17+
estimator: Callable[..., Any] = None,
18+
artifact_dir: str | None = None,
19+
properties: ModelProperties | None = None,
20+
auth: Dict | None = None,
21+
serialize: bool = True,
22+
model_save_serializer: SERDE | None = None,
23+
model_input_serializer: SERDE | None = None,
24+
**kwargs: dict,
25+
) -> Self:
26+
super().__init__(
27+
estimator,
28+
artifact_dir,
29+
properties,
30+
auth,
31+
serialize,
32+
model_save_serializer,
33+
model_input_serializer,
34+
**kwargs,
35+
)
36+
37+
self._extractor = EmbeddingONNXExtractor(estimator)
38+
self.framework = self._extractor.framework
39+
self.algorithm = self._extractor.algorithm
40+
self.version = self._extractor.version
41+
self.hyperparameter = self._extractor.hyperparameter
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# score.py 1.0 generated by ADS 2.11.10 on 20241002_212041
2+
import os
3+
import sys
4+
import json
5+
from functools import lru_cache
6+
import onnxruntime as ort
7+
import jsonschema
8+
from jsonschema import validate, ValidationError
9+
from transformers import AutoTokenizer
10+
import logging
11+
12+
model_name = 'model.onnx'
13+
openapi_schema = ''
14+
15+
16+
"""
17+
Inference script. This script is used for prediction by scoring server when schema is known.
18+
"""
19+
20+
21+
@lru_cache(maxsize=10)
22+
def load_model(model_file_name=model_name):
23+
"""
24+
Loads model from the serialized format
25+
26+
Returns
27+
-------
28+
model: a model instance on which predict API can be invoked
29+
"""
30+
model_dir = os.path.dirname(os.path.realpath(__file__))
31+
if model_dir not in sys.path:
32+
sys.path.insert(0, model_dir)
33+
contents = os.listdir(model_dir)
34+
if model_file_name in contents:
35+
# print(f'Start loading {model_file_name} from model directory {model_dir} ...')
36+
model = ort.InferenceSession(os.path.join(model_dir, model_file_name), providers=['CUDAExecutionProvider','CPUExecutionProvider'])
37+
# print("Model is successfully loaded.")
38+
return model
39+
else:
40+
raise Exception(f'{model_file_name} is not found in model directory {model_dir}')
41+
42+
43+
@lru_cache(maxsize=1)
44+
def load_tokenizer(model_full_name):
45+
46+
# todo: do we need model_full_name or have configs in artifact dir?
47+
model_dir = os.path.dirname(os.path.realpath(__file__))
48+
# initialize tokenizer
49+
return AutoTokenizer.from_pretrained(model_dir, clean_up_tokenization_spaces=True)
50+
51+
@lru_cache(maxsize=1)
52+
def load_openapi_schema():
53+
"""
54+
Loads the input schema for the incoming request
55+
56+
Returns
57+
-------
58+
schema: openapi schema as json
59+
"""
60+
model_dir = os.path.dirname(os.path.realpath(__file__))
61+
if model_dir not in sys.path:
62+
sys.path.insert(0, model_dir)
63+
contents = os.listdir(model_dir)
64+
65+
try:
66+
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), openapi_schema), 'r') as file:
67+
return json.load(file)
68+
except:
69+
raise Exception(f'{openapi_schema} is not found in model directory {model_dir}')
70+
71+
72+
def validate_inputs(data):
73+
74+
api_schema = load_openapi_schema()
75+
76+
# use a reference resolver for internal $refs
77+
resolver = jsonschema.RefResolver.from_schema(api_schema)
78+
79+
# get the actual schema part to validate against
80+
request_schema = api_schema["components"]["schemas"]["OpenAICompatRequest"]
81+
82+
try:
83+
# validate the input JSON
84+
validate(instance=data, schema=request_schema, resolver=resolver)
85+
except ValidationError as e:
86+
# todo: add custom error code and message in error handler
87+
example_value = {
88+
"input": ["What are activation functions?"],
89+
"encoding_format": "float",
90+
"model": "sentence-transformers/all-MiniLM-L6-v2",
91+
"user": "user"
92+
}
93+
message = f"JSON is invalid. Error: {e.message}\n An example of the expected format for 'OpenAICompatRequest' looks like: \n {json.dumps(example_value, indent=2)}"
94+
raise ValueError(message) from e
95+
96+
97+
def pre_inference(data):
98+
"""
99+
Preprocess data
100+
101+
Parameters
102+
----------
103+
data: Data format as expected by the predict API.
104+
105+
Returns
106+
-------
107+
onnx_inputs: Data format after any processing
108+
total_tokens: total tokens that will be processed by the model
109+
110+
"""
111+
validate_inputs(data)
112+
113+
tokenizer = load_tokenizer(data['model'])
114+
inputs = tokenizer(data['input'], return_tensors="np", padding=True)
115+
116+
padding_token_id = tokenizer.pad_token_id
117+
total_tokens = (inputs["input_ids"] != padding_token_id).sum().item()
118+
onnx_inputs = {key: [l.tolist()for l in inputs[key] ] for key in inputs}
119+
120+
return onnx_inputs, total_tokens
121+
122+
def convert_embeddings_to_openapi_format(embeddings, model_name, total_tokens):
123+
124+
formatted_data = []
125+
openai_compat_response = {}
126+
for idx, embedding in enumerate(embeddings):
127+
128+
formatted_embedding = {
129+
"object": "embedding",
130+
"embedding": embedding,
131+
"index": idx
132+
}
133+
formatted_data.append(formatted_embedding)
134+
135+
# create the final OpenAICompatResponse format
136+
openai_compat_response = {
137+
"object": "list",
138+
"data": formatted_data,
139+
"model": model_name, # Use the provided model name
140+
"usage": {
141+
"prompt_tokens": total_tokens, # represents the token count for just the text input
142+
"total_tokens": total_tokens # total number of tokens involved in the request, same in case of embeddings
143+
}
144+
}
145+
146+
return openai_compat_response
147+
148+
149+
def post_inference(outputs, model_name, total_tokens):
150+
"""
151+
Post-process the model results
152+
153+
Parameters
154+
----------
155+
outputs: Data format after calling model.run
156+
model_name: name of model
157+
total_tokens: total tokens that will be processed by the model
158+
159+
Returns
160+
-------
161+
outputs: Data format after any processing.
162+
163+
"""
164+
results = [embed.tolist() for embed in outputs]
165+
response = convert_embeddings_to_openapi_format(results, model_name, total_tokens)
166+
return response
167+
168+
def predict(data, model=load_model()):
169+
"""
170+
Returns prediction given the model and data to predict
171+
172+
Parameters
173+
----------
174+
model: Model instance returned by load_model API.
175+
data: Data format as expected by the predict API of the core estimator. For eg. in case of sckit models it could be numpy array/List of list/Pandas DataFrame.
176+
177+
Returns
178+
-------
179+
predictions: Output from scoring server
180+
Format: {'prediction': output from model.predict method}
181+
182+
"""
183+
# inputs contains 'input_ids', 'token_type_ids', 'attention_mask' but 'token_type_ids' is optional
184+
inputs, total_tokens = pre_inference(data)
185+
186+
onnx_inputs = [inp.name for inp in model.get_inputs()]
187+
embeddings = model.run(None, {key: inputs[key] if key in inputs else None for key in onnx_inputs})[0]
188+
189+
response = post_inference(embeddings, data['model'], total_tokens)
190+
return response

0 commit comments

Comments
 (0)