Skip to content

Commit e5710ba

Browse files
committed
Enhance the evaluation service config.
1 parent 56abb06 commit e5710ba

File tree

7 files changed

+356
-439
lines changed

7 files changed

+356
-439
lines changed

ads/aqua/config/config.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,21 @@
44

55

66
from datetime import datetime, timedelta
7+
from typing import Optional
78

89
from cachetools import TTLCache, cached
910

10-
from ads.aqua.common.utils import service_config_path
11+
from ads.aqua.common.entities import ContainerSpec
12+
from ads.aqua.common.utils import get_container_config
1113
from ads.aqua.config.evaluation.evaluation_service_config import EvaluationServiceConfig
12-
from ads.aqua.constants import EVALUATION_SERVICE_CONFIG
14+
15+
DEFAULT_EVALUATION_CONTAINER = "odsc-llm-evaluate"
1316

1417

1518
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
16-
def evaluation_service_config() -> EvaluationServiceConfig:
19+
def evaluation_service_config(
20+
container: Optional[str] = DEFAULT_EVALUATION_CONTAINER,
21+
) -> EvaluationServiceConfig:
1722
"""
1823
Retrieves the common evaluation configuration.
1924
@@ -22,8 +27,10 @@ def evaluation_service_config() -> EvaluationServiceConfig:
2227
EvaluationServiceConfig: The evaluation common config.
2328
"""
2429

