Skip to content

Commit 88d6feb

Browse files
authored
[AQUA] Refactor evaluation service config to remove redundant information. (#1105)
2 parents 788920c + 6f607b2 commit 88d6feb

File tree

5 files changed

+6
-388
lines changed

5 files changed

+6
-388
lines changed
Lines changed: 5 additions & 181 deletions
Original file line numberDiff line numberDiff line change
@@ -1,157 +1,23 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

6-
from copy import deepcopy
76
from typing import Any, Dict, List, Optional
87

98
from pydantic import Field
109

1110
from ads.aqua.config.utils.serializer import Serializable
1211

1312

14-
class ModelParamsOverrides(Serializable):
15-
"""Defines overrides for model parameters, including exclusions and additional inclusions."""
16-
17-
exclude: Optional[List[str]] = Field(default_factory=list)
18-
include: Optional[Dict[str, Any]] = Field(default_factory=dict)
19-
20-
class Config:
21-
extra = "ignore"
22-
23-
24-
class ModelParamsVersion(Serializable):
25-
"""Handles version-specific model parameter overrides."""
26-
27-
overrides: Optional[ModelParamsOverrides] = Field(
28-
default_factory=ModelParamsOverrides
29-
)
30-
31-
class Config:
32-
extra = "ignore"
33-
34-
35-
class ModelParamsContainer(Serializable):
36-
"""Represents a container's model configuration, including tasks, defaults, and versions."""
37-
38-
name: Optional[str] = None
39-
default: Optional[Dict[str, Any]] = Field(default_factory=dict)
40-
versions: Optional[Dict[str, ModelParamsVersion]] = Field(default_factory=dict)
41-
42-
class Config:
43-
extra = "ignore"
44-
45-
46-
class InferenceParams(Serializable):
47-
"""Contains inference-related parameters with defaults."""
48-
49-
class Config:
50-
extra = "allow"
51-
52-
53-
class InferenceContainer(Serializable):
54-
"""Represents the inference parameters specific to a container."""
55-
56-
name: Optional[str] = None
57-
params: Optional[Dict[str, Any]] = Field(default_factory=dict)
58-
59-
class Config:
60-
extra = "ignore"
61-
62-
63-
class ReportParams(Serializable):
64-
"""Handles the report-related parameters."""
65-
66-
default: Optional[Dict[str, Any]] = Field(default_factory=dict)
67-
68-
class Config:
69-
extra = "ignore"
70-
71-
72-
class InferenceParamsConfig(Serializable):
73-
"""Combines default inference parameters with container-specific configurations."""
74-
75-
default: Optional[InferenceParams] = Field(default_factory=InferenceParams)
76-
containers: Optional[List[InferenceContainer]] = Field(default_factory=list)
77-
78-
def get_merged_params(self, container_name: str) -> InferenceParams:
79-
"""
80-
Merges default inference params with those specific to the given container.
81-
82-
Parameters
83-
----------
84-
container_name (str): The name of the container.
85-
86-
Returns
87-
-------
88-
InferenceParams: The merged inference parameters.
89-
"""
90-
merged_params = self.default.to_dict()
91-
for containers in self.containers:
92-
if containers.name.lower() == container_name.lower():
93-
merged_params.update(containers.params or {})
94-
break
95-
return InferenceParams(**merged_params)
96-
97-
class Config:
98-
extra = "ignore"
99-
100-
101-
class InferenceModelParamsConfig(Serializable):
102-
"""Encapsulates the model parameters for different containers."""
103-
104-
default: Optional[Dict[str, Any]] = Field(default_factory=dict)
105-
containers: Optional[List[ModelParamsContainer]] = Field(default_factory=list)
106-
107-
def get_merged_model_params(
108-
self,
109-
container_name: str,
110-
version: Optional[str] = None,
111-
) -> Dict[str, Any]:
112-
"""
113-
Gets the model parameters for a given container, version,
114-
merged with the defaults.
115-
116-
Parameters
117-
----------
118-
container_name (str): The name of the container.
119-
version (Optional[str]): The specific version of the container.
120-
121-
Returns
122-
-------
123-
Dict[str, Any]: The merged model parameters.
124-
"""
125-
params = deepcopy(self.default)
126-
127-
for container in self.containers:
128-
if container.name.lower() == container_name.lower():
129-
params.update(container.default)
130-
131-
if version and version in container.versions:
132-
version_overrides = container.versions[version].overrides
133-
if version_overrides:
134-
if version_overrides.include:
135-
params.update(version_overrides.include)
136-
if version_overrides.exclude:
137-
for key in version_overrides.exclude:
138-
params.pop(key, None)
139-
break
140-
141-
return params
142-
143-
class Config:
144-
extra = "ignore"
145-
146-
14713
class ShapeFilterConfig(Serializable):
14814
"""Represents the filtering options for a specific shape."""
14915

15016
evaluation_container: Optional[List[str]] = Field(default_factory=list)
15117
evaluation_target: Optional[List[str]] = Field(default_factory=list)
15218

15319
class Config:
154-
extra = "ignore"
20+
extra = "allow"
15521

15622

15723
class ShapeConfig(Serializable):
@@ -178,7 +44,7 @@ class MetricConfig(Serializable):
17844
tags: Optional[List[str]] = Field(default_factory=list)
17945

18046
class Config:
181-
extra = "ignore"
47+
extra = "allow"
18248

18349

18450
class ModelParamsConfig(Serializable):
@@ -223,7 +89,7 @@ def search_shapes(
22389
]
22490

22591
class Config:
226-
extra = "ignore"
92+
extra = "allow"
22793
protected_namespaces = ()
22894

22995

@@ -235,49 +101,7 @@ class EvaluationServiceConfig(Serializable):
235101

236102
version: Optional[str] = "1.0"
237103
kind: Optional[str] = "evaluation_service_config"
238-
report_params: Optional[ReportParams] = Field(default_factory=ReportParams)
239-
inference_params: Optional[InferenceParamsConfig] = Field(
240-
default_factory=InferenceParamsConfig
241-
)
242-
inference_model_params: Optional[InferenceModelParamsConfig] = Field(
243-
default_factory=InferenceModelParamsConfig
244-
)
245104
ui_config: Optional[UIConfig] = Field(default_factory=UIConfig)
246105

247-
def get_merged_inference_params(self, container_name: str) -> InferenceParams:
248-
"""
249-
Merges default inference params with those specific to the given container.
250-
251-
Params
252-
------
253-
container_name (str): The name of the container.
254-
255-
Returns
256-
-------
257-
InferenceParams: The merged inference parameters.
258-
"""
259-
return self.inference_params.get_merged_params(container_name=container_name)
260-
261-
def get_merged_inference_model_params(
262-
self,
263-
container_name: str,
264-
version: Optional[str] = None,
265-
) -> Dict[str, Any]:
266-
"""
267-
Gets the model parameters for a given container, version, and task, merged with the defaults.
268-
269-
Parameters
270-
----------
271-
container_name (str): The name of the container.
272-
version (Optional[str]): The specific version of the container.
273-
274-
Returns
275-
-------
276-
Dict[str, Any]: The merged model parameters.
277-
"""
278-
return self.inference_model_params.get_merged_model_params(
279-
container_name=container_name, version=version
280-
)
281-
282106
class Config:
283-
extra = "ignore"
107+
extra = "allow"

ads/aqua/config/evaluation/evaluation_service_model_config.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

tests/unitary/with_extras/aqua/test_data/config/evaluation_config.json

Lines changed: 0 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,5 @@
11
{
2-
"inference_model_params": {
3-
"containers": [
4-
{
5-
"default": {
6-
"add_generation_prompt": false
7-
},
8-
"name": "odsc-vllm-serving",
9-
"versions": {
10-
"0.5.1": {
11-
"overrides": {
12-
"exclude": [
13-
"max_tokens",
14-
"frequency_penalty"
15-
],
16-
"include": {
17-
"some_other_param": "some_other_param_value"
18-
}
19-
}
20-
},
21-
"0.5.3.post1": {
22-
"overrides": {
23-
"exclude": [
24-
"add_generation_prompt"
25-
],
26-
"include": {}
27-
}
28-
}
29-
}
30-
},
31-
{
32-
"default": {
33-
"add_generation_prompt": false
34-
},
35-
"name": "odsc-tgi-serving",
36-
"versions": {
37-
"2.0.1.4": {
38-
"overrides": {
39-
"exclude": [
40-
"max_tokens",
41-
"frequency_penalty"
42-
],
43-
"include": {
44-
"some_other_param": "some_other_param_value"
45-
}
46-
}
47-
}
48-
}
49-
},
50-
{
51-
"default": {
52-
"add_generation_prompt": false
53-
},
54-
"name": "odsc-llama-cpp-serving",
55-
"versions": {
56-
"0.2.78.0": {
57-
"overrides": {
58-
"exclude": [],
59-
"include": {}
60-
}
61-
}
62-
}
63-
}
64-
],
65-
"default": {
66-
"add_generation_prompt": false,
67-
"frequency_penalty": 0.0,
68-
"max_tokens": 500,
69-
"model": "odsc-llm",
70-
"presence_penalty": 0.0,
71-
"some_default_param": "some_default_param",
72-
"stop": [],
73-
"temperature": 0.7,
74-
"top_k": 50,
75-
"top_p": 0.9
76-
}
77-
},
78-
"inference_params": {
79-
"containers": [
80-
{
81-
"name": "odsc-vllm-serving",
82-
"params": {}
83-
},
84-
{
85-
"name": "odsc-tgi-serving",
86-
"params": {}
87-
},
88-
{
89-
"name": "odsc-llama-cpp-serving",
90-
"params": {
91-
"inference_delay": 1,
92-
"inference_max_threads": 1
93-
}
94-
}
95-
],
96-
"default": {
97-
"inference_backoff_factor": 3,
98-
"inference_delay": 0,
99-
"inference_max_threads": 10,
100-
"inference_retries": 3,
101-
"inference_rps": 25,
102-
"inference_timeout": 120
103-
}
104-
},
1052
"kind": "evaluation_service_config",
106-
"report_params": {
107-
"default": {}
108-
},
1093
"ui_config": {
1104
"metrics": [
1115
{

tests/unitary/with_extras/aqua/test_data/config/evaluation_config_with_default_params.json

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,5 @@
11
{
2-
"inference_model_params": {
3-
"containers": [],
4-
"default": {}
5-
},
6-
"inference_params": {
7-
"containers": [],
8-
"default": {}
9-
},
102
"kind": "evaluation_service_config",
11-
"report_params": {
12-
"default": {}
13-
},
143
"ui_config": {
154
"metrics": [],
165
"model_params": {

0 commit comments

Comments
 (0)