Skip to content

Commit 993f1db

Browse files
committed
Merge branch 'ODSC-61884/global_evaluation_config' of https://github.com/oracle/accelerated-data-science into ODSC-61986/evaluation_supported_shapes
2 parents b868f2d + bbda290 commit 993f1db

File tree

5 files changed

+41
-96
lines changed

5 files changed

+41
-96
lines changed

ads/aqua/config/evaluation/evaluation_service_config.py

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,6 @@
1010

1111
from ads.aqua.config.utils.serializer import Serializable
1212

13-
# Constants
14-
INFERENCE_RPS = 25 # Max RPS for inferencing deployed model.
15-
INFERENCE_TIMEOUT = 120
16-
INFERENCE_MAX_THREADS = 10 # Maximum parallel threads for model inference.
17-
INFERENCE_RETRIES = 3
18-
INFERENCE_BACKOFF_FACTOR = 3
19-
INFERENCE_DELAY = 0
20-
2113

2214
class ModelParamsOverrides(Serializable):
2315
"""Defines overrides for model parameters, including exclusions and additional inclusions."""
@@ -54,13 +46,6 @@ class Config:
5446
class InferenceParams(Serializable):
5547
"""Contains inference-related parameters with defaults."""
5648

57-
inference_rps: Optional[int] = INFERENCE_RPS
58-
inference_timeout: Optional[int] = INFERENCE_TIMEOUT
59-
inference_max_threads: Optional[int] = INFERENCE_MAX_THREADS
60-
inference_retries: Optional[int] = INFERENCE_RETRIES
61-
inference_backoff_factor: Optional[float] = INFERENCE_BACKOFF_FACTOR
62-
inference_delay: Optional[float] = INFERENCE_DELAY
63-
6449
class Config:
6550
extra = "allow"
6651