25-
return EvaluationServiceConfig.from_json(
26-
uri=f"{service_config_path()}/{EVALUATION_SERVICE_CONFIG}"
30+
return EvaluationServiceConfig(
31+
**get_container_config()
32+
.get(ContainerSpec.CONTAINER_SPEC, {})
33+
.get(container, {})
2734
)
2835

2936

ads/aqua/config/evaluation/evaluation_service_config.py

Lines changed: 91 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
from copy import deepcopy
7-
from typing import Any, Dict, List, Optional, Union
7+
from typing import Any, Dict, List, Optional
88

99
from pydantic import Field
1010

@@ -19,17 +19,6 @@
1919
INFERENCE_DELAY = 0
2020

2121

22-
class ModelParamItem(Serializable):
23-
"""Represents min, max, and default values for a model parameter."""
24-
25-
min: Optional[Union[int, float]] = None
26-
max: Optional[Union[int, float]] = None
27-
default: Optional[Union[int, float]] = None
28-
29-
class Config:
30-
extra = "ignore"
31-
32-
3322
class ModelParamsOverrides(Serializable):
3423
"""Defines overrides for model parameters, including exclusions and additional inclusions."""
3524

@@ -51,28 +40,11 @@ class Config:
5140
extra = "ignore"
5241

5342

54-
class ModelDefaultParams(Serializable):
55-
"""Defines default parameters for a model within a specific framework."""
56-
57-
model: Optional[str] = None
58-
max_tokens: Optional[ModelParamItem] = Field(default_factory=ModelParamItem)
59-
temperature: Optional[ModelParamItem] = Field(default_factory=ModelParamItem)
60-
top_p: Optional[ModelParamItem] = Field(default_factory=ModelParamItem)
61-
top_k: Optional[ModelParamItem] = Field(default_factory=ModelParamItem)
62-
presence_penalty: Optional[ModelParamItem] = Field(default_factory=ModelParamItem)
63-
frequency_penalty: Optional[ModelParamItem] = Field(default_factory=ModelParamItem)
64-
stop: List[str] = Field(default_factory=list)
43+
class ModelParamsContainer(Serializable):
44+
"""Represents a container's model configuration, including tasks, defaults, and versions."""
6545

66-
class Config:
67-
extra = "allow"
68-
69-
70-
class ModelFramework(Serializable):
71-
"""Represents a framework's model configuration, including tasks, defaults, and versions."""
72-
73-
framework: Optional[str] = None
74-
task: Optional[List[str]] = Field(default_factory=list)
75-
default: Optional[ModelDefaultParams] = Field(default_factory=ModelDefaultParams)
46+
name: Optional[str] = None
47+
default: Optional[Dict[str, Any]] = Field(default_factory=dict)
7648
versions: Optional[Dict[str, ModelParamsVersion]] = Field(default_factory=dict)
7749

7850
class Config:
@@ -93,10 +65,10 @@ class Config:
9365
extra = "allow"
9466

9567

96-
class InferenceFramework(Serializable):
97-
"""Represents the inference parameters specific to a framework."""
68+
class InferenceContainer(Serializable):
69+
"""Represents the inference parameters specific to a container."""
9870

99-
framework: Optional[str] = None
71+
name: Optional[str] = None
10072
params: Optional[Dict[str, Any]] = Field(default_factory=dict)
10173

10274
class Config:
@@ -113,70 +85,66 @@ class Config:
11385

11486

11587
class InferenceParamsConfig(Serializable):
116-
"""Combines default inference parameters with framework-specific configurations."""
88+
"""Combines default inference parameters with container-specific configurations."""
11789

11890
default: Optional[InferenceParams] = Field(default_factory=InferenceParams)
119-
frameworks: Optional[List[InferenceFramework]] = Field(default_factory=list)
91+
containers: Optional[List[InferenceContainer]] = Field(default_factory=list)
12092

121-
def get_merged_params(self, framework_name: str) -> InferenceParams:
93+
def get_merged_params(self, container_name: str) -> InferenceParams:
12294
"""
123-
Merges default inference params with those specific to the given framework.
95+
Merges default inference params with those specific to the given container.
12496
12597
Parameters
12698
----------
127-
framework_name (str): The name of the framework.
99+
container_name (str): The name of the container.
128100
129101
Returns
130102
-------
131103
InferenceParams: The merged inference parameters.
132104
"""
133105
merged_params = self.default.to_dict()
134-
for framework in self.frameworks:
135-
if framework.framework.lower() == framework_name.lower():
136-
merged_params.update(framework.params or {})
106+
for containers in self.containers:
107+
if containers.name.lower() == container_name.lower():
108+
merged_params.update(containers.params or {})
137109
break
138110
return InferenceParams(**merged_params)
139111

140112
class Config:
141113
extra = "ignore"
142114

143115

144-
class ModelParamsConfig(Serializable):
145-
"""Encapsulates the model parameters for different frameworks."""
116+
class InferenceModelParamsConfig(Serializable):
117+
"""Encapsulates the model parameters for different containers."""
146118

147119
default: Optional[Dict[str, Any]] = Field(default_factory=dict)
148-
frameworks: Optional[List[ModelFramework]] = Field(default_factory=list)
120+
containers: Optional[List[ModelParamsContainer]] = Field(default_factory=list)
149121

150-
def get_model_params(
122+
def get_merged_model_params(
151123
self,
152-
framework_name: str,
124+
container_name: str,
153125
version: Optional[str] = None,
154-
task: Optional[str] = None,
155126
) -> Dict[str, Any]:
156127
"""
157-
Gets the model parameters for a given framework, version, and tasks,
128+
Gets the model parameters for a given container, version,
158129
merged with the defaults.
159130
160131
Parameters
161132
----------
162-
framework_name (str): The name of the framework.
163-
version (Optional[str]): The specific version of the framework.
164-
task (Optional[str]): The specific task.
133+
container_name (str): The name of the container.
134+
version (Optional[str]): The specific version of the container.
165135
166136
Returns
167137
-------
168138
Dict[str, Any]: The merged model parameters.
169139
"""
170140
params = deepcopy(self.default)
171141

172-
for framework in self.frameworks:
173-
if framework.framework.lower() == framework_name.lower() and (
174-
not task or task.lower() in framework.task
175-
):
176-
params.update(framework.default.to_dict())
142+
for container in self.containers:
143+
if container.name.lower() == container_name.lower():
144+
params.update(container.default)
177145

178-
if version and version in framework.versions:
179-
version_overrides = framework.versions[version].overrides
146+
if version and version in container.versions:
147+
version_overrides = container.versions[version].overrides
180148
if version_overrides:
181149
if version_overrides.include:
182150
params.update(version_overrides.include)
@@ -228,59 +196,17 @@ class Config:
228196
extra = "ignore"
229197

230198

231-
class EvaluationServiceConfig(Serializable):
232-
"""
233-
Root configuration class for evaluation setup including model,
234-
inference, and shape configurations.
235-
"""
199+
class ModelParamsConfig(Serializable):
200+
"""Encapsulates the default model parameters."""
236201

237-
version: Optional[str] = "1.0"
238-
kind: Optional[str] = "evaluation"
239-
report_params: Optional[ReportParams] = Field(default_factory=ReportParams)
240-
inference_params: Optional[InferenceParamsConfig] = Field(
241-
default_factory=InferenceParamsConfig
242-
)
202+
default: Optional[Dict[str, Any]] = Field(default_factory=dict)
203+
204+
205+
class UIConfig(Serializable):
243206
model_params: Optional[ModelParamsConfig] = Field(default_factory=ModelParamsConfig)
244207
shapes: List[ShapeConfig] = Field(default_factory=list)
245208
metrics: List[MetricConfig] = Field(default_factory=list)
246209

247-
def get_merged_inference_params(self, framework_name: str) -> InferenceParams:
248-
"""
249-
Merges default inference params with those specific to the given framework.
250-
251-
Params
252-
------
253-
framework_name (str): The name of the framework.
254-
255-
Returns
256-
-------
257-
InferenceParams: The merged inference parameters.
258-
"""
259-
return self.inference_params.get_merged_params(framework_name=framework_name)
260-
261-
def get_merged_model_params(
262-
self,
263-
framework_name: str,
264-
version: Optional[str] = None,
265-
task: Optional[str] = None,
266-
) -> Dict[str, Any]:
267-
"""
268-
Gets the model parameters for a given framework, version, and task, merged with the defaults.
269-
270-
Parameters
271-
----------
272-
framework_name (str): The name of the framework.
273-
version (Optional[str]): The specific version of the framework.
274-
task (Optional[str]): The task.
275-
276-
Returns
277-
-------
278-
Dict[str, Any]: The merged model parameters.
279-
"""
280-
return self.model_params.get_model_params(
281-
framework_name=framework_name, version=version, task=task
282-
)
283-
284210
def search_shapes(
285211
self,
286212
evaluation_container: Optional[str] = None,
@@ -315,3 +241,59 @@ def search_shapes(
315241

316242
class Config:
317243
extra = "ignore"
244+
245+
246+
class EvaluationServiceConfig(Serializable):
247+
"""
248+
Root configuration class for evaluation setup including model,
249+
inference, and shape configurations.
250+
"""
251+
252+
version: Optional[str] = "1.0"
253+
kind: Optional[str] = "evaluation"
254+
report_params: Optional[ReportParams] = Field(default_factory=ReportParams)
255+
inference_params: Optional[InferenceParamsConfig] = Field(
256+
default_factory=InferenceParamsConfig
257+
)
258+
inference_model_params: Optional[InferenceModelParamsConfig] = Field(
259+
default_factory=InferenceModelParamsConfig
260+
)
261+
ui_config: Optional[UIConfig] = Field(default_factory=UIConfig)
262+
263+
def get_merged_inference_params(self, container_name: str) -> InferenceParams:
264+
"""
265+
Merges default inference params with those specific to the given container.
266+
267+
Params
268+
------
269+
container_name (str): The name of the container.
270+
271+
Returns
272+
-------
273+
InferenceParams: The merged inference parameters.
274+
"""
275+
return self.inference_params.get_merged_params(container_name=container_name)
276+
277+
def get_merged_inference_model_params(
278+
self,
279+
container_name: str,
280+
version: Optional[str] = None,
281+
) -> Dict[str, Any]:
282+
"""
283+
Gets the model parameters for a given container, version, and task, merged with the defaults.
284+
285+
Parameters
286+
----------
287+
container_name (str): The name of the container.
288+
version (Optional[str]): The specific version of the container.
289+
290+
Returns
291+
-------
292+
Dict[str, Any]: The merged model parameters.
293+
"""
294+
return self.inference_model_params.get_merged_model_params(
295+
container_name=container_name, version=version
296+
)
297+
298+
class Config:
299+
extra = "ignore"

ads/aqua/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
EVALUATION_REPORT_JSON = "report.json"
1616
EVALUATION_REPORT_MD = "report.md"
1717
EVALUATION_REPORT = "report.html"
18-
EVALUATION_SERVICE_CONFIG = "evaluation.json"
1918
UNKNOWN_JSON_STR = "{}"
2019
FINE_TUNING_RUNTIME_CONTAINER = "iad.ocir.io/ociodscdev/aqua_ft_cuda121:0.3.17.20"
2120
DEFAULT_FT_BLOCK_STORAGE_SIZE = 750

tests/unitary/with_extras/aqua/test_config.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,38 @@
22
# Copyright (c) 2024 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5+
import json
6+
import os
7+
from unittest.mock import patch
58

6-
from unittest.mock import MagicMock, patch
7-
8-
from ads.aqua.common.utils import service_config_path
9-
from ads.aqua.config.config import evaluation_config
10-
from ads.aqua.config.evaluation.evaluation_service_config import EvaluationServiceConfig
11-
from ads.aqua.constants import EVALUATION_SERVICE_CONFIG
9+
from ads.aqua.common.entities import ContainerSpec
10+
from ads.aqua.config.config import evaluation_service_config
1211

1312

1413
class TestConfig:
1514
"""Unit tests for AQUA common configurations."""
1615

17-
@patch.object(EvaluationServiceConfig, "from_json")
18-
def test_evaluation_service_config(self, mock_from_json):
19-
"""Ensures that the common evaluation configuration can be successfully retrieved."""
20-
21-
expected_result = MagicMock()
22-
mock_from_json.return_value = expected_result
16+
def setup_class(cls):
17+
cls.curr_dir = os.path.dirname(os.path.abspath(__file__))
18+
cls.artifact_dir = os.path.join(cls.curr_dir, "test_data", "config")
2319

24-
test_result = evaluation_config()
20+
@patch("ads.aqua.config.config.get_container_config")
21+
def test_evaluation_service_config(self, mock_get_container_config):
22+
"""Ensures that the common evaluation configuration can be successfully retrieved."""
2523

26-
mock_from_json.assert_called_with(
27-
uri=f"{service_config_path()}/{EVALUATION_SERVICE_CONFIG}"
24+
with open(
25+
os.path.join(
26+
self.artifact_dir, "evaluation_config_with_default_params.json"
27+
)
28+
) as file:
29+
expected_result = {
30+
ContainerSpec.CONTAINER_SPEC: {"test_container": json.load(file)}
31+
}
32+
33+
mock_get_container_config.return_value = expected_result
34+
35+
test_result = evaluation_service_config(container="test_container")
36+
assert (
37+
test_result.to_dict()
38+
== expected_result[ContainerSpec.CONTAINER_SPEC]["test_container"]
2839
)
29-
assert test_result == expected_result

0 commit comments

Comments
 (0)