Skip to content

Commit 35c03c6

Browse files
authored
[AQUA] Enhance get_config to Return Model Details and Configuration in a Pydantic format (#1107)
2 parents 88d6feb + 3e64873 commit 35c03c6

File tree

8 files changed

+63
-32
lines changed

8 files changed

+63
-32
lines changed

ads/aqua/app.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
import os
77
import traceback
88
from dataclasses import fields
9-
from typing import Dict, Optional, Union
9+
from typing import Any, Dict, Optional, Union
1010

1111
import oci
1212
from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails
1313

1414
from ads import set_auth
1515
from ads.aqua import logger
16+
from ads.aqua.common.entities import ModelConfigResult
1617
from ads.aqua.common.enums import ConfigFolder, Tags
1718
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1819
from ads.aqua.common.utils import (
@@ -273,24 +274,24 @@ def get_config(
273274
model_id: str,
274275
config_file_name: str,
275276
config_folder: Optional[str] = ConfigFolder.CONFIG,
276-
) -> Dict:
277-
"""Gets the config for the given Aqua model.
277+
) -> ModelConfigResult:
278+
"""
279+
Gets the configuration for the given Aqua model along with the model details.
278280
279281
Parameters
280282
----------
281-
model_id: str
283+
model_id : str
282284
The OCID of the Aqua model.
283-
config_file_name: str
284-
name of the config file
285-
config_folder: (str, optional):
286-
subfolder path where config_file_name needs to be searched
287-
Defaults to `ConfigFolder.CONFIG`.
288-
When searching inside model artifact directory , the value is ConfigFolder.ARTIFACT`
285+
config_file_name : str
286+
The name of the configuration file.
287+
config_folder : Optional[str]
288+
The subfolder path where config_file_name is searched.
289+
Defaults to ConfigFolder.CONFIG. For model artifact directories, use ConfigFolder.ARTIFACT.
289290
290291
Returns
291292
-------
292-
Dict:
293-
A dict of allowed configs.
293+
ModelConfigResult
294+
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
294295
"""
295296
config_folder = config_folder or ConfigFolder.CONFIG
296297
oci_model = self.ds_client.get_model(model_id).data
@@ -302,11 +303,11 @@ def get_config(
302303
if oci_model.freeform_tags
303304
else False
304305
)
305-
306306
if not oci_aqua:
307-
raise AquaRuntimeError(f"Target model {oci_model.id} is not Aqua model.")
307+
raise AquaRuntimeError(f"Target model {oci_model.id} is not an Aqua model.")
308+
309+
config: Dict[str, Any] = {}
308310

309-
config = {}
310311
# if the current model has a service model tag, then
311312
if Tags.AQUA_SERVICE_MODEL_TAG in oci_model.freeform_tags:
312313
base_model_ocid = oci_model.freeform_tags[Tags.AQUA_SERVICE_MODEL_TAG]
@@ -326,7 +327,7 @@ def get_config(
326327
logger.debug(
327328
f"Failed to get artifact path from custom metadata for the model: {model_id}"
328329
)
329-
return config
330+
return ModelConfigResult(config=config, model_details=oci_model)
330331

331332
config_path = os.path.join(os.path.dirname(artifact_path), config_folder)
332333
if not is_path_exists(config_path):
@@ -351,9 +352,8 @@ def get_config(
351352
f"{config_file_name} is not available for the model: {model_id}. "
352353
f"Check if the custom metadata has the artifact path set."
353354
)
354-
return config
355355

356-
return config
356+
return ModelConfigResult(config=config, model_details=oci_model)
357357

358358
@property
359359
def telemetry(self):
@@ -375,9 +375,11 @@ def build_cli(self) -> str:
375375
"""
376376
cmd = f"ads aqua {self._command}"
377377
params = [
378-
f"--{field.name} {json.dumps(getattr(self, field.name))}"
379-
if isinstance(getattr(self, field.name), dict)
380-
else f"--{field.name} {getattr(self, field.name)}"
378+
(
379+
f"--{field.name} {json.dumps(getattr(self, field.name))}"
380+
if isinstance(getattr(self, field.name), dict)
381+
else f"--{field.name} {getattr(self, field.name)}"
382+
)
381383
for field in fields(self.__class__)
382384
if getattr(self, field.name) is not None
383385
]

ads/aqua/common/entities.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5+
from typing import Any, Dict, Optional
6+
7+
from oci.data_science.models import Model
8+
from pydantic import BaseModel, Field
9+
510

611
class ContainerSpec:
712
"""
@@ -15,3 +20,25 @@ class ContainerSpec:
1520
ENV_VARS = "envVars"
1621
RESTRICTED_PARAMS = "restrictedParams"
1722
EVALUATION_CONFIGURATION = "evaluationConfiguration"
23+
24+
25+
class ModelConfigResult(BaseModel):
26+
"""
27+
Represents the result of getting the AQUA model configuration.
28+
29+
Attributes:
30+
model_details (Dict[str, Any]): A dictionary containing model details extracted from OCI.
31+
config (Dict[str, Any]): A dictionary of the loaded configuration.
32+
"""
33+
34+
config: Optional[Dict[str, Any]] = Field(
35+
None, description="Loaded configuration dictionary."
36+
)
37+
model_details: Optional[Model] = Field(
38+
None, description="Details of the model from OCI."
39+
)
40+
41+
class Config:
42+
extra = "ignore"
43+
arbitrary_types_allowed = True
44+
protected_namespaces = ()

ads/aqua/finetuning/finetuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ def get_finetuning_config(self, model_id: str) -> Dict:
592592
Dict:
593593
A dict of allowed finetuning configs.
594594
"""
595-
config = self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG)
595+
config = self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG).config
596596
if not config:
597597
logger.debug(
598598
f"Fine-tuning config for custom model: {model_id} is not available. Use defaults."

ads/aqua/model/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def get_hf_tokenizer_config(self, model_id):
586586
"""
587587
config = self.get_config(
588588
model_id, AQUA_MODEL_TOKENIZER_CONFIG, ConfigFolder.ARTIFACT
589-
)
589+
).config
590590
if not config:
591591
logger.debug(f"Tokenizer config for model: {model_id} is not available.")
592592
return config

ads/aqua/modeldeployment/deployment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ def get_deployment_config(self, model_id: str) -> Dict:
654654
Dict:
655655
A dict of allowed deployment configs.
656656
"""
657-
config = self.get_config(model_id, AQUA_MODEL_DEPLOYMENT_CONFIG)
657+
config = self.get_config(model_id, AQUA_MODEL_DEPLOYMENT_CONFIG).config
658658
if not config:
659659
logger.debug(
660660
f"Deployment config for custom model: {model_id} is not available. Use defaults."

tests/unitary/with_extras/aqua/test_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ def test_load_config(
9696
model_id="test_model_id", config_file_name="test_config_file_name"
9797
)
9898
if not path_exists:
99-
assert result == {}
99+
assert result.config == {}
100100
if not custom_metadata:
101-
assert result == {}
101+
assert result.config == {}
102102
if path_exists and custom_metadata:
103-
assert result == {"config_key": "config_value"}
103+
assert result.config == {"config_key": "config_value"}

tests/unitary/with_extras/aqua/test_deployment.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*--
33

4-
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
import copy
@@ -16,6 +16,7 @@
1616
import pytest
1717
from parameterized import parameterized
1818

19+
from ads.aqua.common.entities import ModelConfigResult
1920
import ads.aqua.modeldeployment.deployment
2021
import ads.config
2122
from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse
@@ -438,11 +439,11 @@ def test_get_deployment_config(self):
438439
with open(config_json, "r") as _file:
439440
config = json.load(_file)
440441

441-
self.app.get_config = MagicMock(return_value=config)
442+
self.app.get_config = MagicMock(return_value=ModelConfigResult(config=config))
442443
result = self.app.get_deployment_config(TestDataset.MODEL_ID)
443444
assert result == config
444445

445-
self.app.get_config = MagicMock(return_value=None)
446+
self.app.get_config = MagicMock(return_value=ModelConfigResult(config=None))
446447
result = self.app.get_deployment_config(TestDataset.MODEL_ID)
447448
assert result == None
448449

tests/unitary/with_extras/aqua/test_finetuning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import ads.aqua.finetuning.finetuning
1818
import ads.config
1919
from ads.aqua.app import AquaApp
20+
from ads.aqua.common.entities import ModelConfigResult
2021
from ads.aqua.common.errors import AquaValueError
2122
from ads.aqua.finetuning import AquaFineTuningApp
2223
from ads.aqua.finetuning.constants import FineTuneCustomMetadata
@@ -279,7 +280,7 @@ def test_get_finetuning_config(self):
279280
with open(config_json, "r") as _file:
280281
config = json.load(_file)
281282

282-
self.app.get_config = MagicMock(return_value=config)
283+
self.app.get_config = MagicMock(return_value=ModelConfigResult(config=config))
283284
result = self.app.get_finetuning_config(model_id="test-model-id")
284285
assert result == config
285286

0 commit comments

Comments
 (0)