Skip to content

Commit 192248a

Browse files
Merge branch 'main' into feature/aqua_ms_changes_2
2 parents b4999d0 + 35c03c6 commit 192248a

File tree

8 files changed

+63
-32
lines changed

8 files changed

+63
-32
lines changed

ads/aqua/app.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
import os
77
import traceback
88
from dataclasses import fields
9-
from typing import Dict, Optional, Union
9+
from typing import Any, Dict, Optional, Union
1010

1111
import oci
1212
from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails
1313

1414
from ads import set_auth
1515
from ads.aqua import logger
16+
from ads.aqua.common.entities import ModelConfigResult
1617
from ads.aqua.common.enums import ConfigFolder, Tags
1718
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1819
from ads.aqua.common.utils import (
@@ -272,24 +273,24 @@ def get_config(
272273
model_id: str,
273274
config_file_name: str,
274275
config_folder: Optional[str] = ConfigFolder.CONFIG,
275-
) -> Dict:
276-
"""Gets the config for the given Aqua model.
276+
) -> ModelConfigResult:
277+
"""
278+
Gets the configuration for the given Aqua model along with the model details.
277279
278280
Parameters
279281
----------
280-
model_id: str
282+
model_id : str
281283
The OCID of the Aqua model.
282-
config_file_name: str
283-
name of the config file
284-
config_folder: (str, optional):
285-
subfolder path where config_file_name needs to be searched
286-
Defaults to `ConfigFolder.CONFIG`.
287-
When searching inside model artifact directory , the value is ConfigFolder.ARTIFACT`
284+
config_file_name : str
285+
The name of the configuration file.
286+
config_folder : Optional[str]
287+
The subfolder path where config_file_name is searched.
288+
Defaults to ConfigFolder.CONFIG. For model artifact directories, use ConfigFolder.ARTIFACT.
288289
289290
Returns
290291
-------
291-
Dict:
292-
A dict of allowed configs.
292+
ModelConfigResult
293+
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
293294
"""
294295
config_folder = config_folder or ConfigFolder.CONFIG
295296
oci_model = self.ds_client.get_model(model_id).data
@@ -301,11 +302,11 @@ def get_config(
301302
if oci_model.freeform_tags
302303
else False
303304
)
304-
305305
if not oci_aqua:
306-
raise AquaRuntimeError(f"Target model {oci_model.id} is not Aqua model.")
306+
raise AquaRuntimeError(f"Target model {oci_model.id} is not an Aqua model.")
307+
308+
config: Dict[str, Any] = {}
307309

308-
config = {}
309310
# if the current model has a service model tag, then
310311
if Tags.AQUA_SERVICE_MODEL_TAG in oci_model.freeform_tags:
311312
base_model_ocid = oci_model.freeform_tags[Tags.AQUA_SERVICE_MODEL_TAG]
@@ -325,7 +326,7 @@ def get_config(
325326
logger.debug(
326327
f"Failed to get artifact path from custom metadata for the model: {model_id}"
327328
)
328-
return config
329+
return ModelConfigResult(config=config, model_details=oci_model)
329330

330331
config_path = os.path.join(os.path.dirname(artifact_path), config_folder)
331332
if not is_path_exists(config_path):
@@ -350,9 +351,8 @@ def get_config(
350351
f"{config_file_name} is not available for the model: {model_id}. "
351352
f"Check if the custom metadata has the artifact path set."
352353
)
353-
return config
354354

355-
return config
355+
return ModelConfigResult(config=config, model_details=oci_model)
356356

357357
@property
358358
def telemetry(self):
@@ -374,9 +374,11 @@ def build_cli(self) -> str:
374374
"""
375375
cmd = f"ads aqua {self._command}"
376376
params = [
377-
f"--{field.name} {json.dumps(getattr(self, field.name))}"
378-
if isinstance(getattr(self, field.name), dict)
379-
else f"--{field.name} {getattr(self, field.name)}"
377+
(
378+
f"--{field.name} {json.dumps(getattr(self, field.name))}"
379+
if isinstance(getattr(self, field.name), dict)
380+
else f"--{field.name} {getattr(self, field.name)}"
381+
)
380382
for field in fields(self.__class__)
381383
if getattr(self, field.name) is not None
382384
]

ads/aqua/common/entities.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5+
from typing import Any, Dict, Optional
6+
7+
from oci.data_science.models import Model
8+
from pydantic import BaseModel, Field
9+
510

611
class ContainerSpec:
712
"""
@@ -15,3 +20,25 @@ class ContainerSpec:
1520
ENV_VARS = "envVars"
1621
RESTRICTED_PARAMS = "restrictedParams"
1722
EVALUATION_CONFIGURATION = "evaluationConfiguration"
23+
24+
25+
class ModelConfigResult(BaseModel):
26+
"""
27+
Represents the result of getting the AQUA model configuration.
28+
29+
Attributes:
30+
model_details (Dict[str, Any]): A dictionary containing model details extracted from OCI.
31+
config (Dict[str, Any]): A dictionary of the loaded configuration.
32+
"""
33+
34+
config: Optional[Dict[str, Any]] = Field(
35+
None, description="Loaded configuration dictionary."
36+
)
37+
model_details: Optional[Model] = Field(
38+
None, description="Details of the model from OCI."
39+
)
40+
41+
class Config:
42+
extra = "ignore"
43+
arbitrary_types_allowed = True
44+
protected_namespaces = ()

ads/aqua/finetuning/finetuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ def get_finetuning_config(self, model_id: str) -> Dict:
591591
Dict:
592592
A dict of allowed finetuning configs.
593593
"""
594-
config = self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG)
594+
config = self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG).config
595595
if not config:
596596
logger.debug(
597597
f"Fine-tuning config for custom model: {model_id} is not available. Use defaults."

ads/aqua/model/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ def get_hf_tokenizer_config(self, model_id):
585585
"""
586586
config = self.get_config(
587587
model_id, AQUA_MODEL_TOKENIZER_CONFIG, ConfigFolder.ARTIFACT
588-
)
588+
).config
589589
if not config:
590590
logger.debug(f"Tokenizer config for model: {model_id} is not available.")
591591
return config

ads/aqua/modeldeployment/deployment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def get_deployment_config(self, model_id: str) -> Dict:
653653
Dict:
654654
A dict of allowed deployment configs.
655655
"""
656-
config = self.get_config(model_id, AQUA_MODEL_DEPLOYMENT_CONFIG)
656+
config = self.get_config(model_id, AQUA_MODEL_DEPLOYMENT_CONFIG).config
657657
if not config:
658658
logger.debug(
659659
f"Deployment config for custom model: {model_id} is not available. Use defaults."

tests/unitary/with_extras/aqua/test_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ def test_load_config(
9696
model_id="test_model_id", config_file_name="test_config_file_name"
9797
)
9898
if not path_exists:
99-
assert result == {}
99+
assert result.config == {}
100100
if not custom_metadata:
101-
assert result == {}
101+
assert result.config == {}
102102
if path_exists and custom_metadata:
103-
assert result == {"config_key": "config_value"}
103+
assert result.config == {"config_key": "config_value"}

tests/unitary/with_extras/aqua/test_deployment.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*--
33

4-
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
import copy
@@ -16,6 +16,7 @@
1616
import pytest
1717
from parameterized import parameterized
1818

19+
from ads.aqua.common.entities import ModelConfigResult
1920
import ads.aqua.modeldeployment.deployment
2021
import ads.config
2122
from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse
@@ -438,11 +439,11 @@ def test_get_deployment_config(self):
438439
with open(config_json, "r") as _file:
439440
config = json.load(_file)
440441

441-
self.app.get_config = MagicMock(return_value=config)
442+
self.app.get_config = MagicMock(return_value=ModelConfigResult(config=config))
442443
result = self.app.get_deployment_config(TestDataset.MODEL_ID)
443444
assert result == config
444445

445-
self.app.get_config = MagicMock(return_value=None)
446+
self.app.get_config = MagicMock(return_value=ModelConfigResult(config=None))
446447
result = self.app.get_deployment_config(TestDataset.MODEL_ID)
447448
assert result == None
448449

tests/unitary/with_extras/aqua/test_finetuning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import ads.aqua.finetuning.finetuning
1818
import ads.config
1919
from ads.aqua.app import AquaApp
20+
from ads.aqua.common.entities import ModelConfigResult
2021
from ads.aqua.common.errors import AquaValueError
2122
from ads.aqua.finetuning import AquaFineTuningApp
2223
from ads.aqua.finetuning.constants import FineTuneCustomMetadata
@@ -279,7 +280,7 @@ def test_get_finetuning_config(self):
279280
with open(config_json, "r") as _file:
280281
config = json.load(_file)
281282

282-
self.app.get_config = MagicMock(return_value=config)
283+
self.app.get_config = MagicMock(return_value=ModelConfigResult(config=config))
283284
result = self.app.get_finetuning_config(model_id="test-model-id")
284285
assert result == config
285286

0 commit comments

Comments
 (0)