Skip to content

Commit 75fb959

Browse files
authored
[AQUA][Evaluate] Added the EvaluationServiceConfig class to manage evaluation service configurations. (#940)
2 parents 81494c2 + b0ca718 commit 75fb959

22 files changed

+1331
-178
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,6 @@ logs/
163163

164164
# Python Wheel
165165
*.whl
166+
167+
# The demo folder
168+
.demo

ads/aqua/common/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -536,14 +536,14 @@ def _build_job_identifier(
536536
return AquaResourceIdentifier()
537537

538538

539-
def container_config_path():
539+
def service_config_path():
540540
return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
541541

542542

543543
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
544544
def get_container_config():
545545
config = load_config(
546-
file_path=container_config_path(),
546+
file_path=service_config_path(),
547547
config_file_name=CONTAINER_INDEX,
548548
)
549549

@@ -568,7 +568,7 @@ def get_container_image(
568568
"""
569569

570570
config = config_file_name or get_container_config()
571-
config_file_name = container_config_path()
571+
config_file_name = service_config_path()
572572

573573
if container_type not in config:
574574
raise AquaValueError(

ads/aqua/config/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

ads/aqua/config/config.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,34 @@
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55

6+
from typing import Optional
7+
8+
from ads.aqua.common.entities import ContainerSpec
9+
from ads.aqua.common.utils import get_container_config
10+
from ads.aqua.config.evaluation.evaluation_service_config import EvaluationServiceConfig
11+
12+
DEFAULT_EVALUATION_CONTAINER = "odsc-llm-evaluate"
13+
14+
15+
def evaluation_service_config(
16+
container: Optional[str] = DEFAULT_EVALUATION_CONTAINER,
17+
) -> EvaluationServiceConfig:
18+
"""
19+
Retrieves the common evaluation configuration.
20+
21+
Returns
22+
-------
23+
EvaluationServiceConfig: The evaluation common config.
24+
"""
25+
26+
container = container or DEFAULT_EVALUATION_CONTAINER
27+
return EvaluationServiceConfig(
28+
**get_container_config()
29+
.get(ContainerSpec.CONTAINER_SPEC, {})
30+
.get(container, {})
31+
)
32+
33+
634
# TODO: move this to global config.json in object storage
735
def get_finetuning_config_defaults():
836
"""Generate and return the fine-tuning default configuration dictionary."""
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
from copy import deepcopy
7+
from typing import Any, Dict, List, Optional
8+
9+
from pydantic import Field
10+
11+
from ads.aqua.config.utils.serializer import Serializable
12+
13+
14+
class ModelParamsOverrides(Serializable):
15+
"""Defines overrides for model parameters, including exclusions and additional inclusions."""
16+
17+
exclude: Optional[List[str]] = Field(default_factory=list)
18+
include: Optional[Dict[str, Any]] = Field(default_factory=dict)
19+
20+
class Config:
21+
extra = "ignore"
22+
23+
24+
class ModelParamsVersion(Serializable):
25+
"""Handles version-specific model parameter overrides."""
26+
27+
overrides: Optional[ModelParamsOverrides] = Field(
28+
default_factory=ModelParamsOverrides
29+
)
30+
31+
class Config:
32+
extra = "ignore"
33+
34+
35+
class ModelParamsContainer(Serializable):
36+
"""Represents a container's model configuration, including tasks, defaults, and versions."""
37+
38+
name: Optional[str] = None
39+
default: Optional[Dict[str, Any]] = Field(default_factory=dict)
40+
versions: Optional[Dict[str, ModelParamsVersion]] = Field(default_factory=dict)
41+
42+
class Config:
43+
extra = "ignore"
44+
45+
46+
class InferenceParams(Serializable):
47+
"""Contains inference-related parameters with defaults."""
48+
49+
class Config:
50+
extra = "allow"
51+
52+
53+
class InferenceContainer(Serializable):
54+
"""Represents the inference parameters specific to a container."""
55+
56+
name: Optional[str] = None
57+
params: Optional[Dict[str, Any]] = Field(default_factory=dict)
58+
59+
class Config:
60+
extra = "ignore"
61+
62+
63+
class ReportParams(Serializable):
64+
"""Handles the report-related parameters."""
65+
66+
default: Optional[Dict[str, Any]] = Field(default_factory=dict)
67+
68+
class Config:
69+
extra = "ignore"
70+
71+
72+
class InferenceParamsConfig(Serializable):
73+
"""Combines default inference parameters with container-specific configurations."""
74+
75+
default: Optional[InferenceParams] = Field(default_factory=InferenceParams)
76+
containers: Optional[List[InferenceContainer]] = Field(default_factory=list)
77+
78+
def get_merged_params(self, container_name: str) -> InferenceParams:
79+
"""
80+
Merges default inference params with those specific to the given container.
81+
82+
Parameters
83+
----------
84+
container_name (str): The name of the container.
85+
86+
Returns
87+
-------
88+
InferenceParams: The merged inference parameters.
89+
"""
90+
merged_params = self.default.to_dict()
91+
for containers in self.containers:
92+
if containers.name.lower() == container_name.lower():
93+
merged_params.update(containers.params or {})
94+
break
95+
return InferenceParams(**merged_params)
96+
97+
class Config:
98+
extra = "ignore"
99+
100+
101+
class InferenceModelParamsConfig(Serializable):
102+
"""Encapsulates the model parameters for different containers."""
103+
104+
default: Optional[Dict[str, Any]] = Field(default_factory=dict)
105+
containers: Optional[List[ModelParamsContainer]] = Field(default_factory=list)
106+
107+
def get_merged_model_params(
108+
self,
109+
container_name: str,
110+
version: Optional[str] = None,
111+
) -> Dict[str, Any]:
112+
"""
113+
Gets the model parameters for a given container, version,
114+
merged with the defaults.
115+
116+
Parameters
117+
----------
118+
container_name (str): The name of the container.
119+
version (Optional[str]): The specific version of the container.
120+
121+
Returns
122+
-------
123+
Dict[str, Any]: The merged model parameters.
124+
"""
125+
params = deepcopy(self.default)
126+
127+
for container in self.containers:
128+
if container.name.lower() == container_name.lower():
129+
params.update(container.default)
130+
131+
if version and version in container.versions:
132+
version_overrides = container.versions[version].overrides
133+
if version_overrides:
134+
if version_overrides.include:
135+
params.update(version_overrides.include)
136+
if version_overrides.exclude:
137+
for key in version_overrides.exclude:
138+
params.pop(key, None)
139+
break
140+
141+
return params
142+
143+
class Config:
144+
extra = "ignore"
145+
146+
147+
class ShapeFilterConfig(Serializable):
148+
"""Represents the filtering options for a specific shape."""
149+
150+
evaluation_container: Optional[List[str]] = Field(default_factory=list)
151+
evaluation_target: Optional[List[str]] = Field(default_factory=list)
152+
153+
class Config:
154+
extra = "ignore"
155+
156+
157+
class ShapeConfig(Serializable):
158+
"""Defines the configuration for a specific shape."""
159+
160+
name: Optional[str] = None
161+
ocpu: Optional[int] = None
162+
memory_in_gbs: Optional[int] = None
163+
block_storage_size: Optional[int] = None
164+
filter: Optional[ShapeFilterConfig] = Field(default_factory=ShapeFilterConfig)
165+
166+
class Config:
167+
extra = "allow"
168+
169+
170+
class MetricConfig(Serializable):
171+
"""Handles metric configuration including task, key, and additional details."""
172+
173+
task: Optional[List[str]] = Field(default_factory=list)
174+
key: Optional[str] = None
175+
name: Optional[str] = None
176+
description: Optional[str] = None
177+
args: Optional[Dict[str, Any]] = Field(default_factory=dict)
178+
tags: Optional[List[str]] = Field(default_factory=list)
179+
180+
class Config:
181+
extra = "ignore"
182+
183+
184+
class ModelParamsConfig(Serializable):
185+
"""Encapsulates the default model parameters."""
186+
187+
default: Optional[Dict[str, Any]] = Field(default_factory=dict)
188+
189+
190+
class UIConfig(Serializable):
191+
model_params: Optional[ModelParamsConfig] = Field(default_factory=ModelParamsConfig)
192+
shapes: List[ShapeConfig] = Field(default_factory=list)
193+
metrics: List[MetricConfig] = Field(default_factory=list)
194+
195+
def search_shapes(
196+
self,
197+
evaluation_container: Optional[str] = None,
198+
evaluation_target: Optional[str] = None,
199+
) -> List[ShapeConfig]:
200+
"""
201+
Searches for shapes that match the given filters.
202+
203+
Parameters
204+
----------
205+
evaluation_container (Optional[str]): Filter for evaluation_container.
206+
evaluation_target (Optional[str]): Filter for evaluation_target.
207+
208+
Returns
209+
-------
210+
List[ShapeConfig]: A list of shapes that match the filters.
211+
"""
212+
return [
213+
shape
214+
for shape in self.shapes
215+
if (
216+
not evaluation_container
217+
or evaluation_container in shape.filter.evaluation_container
218+
)
219+
and (
220+
not evaluation_target
221+
or evaluation_target in shape.filter.evaluation_target
222+
)
223+
]
224+
225+
class Config:
226+
extra = "ignore"
227+
228+
229+
class EvaluationServiceConfig(Serializable):
230+
"""
231+
Root configuration class for evaluation setup including model,
232+
inference, and shape configurations.
233+
"""
234+
235+
version: Optional[str] = "1.0"
236+
kind: Optional[str] = "evaluation_service_config"
237+
report_params: Optional[ReportParams] = Field(default_factory=ReportParams)
238+
inference_params: Optional[InferenceParamsConfig] = Field(
239+
default_factory=InferenceParamsConfig
240+
)
241+
inference_model_params: Optional[InferenceModelParamsConfig] = Field(
242+
default_factory=InferenceModelParamsConfig
243+
)
244+
ui_config: Optional[UIConfig] = Field(default_factory=UIConfig)
245+
246+
def get_merged_inference_params(self, container_name: str) -> InferenceParams:
247+
"""
248+
Merges default inference params with those specific to the given container.
249+
250+
Params
251+
------
252+
container_name (str): The name of the container.
253+
254+
Returns
255+
-------
256+
InferenceParams: The merged inference parameters.
257+
"""
258+
return self.inference_params.get_merged_params(container_name=container_name)
259+
260+
def get_merged_inference_model_params(
261+
self,
262+
container_name: str,
263+
version: Optional[str] = None,
264+
) -> Dict[str, Any]:
265+
"""
266+
Gets the model parameters for a given container, version, and task, merged with the defaults.
267+
268+
Parameters
269+
----------
270+
container_name (str): The name of the container.
271+
version (Optional[str]): The specific version of the container.
272+
273+
Returns
274+
-------
275+
Dict[str, Any]: The merged model parameters.
276+
"""
277+
return self.inference_model_params.get_merged_model_params(
278+
container_name=container_name, version=version
279+
)
280+
281+
class Config:
282+
extra = "ignore"
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
"""
7+
This serves as a future template for implementing model-specific evaluation configurations.
8+
"""

ads/aqua/config/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

0 commit comments

Comments
 (0)