Skip to content

Commit 24d6fbe

Browse files
authored
Merge branch 'main' into ahosler-patch-1
2 parents 3feee65 + 2bd9cbe commit 24d6fbe

19 files changed

+385
-173
lines changed

ads/aqua/app.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
import os
77
import traceback
88
from dataclasses import fields
9-
from typing import Dict, Union
9+
from typing import Dict, Optional, Union
1010

1111
import oci
1212
from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails
1313

1414
from ads import set_auth
1515
from ads.aqua import logger
16-
from ads.aqua.common.enums import Tags
16+
from ads.aqua.common.enums import ConfigFolder, Tags
1717
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1818
from ads.aqua.common.utils import (
1919
_is_valid_mvs,
@@ -268,7 +268,12 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
268268
logger.info(f"Artifact not found in model {model_id}.")
269269
return False
270270

271-
def get_config(self, model_id: str, config_file_name: str) -> Dict:
271+
def get_config(
272+
self,
273+
model_id: str,
274+
config_file_name: str,
275+
config_folder: Optional[str] = ConfigFolder.CONFIG,
276+
) -> Dict:
272277
"""Gets the config for the given Aqua model.
273278
274279
Parameters
@@ -277,12 +282,17 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
277282
The OCID of the Aqua model.
278283
config_file_name: str
279284
name of the config file
285+
config_folder: (str, optional):
286+
subfolder path where config_file_name needs to be searched
287+
Defaults to `ConfigFolder.CONFIG`.
288+
When searching inside model artifact directory , the value is ConfigFolder.ARTIFACT`
280289
281290
Returns
282291
-------
283292
Dict:
284293
A dict of allowed configs.
285294
"""
295+
config_folder = config_folder or ConfigFolder.CONFIG
286296
oci_model = self.ds_client.get_model(model_id).data
287297
oci_aqua = (
288298
(
@@ -304,23 +314,26 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
304314
f"Base model found for the model: {oci_model.id}. "
305315
f"Loading {config_file_name} for base model {base_model_ocid}."
306316
)
307-
base_model = self.ds_client.get_model(base_model_ocid).data
308-
artifact_path = get_artifact_path(base_model.custom_metadata_list)
317+
if config_folder == ConfigFolder.ARTIFACT:
318+
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
319+
else:
320+
base_model = self.ds_client.get_model(base_model_ocid).data
321+
artifact_path = get_artifact_path(base_model.custom_metadata_list)
309322
else:
310323
logger.info(f"Loading {config_file_name} for model {oci_model.id}...")
311324
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
312-
313325
if not artifact_path:
314326
logger.debug(
315327
f"Failed to get artifact path from custom metadata for the model: {model_id}"
316328
)
317329
return config
318330

319-
config_path = f"{os.path.dirname(artifact_path)}/config/"
331+
config_path = os.path.join(os.path.dirname(artifact_path), config_folder)
320332
if not is_path_exists(config_path):
321-
config_path = f"{artifact_path.rstrip('/')}/config/"
322-
323-
config_file_path = f"{config_path}{config_file_name}"
333+
config_path = os.path.join(artifact_path.rstrip("/"), config_folder)
334+
if not is_path_exists(config_path):
335+
config_path = f"{artifact_path.rstrip('/')}/"
336+
config_file_path = os.path.join(config_path, config_file_name)
324337
if is_path_exists(config_file_path):
325338
try:
326339
config = load_config(

ads/aqua/common/enums.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,8 @@ class TextEmbeddingInferenceContainerParams(ExtendedEnum):
9292

9393
MODEL_ID = "model-id"
9494
PORT = "port"
95+
96+
97+
class ConfigFolder(ExtendedEnum):
98+
CONFIG = "config"
99+
ARTIFACT = "artifact"

ads/aqua/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
README = "README.md"
1111
LICENSE_TXT = "config/LICENSE.txt"
1212
DEPLOYMENT_CONFIG = "deployment_config.json"
13+
AQUA_MODEL_TOKENIZER_CONFIG = "tokenizer_config.json"
1314
COMPARTMENT_MAPPING_KEY = "service-model-compartment"
1415
CONTAINER_INDEX = "container_index.json"
1516
EVALUATION_REPORT_JSON = "report.json"

ads/aqua/extension/model_handler.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1515
from ads.aqua.common.utils import (
1616
get_hf_model_info,
17+
is_valid_ocid,
1718
list_hf_models,
1819
)
1920
from ads.aqua.extension.base_handler import AquaAPIhandler
@@ -316,8 +317,30 @@ def post(self, *args, **kwargs): # noqa: ARG002
316317
)
317318

318319

320+
class AquaModelTokenizerConfigHandler(AquaAPIhandler):
321+
def get(self, model_id):
322+
"""
323+
Handles requests for retrieving the Hugging Face tokenizer configuration of a specified model.
324+
Expected request format: GET /aqua/models/<model-ocid>/tokenizer
325+
326+
"""
327+
328+
path_list = urlparse(self.request.path).path.strip("/").split("/")
329+
# Path should be /aqua/models/ocid1.iad.ahdxxx/tokenizer
330+
# path_list=['aqua','models','<model-ocid>','tokenizer']
331+
if (
332+
len(path_list) == 4
333+
and is_valid_ocid(path_list[2])
334+
and path_list[3] == "tokenizer"
335+
):
336+
return self.finish(AquaModelApp().get_hf_tokenizer_config(model_id))
337+
338+
raise HTTPError(400, f"The request {self.request.path} is invalid.")
339+
340+
319341
__handlers__ = [
320342
("model/?([^/]*)", AquaModelHandler),
321343
("model/?([^/]*)/license", AquaModelLicenseHandler),
344+
("model/?([^/]*)/tokenizer", AquaModelTokenizerConfigHandler),
322345
("model/hf/search/?([^/]*)", AquaHuggingFaceHandler),
323346
]

ads/aqua/model/model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
1616
from ads.aqua.app import AquaApp
1717
from ads.aqua.common.enums import (
18+
ConfigFolder,
1819
CustomInferenceContainerTypeFamily,
1920
FineTuningContainerTypeFamily,
2021
InferenceContainerTypeFamily,
@@ -44,6 +45,7 @@
4445
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME,
4546
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE,
4647
AQUA_MODEL_ARTIFACT_FILE,
48+
AQUA_MODEL_TOKENIZER_CONFIG,
4749
AQUA_MODEL_TYPE_CUSTOM,
4850
HF_METADATA_FOLDER,
4951
LICENSE_TXT,
@@ -568,6 +570,26 @@ def _build_ft_metrics(
568570
training_final,
569571
]
570572

573+
def get_hf_tokenizer_config(self, model_id):
574+
"""Gets the default chat template for the given Aqua model.
575+
576+
Parameters
577+
----------
578+
model_id: str
579+
The OCID of the Aqua model.
580+
581+
Returns
582+
-------
583+
str:
584+
Chat template string.
585+
"""
586+
config = self.get_config(
587+
model_id, AQUA_MODEL_TOKENIZER_CONFIG, ConfigFolder.ARTIFACT
588+
)
589+
if not config:
590+
logger.debug(f"Tokenizer config for model: {model_id} is not available.")
591+
return config
592+
571593
@staticmethod
572594
def to_aqua_model(
573595
model: Union[

ads/aqua/modeldeployment/entities.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class AquaDeployment(DataClassSerializable):
4141
id: str = None
4242
display_name: str = None
4343
aqua_service_model: bool = None
44+
model_id: str = None
4445
aqua_model_name: str = None
4546
state: str = None
4647
description: str = None
@@ -97,7 +98,7 @@ def from_oci_model_deployment(
9798
else None
9899
),
99100
)
100-
101+
model_id = oci_model_deployment._model_deployment_configuration_details.model_configuration_details.model_id
101102
tags = {}
102103
tags.update(oci_model_deployment.freeform_tags or UNKNOWN_DICT)
103104
tags.update(oci_model_deployment.defined_tags or UNKNOWN_DICT)
@@ -110,6 +111,7 @@ def from_oci_model_deployment(
110111

111112
return AquaDeployment(
112113
id=oci_model_deployment.id,
114+
model_id=model_id,
113115
display_name=oci_model_deployment.display_name,
114116
aqua_service_model=aqua_service_model_tag is not None,
115117
aqua_model_name=aqua_model_name,

ads/model/artifact_downloader.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
32

43
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -12,9 +11,9 @@
1211
from zipfile import ZipFile
1312

1413
from ads.common import utils
14+
from ads.common.object_storage_details import ObjectStorageDetails
1515
from ads.common.utils import extract_region
1616
from ads.model.service.oci_datascience_model import OCIDataScienceModel
17-
from ads.common.object_storage_details import ObjectStorageDetails
1817

1918

2019
class ArtifactDownloader(ABC):
@@ -169,9 +168,9 @@ def __init__(
169168

170169
def _download(self):
171170
"""Downloads model artifacts."""
172-
self.progress.update(f"Importing model artifacts from catalog")
171+
self.progress.update("Importing model artifacts from catalog")
173172

174-
if self.dsc_model.is_model_by_reference() and self.model_file_description:
173+
if self.dsc_model._is_model_by_reference() and self.model_file_description:
175174
self.download_from_model_file_description()
176175
self.progress.update()
177176
return

ads/model/datascience_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1778,7 +1778,7 @@ def _update_from_oci_dsc_model(
17781778
artifact_info = self.dsc_model.get_artifact_info()
17791779
_, file_name_info = cgi.parse_header(artifact_info["Content-Disposition"])
17801780

1781-
if self.dsc_model.is_model_by_reference():
1781+
if self.dsc_model._is_model_by_reference():
17821782
_, file_extension = os.path.splitext(file_name_info["filename"])
17831783
if file_extension.lower() == ".json":
17841784
bucket_uri, _ = self._download_file_description_artifact()

ads/model/model_metadata.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,7 +1509,10 @@ def _from_oci_metadata(cls, metadata_list):
15091509
metadata = cls()
15101510
for oci_item in metadata_list:
15111511
item = ModelTaxonomyMetadataItem._from_oci_metadata(oci_item)
1512-
metadata[item.key].update(value=item.value)
1512+
if item.key in metadata.keys:
1513+
metadata[item.key].update(value=item.value)
1514+
else:
1515+
metadata._items.add(item)
15131516
return metadata
15141517

15151518
def to_dataframe(self) -> pd.DataFrame:
@@ -1562,7 +1565,10 @@ def from_dict(cls, data: Dict) -> "ModelTaxonomyMetadata":
15621565
metadata = cls()
15631566
for item in data["data"]:
15641567
item = ModelTaxonomyMetadataItem.from_dict(item)
1565-
metadata[item.key].update(value=item.value)
1568+
if item.key in metadata.keys:
1569+
metadata[item.key].update(value=item.value)
1570+
else:
1571+
metadata._items.add(item)
15661572
return metadata
15671573

15681574

ads/model/service/oci_datascience_model.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,32 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
32

43
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
import logging
8-
import time
97
from functools import wraps
108
from io import BytesIO
119
from typing import Callable, Dict, List, Optional
1210

1311
import oci.data_science
14-
from ads.common import utils
15-
from ads.common.object_storage_details import ObjectStorageDetails
16-
from ads.common.oci_datascience import OCIDataScienceMixin
17-
from ads.common.oci_mixin import OCIWorkRequestMixin
18-
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
19-
from ads.common.utils import extract_region
20-
from ads.common.work_request import DataScienceWorkRequest
21-
from ads.model.deployment import ModelDeployment
2212
from oci.data_science.models import (
2313
ArtifactExportDetailsObjectStorage,
2414
ArtifactImportDetailsObjectStorage,
2515
CreateModelDetails,
2616
ExportModelArtifactDetails,
2717
ImportModelArtifactDetails,
2818
UpdateModelDetails,
29-
WorkRequest,
3019
)
3120
from oci.exceptions import ServiceError
3221

22+
from ads.common.object_storage_details import ObjectStorageDetails
23+
from ads.common.oci_datascience import OCIDataScienceMixin
24+
from ads.common.oci_mixin import OCIWorkRequestMixin
25+
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
26+
from ads.common.utils import extract_region
27+
from ads.common.work_request import DataScienceWorkRequest
28+
from ads.model.deployment import ModelDeployment
29+
3330
logger = logging.getLogger(__name__)
3431

3532
_REQUEST_INTERVAL_IN_SEC = 3
@@ -282,7 +279,7 @@ def get_artifact_info(self) -> Dict:
282279
msg="Model needs to be restored before the archived artifact content can be accessed."
283280
)
284281
def restore_archived_model_artifact(
285-
self, restore_model_for_hours_specified: Optional[int] = None
282+
self, restore_model_for_hours_specified: Optional[int] = None
286283
) -> None:
287284
"""Restores the archived model artifact.
288285
@@ -304,7 +301,8 @@ def restore_archived_model_artifact(
304301
"""
305302
return self.client.restore_archived_model_artifact(
306303
model_id=self.id,
307-
restore_model_for_hours_specified=restore_model_for_hours_specified).headers["opc-work-request-id"]
304+
restore_model_for_hours_specified=restore_model_for_hours_specified,
305+
).headers["opc-work-request-id"]
308306

309307
@check_for_model_id(
310308
msg="Model needs to be saved to the Model Catalog before the artifact content can be read."
@@ -581,7 +579,7 @@ def from_id(cls, ocid: str) -> "OCIDataScienceModel":
581579
raise ValueError("Model OCID not provided.")
582580
return super().from_ocid(ocid)
583581

584-
def is_model_by_reference(self):
582+
def _is_model_by_reference(self):
585583
"""Checks if model is created by reference
586584
Returns
587585
-------

0 commit comments

Comments
 (0)