Skip to content

Commit a5cb572

Browse files
Addressing review comments
1 parent 035c4ee commit a5cb572

File tree

4 files changed

+17
-6
lines changed

4 files changed

+17
-6
lines changed

ads/aqua/app.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from ads import set_auth
1515
from ads.aqua import logger
16-
from ads.aqua.common.enums import Tags
16+
from ads.aqua.common.enums import ConfigFolder, Tags
1717
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1818
from ads.aqua.common.utils import (
1919
_is_valid_mvs,
@@ -272,7 +272,7 @@ def get_config(
272272
self,
273273
model_id: str,
274274
config_file_name: str,
275-
config_folder: Optional[str] = "config",
275+
config_folder: Optional[str] = ConfigFolder.CONFIG,
276276
) -> Dict:
277277
"""Gets the config for the given Aqua model.
278278
@@ -282,15 +282,16 @@ def get_config(
282282
The OCID of the Aqua model.
283283
config_file_name: str
284284
name of the config file
285-
config_folder: Optional[str]
285+
config_folder: (str, optional):
286286
subfolder path where config_file_name needs to be searched
287-
default value: config
287+
Defaults to `ConfigFolder.CONFIG`.
288288
289289
Returns
290290
-------
291291
Dict:
292292
A dict of allowed configs.
293293
"""
294+
config_folder = config_folder or ConfigFolder.CONFIG
294295
oci_model = self.ds_client.get_model(model_id).data
295296
oci_aqua = (
296297
(
@@ -314,7 +315,7 @@ def get_config(
314315
)
315316
base_model = self.ds_client.get_model(base_model_ocid).data
316317
artifact_path = get_artifact_path(base_model.custom_metadata_list)
317-
if config_folder == "artifact":
318+
if config_folder == ConfigFolder.ARTIFACT:
318319
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
319320
else:
320321
logger.info(f"Loading {config_file_name} for model {oci_model.id}...")

ads/aqua/common/enums.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,8 @@ class TextEmbeddingInferenceContainerParams(str, metaclass=ExtendedEnumMeta):
9292

9393
MODEL_ID = "model-id"
9494
PORT = "port"
95+
96+
97+
class ConfigFolder(str, metaclass=ExtendedEnumMeta):
98+
CONFIG = "config"
99+
ARTIFACT = "artifact"

ads/aqua/extension/model_handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,8 @@ def get(self, model_id):
322322
url_parse = urlparse(self.request.path)
323323
paths = url_parse.path.strip("/")
324324
path_list = paths.split("/")
325+
# Path should be /aqua/models/ocid1.iad.ahdxxx/tokenizer
326+
# path_list=['aqua','models','<model-ocid>','tokenizer']
325327
if (
326328
len(path_list) == 4
327329
and is_valid_ocid(path_list[2])

ads/aqua/model/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
1616
from ads.aqua.app import AquaApp
1717
from ads.aqua.common.enums import (
18+
ConfigFolder,
1819
CustomInferenceContainerTypeFamily,
1920
FineTuningContainerTypeFamily,
2021
InferenceContainerTypeFamily,
@@ -582,7 +583,9 @@ def get_hf_tokenizer_config(self, model_id):
582583
str:
583584
Chat template string.
584585
"""
585-
config = self.get_config(model_id, AQUA_MODEL_TOKENIZER_CONFIG, "artifact")
586+
config = self.get_config(
587+
model_id, AQUA_MODEL_TOKENIZER_CONFIG, ConfigFolder.ARTIFACT
588+
)
586589
if not config:
587590
logger.debug(f"Tokenizer config for model: {model_id} is not available.")
588591
return config

0 commit comments

Comments
 (0)