Skip to content

Commit a76b698

Browse files
authored
Fixes evaluation unit tests. (#944)
2 parents 9a9569b + e3de1d0 commit a76b698

File tree

4 files changed

+17
-13
lines changed

4 files changed

+17
-13
lines changed

ads/aqua/config/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
DEFAULT_EVALUATION_CONTAINER = "odsc-llm-evaluate"
1313

1414

15-
def evaluation_service_config(
15+
def get_evaluation_service_config(
1616
container: Optional[str] = DEFAULT_EVALUATION_CONTAINER,
1717
) -> EvaluationServiceConfig:
1818
"""

ads/aqua/evaluation/evaluation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
is_valid_ocid,
4646
upload_local_to_os,
4747
)
48-
from ads.aqua.config.config import evaluation_service_config
48+
from ads.aqua.config.config import get_evaluation_service_config
4949
from ads.aqua.config.evaluation.evaluation_service_config import EvaluationServiceConfig
5050
from ads.aqua.constants import (
5151
CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
@@ -176,7 +176,7 @@ def create(
176176
# The model to evaluate
177177
evaluation_source = None
178178
# The evaluation service config
179-
evaluation_config: EvaluationServiceConfig = evaluation_service_config()
179+
evaluation_config: EvaluationServiceConfig = get_evaluation_service_config()
180180
# The evaluation inference configuration. The inference configuration will be extracted
181181
# based on the inferencing container family.
182182
eval_inference_configuration: Dict = {}
@@ -931,7 +931,7 @@ def get_status(self, eval_id: str) -> dict:
931931
def get_supported_metrics(self) -> dict:
932932
"""Gets a list of supported metrics for evaluation."""
933933
return [
934-
item.to_dict() for item in evaluation_service_config().ui_config.metrics
934+
item.to_dict() for item in get_evaluation_service_config().ui_config.metrics
935935
]
936936

937937
@telemetry(entry_point="plugin=evaluation&action=load_metrics", name="aqua")
@@ -1218,7 +1218,7 @@ def load_evaluation_config(self, container: Optional[str] = None) -> Dict:
12181218
"""Loads evaluation config."""
12191219

12201220
# retrieve the evaluation config by container family name
1221-
evaluation_config = evaluation_service_config(container)
1221+
evaluation_config = get_evaluation_service_config(container)
12221222

12231223
# convert the new config representation to the old one
12241224
return {

tests/unitary/with_extras/aqua/test_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from unittest.mock import patch
88

99
from ads.aqua.common.entities import ContainerSpec
10-
from ads.aqua.config.config import evaluation_service_config
10+
from ads.aqua.config.config import get_evaluation_service_config
1111

1212

1313
class TestConfig:
@@ -32,7 +32,7 @@ def test_evaluation_service_config(self, mock_get_container_config):
3232

3333
mock_get_container_config.return_value = expected_result
3434

35-
test_result = evaluation_service_config(container="test_container")
35+
test_result = get_evaluation_service_config(container="test_container")
3636
assert (
3737
test_result.to_dict()
3838
== expected_result[ContainerSpec.CONTAINER_SPEC]["test_container"]

tests/unitary/with_extras/aqua/test_evaluation.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,7 @@ def assert_payload(self, response, response_type):
426426
continue
427427
assert rdict.get(attr), f"{attr} is empty"
428428

429+
@patch("ads.aqua.evaluation.evaluation.get_evaluation_service_config")
429430
@patch.object(Job, "run")
430431
@patch("ads.jobs.ads_job.Job.name", new_callable=PropertyMock)
431432
@patch("ads.jobs.ads_job.Job.id", new_callable=PropertyMock)
@@ -444,6 +445,7 @@ def test_create_evaluation(
444445
mock_job_id,
445446
mock_job_name,
446447
mock_job_run,
448+
mock_get_evaluation_service_config,
447449
):
448450
foundation_model = MagicMock()
449451
foundation_model.display_name = "test_foundation_model"
@@ -473,6 +475,8 @@ def test_create_evaluation(
473475
evaluation_job_run.lifecycle_state = "IN_PROGRESS"
474476
mock_job_run.return_value = evaluation_job_run
475477

478+
mock_get_evaluation_service_config.return_value = EvaluationServiceConfig()
479+
476480
self.app.ds_client.update_model = MagicMock()
477481
self.app.ds_client.update_model_provenance = MagicMock()
478482

@@ -883,8 +887,8 @@ def test_extract_job_lifecycle_details(self, input, expect_output):
883887
msg = self.app._extract_job_lifecycle_details(input)
884888
assert msg == expect_output, msg
885889

886-
@patch("ads.aqua.evaluation.evaluation.evaluation_service_config")
887-
def test_get_supported_metrics(self, mock_evaluation_service_config):
890+
@patch("ads.aqua.evaluation.evaluation.get_evaluation_service_config")
891+
def test_get_supported_metrics(self, mock_get_evaluation_service_config):
888892
"""
889893
Tests getting a list of supported metrics for evaluation.
890894
"""
@@ -905,16 +909,16 @@ def test_get_supported_metrics(self, mock_evaluation_service_config):
905909
]
906910
)
907911
)
908-
mock_evaluation_service_config.return_value = test_evaluation_service_config
912+
mock_get_evaluation_service_config.return_value = test_evaluation_service_config
909913
response = self.app.get_supported_metrics()
910914
assert isinstance(response, list)
911915
assert len(response) == len(test_evaluation_service_config.ui_config.metrics)
912916
assert response == [
913917
item.to_dict() for item in test_evaluation_service_config.ui_config.metrics
914918
]
915919

916-
@patch("ads.aqua.evaluation.evaluation.evaluation_service_config")
917-
def test_load_evaluation_config(self, mock_evaluation_service_config):
920+
@patch("ads.aqua.evaluation.evaluation.get_evaluation_service_config")
921+
def test_load_evaluation_config(self, mock_get_evaluation_service_config):
918922
"""
919923
Tests loading default config for evaluation.
920924
This method currently hardcoded the return value.
@@ -952,7 +956,7 @@ def test_load_evaluation_config(self, mock_evaluation_service_config):
952956
],
953957
)
954958
)
955-
mock_evaluation_service_config.return_value = test_evaluation_service_config
959+
mock_get_evaluation_service_config.return_value = test_evaluation_service_config
956960

957961
expected_result = {
958962
"model_params": {

0 commit comments

Comments
 (0)