@@ -224,20 +209,18 @@ def search_shapes(
224209
-------
225210
List[ShapeConfig]: A list of shapes that match the filters.
226211
"""
227-
results = []
228-
for shape in self.shapes:
229-
if (
230-
evaluation_container
231-
and evaluation_container not in shape.filter.evaluation_container
232-
):
233-
continue
212+
return [
213+
shape
214+
for shape in self.shapes
234215
if (
235-
evaluation_target
236-
and evaluation_target not in shape.filter.evaluation_target
237-
):
238-
continue
239-
results.append(shape)
240-
return results
216+
not evaluation_container
217+
or evaluation_container in shape.filter.evaluation_container
218+
)
219+
and (
220+
not evaluation_target
221+
or evaluation_target in shape.filter.evaluation_target
222+
)
223+
]
241224

242225
class Config:
243226
extra = "ignore"

ads/aqua/evaluation/evaluation.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -930,48 +930,8 @@ def get_status(self, eval_id: str) -> dict:
930930

931931
def get_supported_metrics(self) -> dict:
932932
"""Gets a list of supported metrics for evaluation."""
933-
# TODO: implement it when starting to support more metrics.
934933
return [
935-
{
936-
"use_case": ["text_generation"],
937-
"key": "bertscore",
938-
"name": "bertscore",
939-
"description": (
940-
"BERT Score is a metric for evaluating the quality of text "
941-
"generation models, such as machine translation or summarization. "
942-
"It utilizes pre-trained BERT contextual embeddings for both the "
943-
"generated and reference texts, and then calculates the cosine "
944-
"similarity between these embeddings."
945-
),
946-
"args": {},
947-
},
948-
{
949-
"use_case": ["text_generation"],
950-
"key": "rouge",
951-
"name": "rouge",
952-
"description": (
953-
"ROUGE scores compare a candidate document to a collection of "
954-
"reference documents to evaluate the similarity between them. "
955-
"The metrics range from 0 to 1, with higher scores indicating "
956-
"greater similarity. ROUGE is more suitable for models that don't "
957-
"include paraphrasing and do not generate new text units that don't "
958-
"appear in the references."
959-
),
960-
"args": {},
961-
},
962-
{
963-
"use_case": ["text_generation"],
964-
"key": "bleu",
965-
"name": "bleu",
966-
"description": (
967-
"BLEU (Bilingual Evaluation Understudy) is an algorithm for evaluating the "
968-
"quality of text which has been machine-translated from one natural language to another. "
969-
"Quality is considered to be the correspondence between a machine's output and that of a "
970-
"human: 'the closer a machine translation is to a professional human translation, "
971-
"the better it is'."
972-
),
973-
"args": {},
974-
},
934+
item.to_dict() for item in evaluation_service_config().ui_config.metrics
975935
]
976936

977937
@telemetry(entry_point="plugin=evaluation&action=load_metrics", name="aqua")

tests/unitary/with_extras/aqua/test_data/config/evaluation_config_with_default_params.json

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,7 @@
55
},
66
"inference_params": {
77
"containers": [],
8-
"default": {
9-
"inference_backoff_factor": 3,
10-
"inference_delay": 0,
11-
"inference_max_threads": 10,
12-
"inference_retries": 3,
13-
"inference_rps": 25,
14-
"inference_timeout": 120
15-
}
8+
"default": {}
169
},
1710
"kind": "evaluation_service_config",
1811
"report_params": {

tests/unitary/with_extras/aqua/test_evaluation.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from ads.aqua.config.evaluation.evaluation_service_config import (
2626
EvaluationServiceConfig,
27+
MetricConfig,
2728
ModelParamsConfig,
2829
ShapeConfig,
2930
UIConfig,
@@ -302,7 +303,7 @@ class TestDataset:
302303
"category": null,
303304
"description": null,
304305
"key": "Hyperparameters",
305-
"value": '{"model_params": {"max_tokens": 500, "top_p": 1, "top_k": 50, "temperature": 0.7, "presence_penalty": 0, "frequency_penalty": 0, "stop": [], "shape": "VM.Standard.E3.Flex", "dataset_path": "oci://mybucket@mytenancy/data.jsonl", "report_path": "oci://mybucket@mytenancy/report"}}',
306+
"value": '{"model_params": {"model": "odsc-llm", "max_tokens": 500, "top_p": 1, "top_k": 50, "temperature": 0.7, "presence_penalty": 0, "frequency_penalty": 0, "stop": [], "shape": "VM.Standard.E3.Flex", "dataset_path": "oci://mybucket@mytenancy/data.jsonl", "report_path": "oci://mybucket@mytenancy/report"}}',
306307
},
307308
{
308309
"category": null,
@@ -506,6 +507,7 @@ def test_create_evaluation(
506507
"lifecycle_state": f"{evaluation_job_run.lifecycle_state}",
507508
"name": f"{evaluation_model.display_name}",
508509
"parameters": {
510+
"model": "odsc-llm",
509511
"dataset_path": "",
510512
"frequency_penalty": 0.0,
511513
"max_tokens": "",
@@ -881,17 +883,35 @@ def test_extract_job_lifecycle_details(self, input, expect_output):
881883
msg = self.app._extract_job_lifecycle_details(input)
882884
assert msg == expect_output, msg
883885

884-
def test_get_supported_metrics(self):
885-
"""Tests getting a list of supported metrics for evaluation.
886-
This method currently hardcoded the return value.
886+
@patch("ads.aqua.evaluation.evaluation.evaluation_service_config")
887+
def test_get_supported_metrics(self, mock_evaluation_service_config):
888+
"""
889+
Tests getting a list of supported metrics for evaluation.
887890
"""
888-
from .utils import SupportMetricsFormat as metric_schema
889-
from .utils import check
890891

892+
test_evaluation_service_config = EvaluationServiceConfig(
893+
ui_config=UIConfig(
894+
metrics=[
895+
MetricConfig(
896+
**{
897+
"args": {},
898+
"description": "BERT Score.",
899+
"key": "bertscore",
900+
"name": "BERT Score",
901+
"tags": [],
902+
"task": ["text-generation"],
903+
},
904+
)
905+
]
906+
)
907+
)
908+
mock_evaluation_service_config.return_value = test_evaluation_service_config
891909
response = self.app.get_supported_metrics()
892910
assert isinstance(response, list)
893-
for metric in response:
894-
assert check(metric_schema, metric)
911+
assert len(response) == len(test_evaluation_service_config.ui_config.metrics)
912+
assert response == [
913+
item.to_dict() for item in test_evaluation_service_config.ui_config.metrics
914+
]
895915

896916
@patch("ads.aqua.evaluation.evaluation.evaluation_service_config")
897917
def test_load_evaluation_config(self, mock_evaluation_service_config):

tests/unitary/with_extras/aqua/utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,6 @@ def __post_init__(self):
7676
)
7777

7878

79-
@dataclass
80-
class SupportMetricsFormat(BaseFormat):
81-
"""Format for supported evaluation metrics."""
82-
83-
use_case: list
84-
key: str
85-
name: str
86-
description: str
87-
args: dict
88-
89-
9079
def check(conf_schema, conf):
9180
"""Check if the format of the output dictionary is correct."""
9281
try:

0 commit comments

Comments
 (0)