Skip to content

Commit 56abb06

Browse files
committed
Added the EvaluationServiceConfig class along with supporting structures to manage evaluation service configurations.
1 parent 6b31d0f commit 56abb06

15 files changed

+1244
-4
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,6 @@ logs/
163163

164164
# Python Wheel
165165
*.whl
166+
167+
# The demo folder
168+
.demo

ads/aqua/common/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -536,14 +536,14 @@ def _build_job_identifier(
536536
return AquaResourceIdentifier()
537537

538538

539-
def container_config_path():
539+
def service_config_path():
540540
return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
541541

542542

543543
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
544544
def get_container_config():
545545
config = load_config(
546-
file_path=container_config_path(),
546+
file_path=service_config_path(),
547547
config_file_name=CONTAINER_INDEX,
548548
)
549549

@@ -568,7 +568,7 @@ def get_container_image(
568568
"""
569569

570570
config = config_file_name or get_container_config()
571-
config_file_name = container_config_path()
571+
config_file_name = service_config_path()
572572

573573
if container_type not in config:
574574
raise AquaValueError(

ads/aqua/config/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

ads/aqua/config/config.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,30 @@
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55

6+
from datetime import datetime, timedelta
7+
8+
from cachetools import TTLCache, cached
9+
10+
from ads.aqua.common.utils import service_config_path
11+
from ads.aqua.config.evaluation.evaluation_service_config import EvaluationServiceConfig
12+
from ads.aqua.constants import EVALUATION_SERVICE_CONFIG
13+
14+
15+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
16+
def evaluation_service_config() -> EvaluationServiceConfig:
17+
"""
18+
Retrieves the common evaluation configuration.
19+
20+
Returns
21+
-------
22+
EvaluationServiceConfig: The evaluation common config.
23+
"""
24+
25+
return EvaluationServiceConfig.from_json(
26+
uri=f"{service_config_path()}/{EVALUATION_SERVICE_CONFIG}"
27+
)
28+
29+
630
# TODO: move this to global config.json in object storage
731
def get_finetuning_config_defaults():
832
"""Generate and return the fine-tuning default configuration dictionary."""
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
from copy import deepcopy
7+
from typing import Any, Dict, List, Optional, Union
8+
9+
from pydantic import Field
10+
11+
from ads.aqua.config.utils.serializer import Serializable
12+
13+
# Constants
14+
INFERENCE_RPS = 25 # Max RPS for inferencing deployed model.
15+
INFERENCE_TIMEOUT = 120
16+
INFERENCE_MAX_THREADS = 10 # Maximum parallel threads for model inference.
17+
INFERENCE_RETRIES = 3
18+
INFERENCE_BACKOFF_FACTOR = 3
19+
INFERENCE_DELAY = 0
20+
21+
22+
class ModelParamItem(Serializable):
23+
"""Represents min, max, and default values for a model parameter."""
24+
25+
min: Optional[Union[int, float]] = None
26+
max: Optional[Union[int, float]] = None
27+
default: Optional[Union[int, float]] = None
28+
29+
class Config:
30+
extra = "ignore"
31+
32+
33+
class ModelParamsOverrides(Serializable):
34+
"""Defines overrides for model parameters, including exclusions and additional inclusions."""
35+
36+
exclude: Optional[List[str]] = Field(default_factory=list)
37+
include: Optional[Dict[str, Any]] = Field(default_factory=dict)
38+
39+
class Config:
40+
extra = "ignore"
41+
42+
43+
class ModelParamsVersion(Serializable):
44+
"""Handles version-specific model parameter overrides."""
45+
46+
overrides: Optional[ModelParamsOverrides] = Field(
47+
default_factory=ModelParamsOverrides
48+
)
49+
50+
class Config:
51+
extra = "ignore"
52+
53+
54+
class ModelDefaultParams(Serializable):
55+
"""Defines default parameters for a model within a specific framework."""
56+
57+
model: Optional[str] = None
58+
max_tokens: Optional[ModelParamItem] = Field(default_factory=ModelParamItem)
59+
temperature: Optional[ModelParamItem] = Field(default_factory=ModelParamItem)
60+
top_p: Optional[ModelParamItem] = Field(default_factory=ModelParamItem)
61+
top_k: Optional[ModelParamItem] = Field(default_factory=ModelParamItem)
62+
presence_penalty: Optional[ModelParamItem] = Field(default_factory=ModelParamItem)
63+
frequency_penalty: Optional[ModelParamItem] = Field(default_factory=ModelParamItem)
64+
stop: List[str] = Field(default_factory=list)
65+
66+
class Config:
67+
extra = "allow"
68+
69+
70+
class ModelFramework(Serializable):
71+
"""Represents a framework's model configuration, including tasks, defaults, and versions."""
72+
73+
framework: Optional[str] = None
74+
task: Optional[List[str]] = Field(default_factory=list)
75+
default: Optional[ModelDefaultParams] = Field(default_factory=ModelDefaultParams)
76+
versions: Optional[Dict[str, ModelParamsVersion]] = Field(default_factory=dict)
77+
78+
class Config:
79+
extra = "ignore"
80+
81+
82+
class InferenceParams(Serializable):
83+
"""Contains inference-related parameters with defaults."""
84+
85+
inference_rps: Optional[int] = INFERENCE_RPS
86+
inference_timeout: Optional[int] = INFERENCE_TIMEOUT
87+
inference_max_threads: Optional[int] = INFERENCE_MAX_THREADS
88+
inference_retries: Optional[int] = INFERENCE_RETRIES
89+
inference_backoff_factor: Optional[float] = INFERENCE_BACKOFF_FACTOR
90+
inference_delay: Optional[float] = INFERENCE_DELAY
91+
92+
class Config:
93+
extra = "allow"
94+
95+
96+
class InferenceFramework(Serializable):
97+
"""Represents the inference parameters specific to a framework."""
98+
99+
framework: Optional[str] = None
100+
params: Optional[Dict[str, Any]] = Field(default_factory=dict)
101+
102+
class Config:
103+
extra = "ignore"
104+
105+
106+
class ReportParams(Serializable):
107+
"""Handles the report-related parameters."""
108+
109+
default: Optional[Dict[str, Any]] = Field(default_factory=dict)
110+
111+
class Config:
112+
extra = "ignore"
113+
114+
115+
class InferenceParamsConfig(Serializable):
116+
"""Combines default inference parameters with framework-specific configurations."""
117+
118+
default: Optional[InferenceParams] = Field(default_factory=InferenceParams)
119+
frameworks: Optional[List[InferenceFramework]] = Field(default_factory=list)
120+
121+
def get_merged_params(self, framework_name: str) -> InferenceParams:
122+
"""
123+
Merges default inference params with those specific to the given framework.
124+
125+
Parameters
126+
----------
127+
framework_name (str): The name of the framework.
128+
129+
Returns
130+
-------
131+
InferenceParams: The merged inference parameters.
132+
"""
133+
merged_params = self.default.to_dict()
134+
for framework in self.frameworks:
135+
if framework.framework.lower() == framework_name.lower():
136+
merged_params.update(framework.params or {})
137+
break
138+
return InferenceParams(**merged_params)
139+
140+
class Config:
141+
extra = "ignore"
142+
143+
144+
class ModelParamsConfig(Serializable):
145+
"""Encapsulates the model parameters for different frameworks."""
146+
147+
default: Optional[Dict[str, Any]] = Field(default_factory=dict)
148+
frameworks: Optional[List[ModelFramework]] = Field(default_factory=list)
149+
150+
def get_model_params(
151+
self,
152+
framework_name: str,
153+
version: Optional[str] = None,
154+
task: Optional[str] = None,
155+
) -> Dict[str, Any]:
156+
"""
157+
Gets the model parameters for a given framework, version, and tasks,
158+
merged with the defaults.
159+
160+
Parameters
161+
----------
162+
framework_name (str): The name of the framework.
163+
version (Optional[str]): The specific version of the framework.
164+
task (Optional[str]): The specific task.
165+
166+
Returns
167+
-------
168+
Dict[str, Any]: The merged model parameters.
169+
"""
170+
params = deepcopy(self.default)
171+
172+
for framework in self.frameworks:
173+
if framework.framework.lower() == framework_name.lower() and (
174+
not task or task.lower() in framework.task
175+
):
176+
params.update(framework.default.to_dict())
177+
178+
if version and version in framework.versions:
179+
version_overrides = framework.versions[version].overrides
180+
if version_overrides:
181+
if version_overrides.include:
182+
params.update(version_overrides.include)
183+
if version_overrides.exclude:
184+
for key in version_overrides.exclude:
185+
params.pop(key, None)
186+
break
187+
188+
return params
189+
190+
class Config:
191+
extra = "ignore"
192+
193+
194+
class ShapeFilterConfig(Serializable):
195+
"""Represents the filtering options for a specific shape."""
196+
197+
evaluation_container: Optional[List[str]] = Field(default_factory=list)
198+
evaluation_target: Optional[List[str]] = Field(default_factory=list)
199+
200+
class Config:
201+
extra = "ignore"
202+
203+
204+
class ShapeConfig(Serializable):
205+
"""Defines the configuration for a specific shape."""
206+
207+
name: Optional[str] = None
208+
ocpu: Optional[int] = None
209+
memory_in_gbs: Optional[int] = None
210+
block_storage_size: Optional[int] = None
211+
filter: Optional[ShapeFilterConfig] = Field(default_factory=ShapeFilterConfig)
212+
213+
class Config:
214+
extra = "allow"
215+
216+
217+
class MetricConfig(Serializable):
218+
"""Handles metric configuration including task, key, and additional details."""
219+
220+
task: Optional[List[str]] = Field(default_factory=list)
221+
key: Optional[str] = None
222+
name: Optional[str] = None
223+
description: Optional[str] = None
224+
args: Optional[Dict[str, Any]] = Field(default_factory=dict)
225+
tags: Optional[List[str]] = Field(default_factory=list)
226+
227+
class Config:
228+
extra = "ignore"
229+
230+
231+
class EvaluationServiceConfig(Serializable):
232+
"""
233+
Root configuration class for evaluation setup including model,
234+
inference, and shape configurations.
235+
"""
236+
237+
version: Optional[str] = "1.0"
238+
kind: Optional[str] = "evaluation"
239+
report_params: Optional[ReportParams] = Field(default_factory=ReportParams)
240+
inference_params: Optional[InferenceParamsConfig] = Field(
241+
default_factory=InferenceParamsConfig
242+
)
243+
model_params: Optional[ModelParamsConfig] = Field(default_factory=ModelParamsConfig)
244+
shapes: List[ShapeConfig] = Field(default_factory=list)
245+
metrics: List[MetricConfig] = Field(default_factory=list)
246+
247+
def get_merged_inference_params(self, framework_name: str) -> InferenceParams:
248+
"""
249+
Merges default inference params with those specific to the given framework.
250+
251+
Params
252+
------
253+
framework_name (str): The name of the framework.
254+
255+
Returns
256+
-------
257+
InferenceParams: The merged inference parameters.
258+
"""
259+
return self.inference_params.get_merged_params(framework_name=framework_name)
260+
261+
def get_merged_model_params(
262+
self,
263+
framework_name: str,
264+
version: Optional[str] = None,
265+
task: Optional[str] = None,
266+
) -> Dict[str, Any]:
267+
"""
268+
Gets the model parameters for a given framework, version, and task, merged with the defaults.
269+
270+
Parameters
271+
----------
272+
framework_name (str): The name of the framework.
273+
version (Optional[str]): The specific version of the framework.
274+
task (Optional[str]): The task.
275+
276+
Returns
277+
-------
278+
Dict[str, Any]: The merged model parameters.
279+
"""
280+
return self.model_params.get_model_params(
281+
framework_name=framework_name, version=version, task=task
282+
)
283+
284+
def search_shapes(
285+
self,
286+
evaluation_container: Optional[str] = None,
287+
evaluation_target: Optional[str] = None,
288+
) -> List[ShapeConfig]:
289+
"""
290+
Searches for shapes that match the given filters.
291+
292+
Parameters
293+
----------
294+
evaluation_container (Optional[str]): Filter for evaluation_container.
295+
evaluation_target (Optional[str]): Filter for evaluation_target.
296+
297+
Returns
298+
-------
299+
List[ShapeConfig]: A list of shapes that match the filters.
300+
"""
301+
results = []
302+
for shape in self.shapes:
303+
if (
304+
evaluation_container
305+
and evaluation_container not in shape.filter.evaluation_container
306+
):
307+
continue
308+
if (
309+
evaluation_target
310+
and evaluation_target not in shape.filter.evaluation_target
311+
):
312+
continue
313+
results.append(shape)
314+
return results
315+
316+
class Config:
317+
extra = "ignore"

0 commit comments

Comments
 (0)