Skip to content

Commit b0ca718

Browse files
authored
[AQUA][Evaluate] Externalize Supported Shapes List to Global Config. (#942)
2 parents bbda290 + 9000691 commit b0ca718

File tree

11 files changed

+169
-124
lines changed

11 files changed

+169
-124
lines changed

ads/aqua/config/config.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,15 @@
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55

6-
from datetime import datetime, timedelta
76
from typing import Optional
87

9-
from cachetools import TTLCache, cached
10-
118
from ads.aqua.common.entities import ContainerSpec
129
from ads.aqua.common.utils import get_container_config
1310
from ads.aqua.config.evaluation.evaluation_service_config import EvaluationServiceConfig
1411

1512
DEFAULT_EVALUATION_CONTAINER = "odsc-llm-evaluate"
1613

1714

18-
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
1915
def evaluation_service_config(
2016
container: Optional[str] = DEFAULT_EVALUATION_CONTAINER,
2117
) -> EvaluationServiceConfig:
@@ -27,6 +23,7 @@ def evaluation_service_config(
2723
EvaluationServiceConfig: The evaluation common config.
2824
"""
2925

26+
container = container or DEFAULT_EVALUATION_CONTAINER
3027
return EvaluationServiceConfig(
3128
**get_container_config()
3229
.get(ContainerSpec.CONTAINER_SPEC, {})

ads/aqua/config/evaluation/evaluation_service_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ class EvaluationServiceConfig(Serializable):
233233
"""
234234

235235
version: Optional[str] = "1.0"
236-
kind: Optional[str] = "evaluation"
236+
kind: Optional[str] = "evaluation_service_config"
237237
report_params: Optional[ReportParams] = Field(default_factory=ReportParams)
238238
inference_params: Optional[InferenceParamsConfig] = Field(
239239
default_factory=InferenceParamsConfig

ads/aqua/evaluation/entities.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class ModelParams(DataClassSerializable):
102102
presence_penalty: Optional[float] = 0.0
103103
frequency_penalty: Optional[float] = 0.0
104104
stop: Optional[Union[str, List[str]]] = field(default_factory=list)
105+
model: Optional[str] = "odsc-llm"
105106

106107

107108
@dataclass(repr=False)

ads/aqua/evaluation/evaluation.py

Lines changed: 54 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from datetime import datetime, timedelta
1212
from pathlib import Path
1313
from threading import Lock
14-
from typing import Any, Dict, List, Union
14+
from typing import Any, Dict, List, Optional, Union
1515

1616
import oci
1717
from cachetools import TTLCache
@@ -46,6 +46,7 @@
4646
upload_local_to_os,
4747
)
4848
from ads.aqua.config.config import evaluation_service_config
49+
from ads.aqua.config.evaluation.evaluation_service_config import EvaluationServiceConfig
4950
from ads.aqua.constants import (
5051
CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
5152
EVALUATION_REPORT,
@@ -171,8 +172,19 @@ def create(
171172
f"Invalid evaluation source {create_aqua_evaluation_details.evaluation_source_id}. "
172173
"Specify either a model or model deployment id."
173174
)
175+
176+
# The model to evaluate
174177
evaluation_source = None
175-
eval_inference_configuration = None
178+
# The evaluation service config
179+
evaluation_config: EvaluationServiceConfig = evaluation_service_config()
180+
# The evaluation inference configuration. The inference configuration will be extracted
181+
# based on the inferencing container family.
182+
eval_inference_configuration: Dict = {}
183+
# The evaluation inference model sampling params. The system parameters that will not be
184+
# visible for user, but will be applied implicitly for evaluation. The service model params
185+
# will be extracted based on the container family and version.
186+
eval_inference_service_model_params: Dict = {}
187+
176188
if (
177189
DataScienceResource.MODEL_DEPLOYMENT
178190
in create_aqua_evaluation_details.evaluation_source_id
@@ -188,17 +200,32 @@ def create(
188200
runtime = ModelDeploymentContainerRuntime.from_dict(
189201
evaluation_source.runtime.to_dict()
190202
)
191-
inference_config = AquaContainerConfig.from_container_index_json(
203+
container_config = AquaContainerConfig.from_container_index_json(
192204
enable_spec=True
193-
).inference
194-
for container in inference_config.values():
195-
if container.name == runtime.image[: runtime.image.rfind(":")]:
205+
)
206+
for (
207+
inference_container_family,
208+
inference_container_info,
209+
) in container_config.inference.items():
210+
if (
211+
inference_container_info.name
212+
== runtime.image[: runtime.image.rfind(":")]
213+
):
196214
eval_inference_configuration = (
197-
container.spec.evaluation_configuration
215+
evaluation_config.get_merged_inference_params(
216+
inference_container_family
217+
).to_dict()
218+
)
219+
eval_inference_service_model_params = (
220+
evaluation_config.get_merged_inference_model_params(
221+
inference_container_family,
222+
inference_container_info.version,
223+
)
198224
)
225+
199226
except Exception:
200227
logger.debug(
201-
f"Could not load inference config details for the evaluation id: "
228+
f"Could not load inference config details for the evaluation source id: "
202229
f"{create_aqua_evaluation_details.evaluation_source_id}. Please check if the container"
203230
f" runtime has the correct SMC image information."
204231
)
@@ -415,13 +442,12 @@ def create(
415442
container_image=container_image,
416443
dataset_path=evaluation_dataset_path,
417444
report_path=create_aqua_evaluation_details.report_path,
418-
model_parameters=create_aqua_evaluation_details.model_parameters,
445+
model_parameters={
446+
**eval_inference_service_model_params,
447+
**create_aqua_evaluation_details.model_parameters,
448+
},
419449
metrics=create_aqua_evaluation_details.metrics,
420-
inference_configuration=(
421-
eval_inference_configuration.to_filtered_dict()
422-
if eval_inference_configuration
423-
else {}
424-
),
450+
inference_configuration=eval_inference_configuration or {},
425451
)
426452
).create(**kwargs) ## TODO: decide what parameters will be needed
427453
logger.debug(
@@ -1188,45 +1214,24 @@ def _delete_job_and_model(job, model):
11881214
f"Exception message: {ex}"
11891215
)
11901216

1191-
def load_evaluation_config(self):
1217+
def load_evaluation_config(self, container: Optional[str] = None) -> Dict:
11921218
"""Loads evaluation config."""
1219+
1220+
# retrieve the evaluation config by container family name
1221+
evaluation_config = evaluation_service_config(container)
1222+
1223+
# convert the new config representation to the old one
11931224
return {
1194-
"model_params": {
1195-
"max_tokens": 500,
1196-
"temperature": 0.7,
1197-
"top_p": 1.0,
1198-
"top_k": 50,
1199-
"presence_penalty": 0.0,
1200-
"frequency_penalty": 0.0,
1201-
"stop": [],
1202-
},
1225+
"model_params": evaluation_config.ui_config.model_params.default,
12031226
"shape": {
1204-
"VM.Standard.E3.Flex": {
1205-
"ocpu": 8,
1206-
"memory_in_gbs": 128,
1207-
"block_storage_size": 200,
1208-
},
1209-
"VM.Standard.E4.Flex": {
1210-
"ocpu": 8,
1211-
"memory_in_gbs": 128,
1212-
"block_storage_size": 200,
1213-
},
1214-
"VM.Standard3.Flex": {
1215-
"ocpu": 8,
1216-
"memory_in_gbs": 128,
1217-
"block_storage_size": 200,
1218-
},
1219-
"VM.Optimized3.Flex": {
1220-
"ocpu": 8,
1221-
"memory_in_gbs": 128,
1222-
"block_storage_size": 200,
1223-
},
1224-
},
1225-
"default": {
1226-
"ocpu": 8,
1227-
"memory_in_gbs": 128,
1228-
"block_storage_size": 200,
1227+
shape.name: shape.to_dict()
1228+
for shape in evaluation_config.ui_config.shapes
12291229
},
1230+
"default": (
1231+
evaluation_config.ui_config.shapes[0].to_dict()
1232+
if len(evaluation_config.ui_config.shapes) > 0
1233+
else {}
1234+
),
12301235
}
12311236

12321237
def _get_attribute_from_model_metadata(

ads/aqua/extension/evaluation_handler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Copyright (c) 2024 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5+
from typing import Optional
56
from urllib.parse import urlparse
67

78
from tornado.web import HTTPError
@@ -30,7 +31,7 @@ def get(self, eval_id=""):
3031
return self.read(eval_id)
3132

3233
@handle_exceptions
33-
def post(self, *args, **kwargs):
34+
def post(self, *args, **kwargs): # noqa
3435
"""Handles post request for the evaluation APIs
3536
3637
Raises
@@ -117,10 +118,10 @@ class AquaEvaluationConfigHandler(AquaAPIhandler):
117118
"""Handler for Aqua Evaluation Config REST APIs."""
118119

119120
@handle_exceptions
120-
def get(self, model_id):
121+
def get(self, container: Optional[str] = None, **kwargs): # noqa
121122
"""Handle GET request."""
122123

123-
return self.finish(AquaEvaluationApp().load_evaluation_config(model_id))
124+
return self.finish(AquaEvaluationApp().load_evaluation_config(container))
124125

125126

126127
__handlers__ = [

ads/aqua/ui.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,6 @@ class AquaContainerConfigSpec(DataClassSerializable):
8484
health_check_port: str = None
8585
env_vars: List[dict] = None
8686
restricted_params: List[str] = None
87-
evaluation_configuration: AquaContainerEvaluationConfig = field(
88-
default_factory=AquaContainerEvaluationConfig
89-
)
9087

9188

9289
@dataclass(repr=False)
@@ -184,32 +181,37 @@ def from_container_index_json(
184181
family=container_type,
185182
platforms=platforms,
186183
model_formats=model_formats,
187-
spec=AquaContainerConfigSpec(
188-
cli_param=container_spec.get(ContainerSpec.CLI_PARM, ""),
189-
server_port=container_spec.get(
190-
ContainerSpec.SERVER_PORT, ""
191-
),
192-
health_check_port=container_spec.get(
193-
ContainerSpec.HEALTH_CHECK_PORT, ""
194-
),
195-
env_vars=container_spec.get(ContainerSpec.ENV_VARS, []),
196-
restricted_params=container_spec.get(
197-
ContainerSpec.RESTRICTED_PARAMS, []
198-
),
199-
evaluation_configuration=AquaContainerEvaluationConfig.from_config(
200-
container_spec.get(
201-
ContainerSpec.EVALUATION_CONFIGURATION, {}
202-
)
203-
),
204-
)
205-
if container_spec
206-
else None,
184+
spec=(
185+
AquaContainerConfigSpec(
186+
cli_param=container_spec.get(
187+
ContainerSpec.CLI_PARM, ""
188+
),
189+
server_port=container_spec.get(
190+
ContainerSpec.SERVER_PORT, ""
191+
),
192+
health_check_port=container_spec.get(
193+
ContainerSpec.HEALTH_CHECK_PORT, ""
194+
),
195+
env_vars=container_spec.get(ContainerSpec.ENV_VARS, []),
196+
restricted_params=container_spec.get(
197+
ContainerSpec.RESTRICTED_PARAMS, []
198+
),
199+
)
200+
if container_spec
201+
else None
202+
),
207203
)
208204
if container.get("type") == "inference":
209205
inference_items[container_type] = container_item
210-
elif container_type == "odsc-llm-fine-tuning":
206+
elif (
207+
container.get("type") == "fine-tune"
208+
or container_type == "odsc-llm-fine-tuning"
209+
):
211210
finetune_items[container_type] = container_item
212-
elif container_type == "odsc-llm-evaluate":
211+
elif (
212+
container.get("type") == "evaluate"
213+
or container_type == "odsc-llm-evaluate"
214+
):
213215
evaluate_items[container_type] = container_item
214216

215217
return AquaContainerConfig(

tests/unitary/with_extras/aqua/test_data/config/evaluation_config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@
102102
"inference_timeout": 120
103103
}
104104
},
105-
"kind": "evaluation",
105+
"kind": "evaluation_service_config",
106106
"report_params": {
107107
"default": {}
108108
},

tests/unitary/with_extras/aqua/test_data/config/evaluation_config_with_default_params.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"containers": [],
88
"default": {}
99
},
10-
"kind": "evaluation",
10+
"kind": "evaluation_service_config",
1111
"report_params": {
1212
"default": {}
1313
},

0 commit comments

Comments
 (0)