1
1
#!/usr/bin/env python
2
2
3
- # Copyright (c) 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2025 Oracle and/or its affiliates.
4
4
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
6
+ import logging
7
+ import os
8
+ from pathlib import Path
6
9
from typing import Dict , Optional
7
10
8
11
from ads .model .extractor .embedding_onnx_extractor import EmbeddingONNXExtractor
9
12
from ads .model .generic_model import FrameworkSpecificModel
10
13
14
+ logger = logging .getLogger (__name__ )
15
+
16
+ CONFIG = "config.json"
17
+ TOKENIZERS = [
18
+ "tokenizer.json" ,
19
+ "tokenizer_config.json" ,
20
+ "spiece.model" ,
21
+ "vocab.txt" ,
22
+ "vocab.json" ,
23
+ ]
24
+
11
25
12
26
class EmbeddingONNXModel (FrameworkSpecificModel ):
13
27
"""EmbeddingONNXModel class for embedding onnx model.
@@ -18,6 +32,12 @@ class EmbeddingONNXModel(FrameworkSpecificModel):
18
32
The algorithm of the model.
19
33
artifact_dir: str
20
34
Artifact directory to store the files needed for deployment.
35
+ model_file_name: str
36
+ Path to the model artifact.
37
+ config_json: str
38
+ Path to the config.json file.
39
+ tokenizer_dir: str
40
+ Path to the tokenizer directory.
21
41
auth: Dict
22
42
Default authentication is set using the `ads.set_auth` API. To override the
23
43
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create
@@ -166,6 +186,9 @@ class EmbeddingONNXModel(FrameworkSpecificModel):
166
186
def __init__ (
167
187
self ,
168
188
artifact_dir : Optional [str ] = None ,
189
+ model_file_name : Optional [str ] = None ,
190
+ config_json : Optional [str ] = None ,
191
+ tokenizer_dir : Optional [str ] = None ,
169
192
auth : Optional [Dict ] = None ,
170
193
serialize : bool = False ,
171
194
** kwargs : dict ,
@@ -175,8 +198,14 @@ def __init__(
175
198
176
199
Parameters
177
200
----------
178
- artifact_dir: str
201
+ artifact_dir: ( str, optional). Defaults to None.
179
202
Directory for generate artifact.
203
+ model_file_name: (str, optional). Defaults to None.
204
+ Path to the model artifact.
205
+ config_json: (str, optional). Defaults to None.
206
+ Path to the config.json file.
207
+ tokenizer_dir: (str, optional). Defaults to None.
208
+ Path to the tokenizer directory.
180
209
auth: (Dict, optional). Defaults to None.
181
210
The default authetication is set using `ads.set_auth` API. If you need to override the
182
211
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
@@ -260,12 +289,63 @@ def __init__(
260
289
** kwargs ,
261
290
)
262
291
292
+ self ._validate_artifact_directory (
293
+ model_file_name = model_file_name ,
294
+ config_json = config_json ,
295
+ tokenizer_dir = tokenizer_dir ,
296
+ )
297
+
263
298
self ._extractor = EmbeddingONNXExtractor ()
264
299
self .framework = self ._extractor .framework
265
300
self .algorithm = self ._extractor .algorithm
266
301
self .version = self ._extractor .version
267
302
self .hyperparameter = self ._extractor .hyperparameter
268
303
304
+ def _validate_artifact_directory (
305
+ self ,
306
+ model_file_name : str = None ,
307
+ config_json : str = None ,
308
+ tokenizer_dir : str = None ,
309
+ ):
310
+ artifacts = []
311
+ for _ , _ , files in os .walk (self .artifact_dir ):
312
+ artifacts .extend (files )
313
+
314
+ if not artifacts :
315
+ raise ValueError (
316
+ f"No files found in { self .artifact_dir } . Specify a valid `artifact_dir`."
317
+ )
318
+
319
+ if not model_file_name :
320
+ has_model_file = False
321
+ for artifact in artifacts :
322
+ if Path (artifact ).suffix .lstrip ("." ).lower () == "onnx" :
323
+ has_model_file = True
324
+ break
325
+
326
+ if not has_model_file :
327
+ raise ValueError (
328
+ f"No onnx model found in { self .artifact_dir } . Specify a valid `artifact_dir` or `model_file_name`."
329
+ )
330
+
331
+ if not config_json :
332
+ if CONFIG not in artifacts :
333
+ logger .warning (
334
+ f"No { CONFIG } found in { self .artifact_dir } . Specify a valid `artifact_dir` or `config_json`."
335
+ )
336
+
337
+ if not tokenizer_dir :
338
+ has_tokenizer = False
339
+ for artifact in artifacts :
340
+ if artifact in TOKENIZERS :
341
+ has_tokenizer = True
342
+ break
343
+
344
+ if not has_tokenizer :
345
+ logger .warning (
346
+ f"No tokenizer found in { self .artifact_dir } . Specify a valid `artifact_dir` or `tokenizer_dir`."
347
+ )
348
+
269
349
def verify (
270
350
self , data = None , reload_artifacts = True , auto_serialize_data = False , ** kwargs
271
351
):
0 commit comments