Skip to content

Commit 11abb35

Browse files
add evaluation config params
1 parent 5b58330 commit 11abb35

File tree

6 files changed

+89
-69
lines changed

6 files changed

+89
-69
lines changed

ads/aqua/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
EVALUATION_REPORT_JSON = "report.json"
1616
EVALUATION_REPORT_MD = "report.md"
1717
EVALUATION_REPORT = "report.html"
18-
EVALUATION_INFERENCE_DEFAULT_THREADS = 10
1918
UNKNOWN_JSON_STR = "{}"
2019
FINE_TUNING_RUNTIME_CONTAINER = "iad.ocir.io/ociodscdev/aqua_ft_cuda121:0.3.17.20"
2120
DEFAULT_FT_BLOCK_STORAGE_SIZE = 750

ads/aqua/evaluation/entities.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ class CreateAquaEvaluationDetails(DataClassSerializable):
6464
The metrics for the evaluation.
6565
force_overwrite: (bool, optional). Defaults to `False`.
6666
Whether to force overwrite the existing file in object storage.
67-
inference_max_threads: (int, optional). Defaults to None
68-
Set the value of concurrent requests to be made to the inference endpoint during evaluation.
6967
"""
7068

7169
evaluation_source_id: str
@@ -87,7 +85,6 @@ class CreateAquaEvaluationDetails(DataClassSerializable):
8785
log_id: Optional[str] = None
8886
metrics: Optional[List] = None
8987
force_overwrite: Optional[bool] = False
90-
inference_max_threads: Optional[int] = None
9188

9289

9390
@dataclass(repr=False)
@@ -144,7 +141,6 @@ class AquaEvaluationCommands(DataClassSerializable):
144141
metrics: list
145142
output_dir: str
146143
params: dict
147-
inference_max_threads: int
148144

149145

150146
@dataclass(repr=False)

ads/aqua/evaluation/evaluation.py

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,15 @@
1313
from threading import Lock
1414
from typing import Any, Dict, List, Union
1515

16+
import oci
1617
from cachetools import TTLCache
18+
from oci.data_science.models import (
19+
JobRun,
20+
Metadata,
21+
UpdateModelDetails,
22+
UpdateModelProvenanceDetails,
23+
)
1724

18-
import oci
1925
from ads.aqua import logger
2026
from ads.aqua.app import AquaApp
2127
from ads.aqua.common import utils
@@ -41,7 +47,6 @@
4147
)
4248
from ads.aqua.constants import (
4349
CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
44-
EVALUATION_INFERENCE_DEFAULT_THREADS,
4550
EVALUATION_REPORT,
4651
EVALUATION_REPORT_JSON,
4752
EVALUATION_REPORT_MD,
@@ -97,12 +102,6 @@
97102
)
98103
from ads.model.model_version_set import ModelVersionSet
99104
from ads.telemetry import telemetry
100-
from oci.data_science.models import (
101-
JobRun,
102-
Metadata,
103-
UpdateModelDetails,
104-
UpdateModelProvenanceDetails,
105-
)
106105

107106

108107
class AquaEvaluationApp(AquaApp):
@@ -171,6 +170,7 @@ def create(
171170
"Specify either a model or model deployment id."
172171
)
173172
evaluation_source = None
173+
eval_inference_configuration = None
174174
if (
175175
DataScienceResource.MODEL_DEPLOYMENT
176176
in create_aqua_evaluation_details.evaluation_source_id
@@ -182,29 +182,14 @@ def create(
182182
runtime = ModelDeploymentContainerRuntime.from_dict(
183183
evaluation_source.runtime.to_dict()
184184
)
185-
container_config = AquaContainerConfig.from_container_index_json(
185+
inference_config = AquaContainerConfig.from_container_index_json(
186186
enable_spec=True
187-
)
188-
for container in container_config.inference.values():
187+
).inference
188+
for container in inference_config.values():
189189
if container.name == runtime.image.split(":")[0]:
190-
max_threads = container.spec.evaluation_configuration.evaluation_max_threads
191-
if (
192-
max_threads
193-
and create_aqua_evaluation_details.inference_max_threads
194-
and max_threads
195-
< create_aqua_evaluation_details.inference_max_threads
196-
):
197-
raise AquaValueError(
198-
f"Invalid inference max threads. The maximum allowed value for {runtime.image} is {max_threads}."
199-
)
200-
if not create_aqua_evaluation_details.inference_max_threads:
201-
create_aqua_evaluation_details.inference_max_threads = container.spec.evaluation_configuration.evaluation_default_threads
202-
break
203-
if not create_aqua_evaluation_details.inference_max_threads:
204-
create_aqua_evaluation_details.inference_max_threads = (
205-
EVALUATION_INFERENCE_DEFAULT_THREADS
206-
)
207-
190+
eval_inference_configuration = (
191+
container.spec.evaluation_configuration
192+
)
208193
elif (
209194
DataScienceResource.MODEL
210195
in create_aqua_evaluation_details.evaluation_source_id
@@ -420,7 +405,9 @@ def create(
420405
report_path=create_aqua_evaluation_details.report_path,
421406
model_parameters=create_aqua_evaluation_details.model_parameters,
422407
metrics=create_aqua_evaluation_details.metrics,
423-
inference_max_threads=create_aqua_evaluation_details.inference_max_threads,
408+
inference_configuration=eval_inference_configuration.to_filtered_dict()
409+
if eval_inference_configuration
410+
else {},
424411
)
425412
).create(**kwargs) ## TODO: decide what parameters will be needed
426413
logger.debug(
@@ -542,7 +529,7 @@ def _build_evaluation_runtime(
542529
report_path: str,
543530
model_parameters: dict,
544531
metrics: List = None,
545-
inference_max_threads: int = None,
532+
inference_configuration: dict = None,
546533
) -> Runtime:
547534
"""Builds evaluation runtime for Job."""
548535
# TODO the image name needs to be extracted from the mapping index.json file.
@@ -552,17 +539,19 @@ def _build_evaluation_runtime(
552539
.with_environment_variable(
553540
**{
554541
"AIP_SMC_EVALUATION_ARGUMENTS": json.dumps(
555-
asdict(
556-
self._build_launch_cmd(
557-
evaluation_id=evaluation_id,
558-
evaluation_source_id=evaluation_source_id,
559-
dataset_path=dataset_path,
560-
report_path=report_path,
561-
model_parameters=model_parameters,
562-
metrics=metrics,
563-
inference_max_threads=inference_max_threads,
564-
)
565-
)
542+
{
543+
**asdict(
544+
self._build_launch_cmd(
545+
evaluation_id=evaluation_id,
546+
evaluation_source_id=evaluation_source_id,
547+
dataset_path=dataset_path,
548+
report_path=report_path,
549+
model_parameters=model_parameters,
550+
metrics=metrics,
551+
),
552+
),
553+
**inference_configuration,
554+
},
566555
),
567556
"CONDA_BUCKET_NS": CONDA_BUCKET_NS,
568557
},
@@ -620,7 +609,6 @@ def _build_launch_cmd(
620609
report_path: str,
621610
model_parameters: dict,
622611
metrics: List = None,
623-
inference_max_threads: int = None,
624612
):
625613
return AquaEvaluationCommands(
626614
evaluation_id=evaluation_id,
@@ -637,7 +625,6 @@ def _build_launch_cmd(
637625
metrics=metrics,
638626
output_dir=report_path,
639627
params=model_parameters,
640-
inference_max_threads=inference_max_threads,
641628
)
642629

643630
@telemetry(entry_point="plugin=evaluation&action=get", name="aqua")
@@ -1227,7 +1214,7 @@ def _delete_job_and_model(job, model):
12271214
f"Exception message: {ex}"
12281215
)
12291216

1230-
def load_evaluation_config(self, _):
1217+
def load_evaluation_config(self, eval_id):
12311218
"""Loads evaluation config."""
12321219
return {
12331220
"model_params": {

ads/aqua/ui.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,22 @@
22
# Copyright (c) 2024 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
import concurrent.futures
5-
from dataclasses import dataclass, field
5+
from dataclasses import dataclass, field, fields
66
from datetime import datetime, timedelta
77
from enum import Enum
88
from threading import Lock
99
from typing import Dict, List, Optional
1010

1111
from cachetools import TTLCache
12+
from oci.exceptions import ServiceError
13+
from oci.identity.models import Compartment
1214

1315
from ads.aqua import logger
1416
from ads.aqua.app import AquaApp
1517
from ads.aqua.common.entities import ContainerSpec
1618
from ads.aqua.common.enums import Tags
1719
from ads.aqua.common.errors import AquaResourceAccessError, AquaValueError
1820
from ads.aqua.common.utils import get_container_config, load_config, sanitize_response
19-
from ads.aqua.constants import EVALUATION_INFERENCE_DEFAULT_THREADS
2021
from ads.common import oci_client as oc
2122
from ads.common.auth import default_signer
2223
from ads.common.object_storage_details import ObjectStorageDetails
@@ -29,8 +30,6 @@
2930
TENANCY_OCID,
3031
)
3132
from ads.telemetry import telemetry
32-
from oci.exceptions import ServiceError
33-
from oci.identity.models import Compartment
3433

3534

3635
class ModelFormat(Enum):
@@ -47,25 +46,36 @@ def to_dict(self):
4746

4847

4948
@dataclass(repr=False)
50-
class AquaContainerEvaluationConfiguration(DataClassSerializable):
49+
class AquaContainerEvaluationConfig(DataClassSerializable):
5150
"""
5251
Represents the evaluation configuration for the container.
5352
"""
5453

55-
evaluation_max_threads: Optional[int] = None
56-
evaluation_default_threads: int = field(
57-
default=EVALUATION_INFERENCE_DEFAULT_THREADS
58-
)
54+
inference_max_threads: Optional[int] = None
55+
inference_rps: Optional[int] = None
56+
inference_timeout: Optional[int] = None
57+
inference_retries: Optional[int] = None
58+
inference_backoff_factor: Optional[int] = None
59+
inference_delay: Optional[int] = None
5960

6061
@classmethod
61-
def from_config(cls, config: dict) -> "AquaContainerEvaluationConfiguration":
62+
def from_config(cls, config: dict) -> "AquaContainerEvaluationConfig":
6263
return cls(
63-
evaluation_max_threads=config.get("MAX_THREADS"),
64-
evaluation_default_threads=config.get(
65-
"DEFAULT_THREADS", EVALUATION_INFERENCE_DEFAULT_THREADS
66-
),
64+
inference_max_threads=config.get("inference_max_threads"),
65+
inference_rps=config.get("inference_rps"),
66+
inference_timeout=config.get("inference_timeout"),
67+
inference_retries=config.get("inference_retries"),
68+
inference_backoff_factor=config.get("inference_backoff_factor"),
69+
inference_delay=config.get("inference_delay"),
6770
)
6871

72+
def to_filtered_dict(self):
73+
return {
74+
field.name: getattr(self, field.name)
75+
for field in fields(self)
76+
if getattr(self, field.name) is not None
77+
}
78+
6979

7080
@dataclass(repr=False)
7181
class AquaContainerConfigSpec(DataClassSerializable):
@@ -74,8 +84,8 @@ class AquaContainerConfigSpec(DataClassSerializable):
7484
health_check_port: str = None
7585
env_vars: List[dict] = None
7686
restricted_params: List[str] = None
77-
evaluation_configuration: AquaContainerEvaluationConfiguration = field(
78-
default_factory=AquaContainerEvaluationConfiguration
87+
evaluation_configuration: AquaContainerEvaluationConfig = field(
88+
default_factory=AquaContainerEvaluationConfig
7989
)
8090

8191

@@ -186,7 +196,7 @@ def from_container_index_json(
186196
restricted_params=container_spec.get(
187197
ContainerSpec.RESTRICTED_PARAMS, []
188198
),
189-
evaluation_configuration=AquaContainerEvaluationConfiguration.from_config(
199+
evaluation_configuration=AquaContainerEvaluationConfig.from_config(
190200
container_spec.get(
191201
ContainerSpec.EVALUATION_CONFIGURATION, {}
192202
)

tests/unitary/with_extras/aqua/test_data/ui/container_index.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
"HEALTH_CHECK_PORT": "8080"
2020
}
2121
],
22+
"evaluationConfiguration": {
23+
"inference_delay": 1,
24+
"inference_max_threads": 1
25+
},
2226
"healthCheckPort": "8080",
2327
"restrictedParams": [],
2428
"serverPort": "8080"
@@ -81,7 +85,7 @@
8185
"modelFormats": [
8286
"GGUF"
8387
],
84-
"name": "iad.ocir.io/ociodscdev/odsc-llama-cpp-python-aio-linux_arm64_v8",
88+
"name": "dsmc://odsc-llama-cpp-python-aio-linux_arm64_v8",
8589
"platforms": [
8690
"ARM_CPU"
8791
],

tests/unitary/with_extras/aqua/test_ui.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def test_list_containers(self, mock_get_container_config):
484484
],
485485
"inference": [
486486
{
487-
"name": "iad.ocir.io/ociodscdev/odsc-llama-cpp-python-aio-linux_arm64_v8",
487+
"name": "dsmc://odsc-llama-cpp-python-aio-linux_arm64_v8",
488488
"version": "0.2.75.5",
489489
"display_name": "LLAMA-CPP:0.2.75",
490490
"family": "odsc-llama-cpp-serving",
@@ -502,6 +502,14 @@ def test_list_containers(self, mock_get_container_config):
502502
"health_check_port": "8080",
503503
"restricted_params": [],
504504
"server_port": "8080",
505+
"evaluation_configuration": {
506+
"inference_max_threads": 1,
507+
"inference_rps": None,
508+
"inference_timeout": None,
509+
"inference_backoff_factor": None,
510+
"inference_delay": 1,
511+
"inference_retries": None,
512+
},
505513
},
506514
},
507515
{
@@ -528,6 +536,14 @@ def test_list_containers(self, mock_get_container_config):
528536
"--trust-remote-code",
529537
],
530538
"server_port": "8080",
539+
"evaluation_configuration": {
540+
"inference_max_threads": None,
541+
"inference_rps": None,
542+
"inference_timeout": None,
543+
"inference_backoff_factor": None,
544+
"inference_delay": None,
545+
"inference_retries": None,
546+
},
531547
},
532548
},
533549
{
@@ -553,6 +569,14 @@ def test_list_containers(self, mock_get_container_config):
553569
"--seed",
554570
],
555571
"server_port": "8080",
572+
"evaluation_configuration": {
573+
"inference_max_threads": None,
574+
"inference_rps": None,
575+
"inference_timeout": None,
576+
"inference_backoff_factor": None,
577+
"inference_delay": None,
578+
"inference_retries": None,
579+
},
556580
},
557581
},
558582
],

0 commit comments

Comments
 (0)