Skip to content

Commit b983a7f

Browse files
committed
Improved evaluation model parameters.
1 parent 2799630 commit b983a7f

File tree

5 files changed

+72
-146
lines changed

5 files changed

+72
-146
lines changed

ads/aqua/evaluation/entities.py

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,18 @@
99
This module contains dataclasses for aqua evaluation.
1010
"""
1111

12-
from dataclasses import dataclass, field
13-
from typing import List, Optional, Union
12+
from pydantic import Field
13+
from typing import Any, Dict, List, Optional, Union
1414

1515
from ads.aqua.data import AquaResourceIdentifier
16-
from ads.common.serializer import DataClassSerializable
16+
from ads.aqua.config.utils.serializer import Serializable
1717

1818

19-
@dataclass(repr=False)
20-
class CreateAquaEvaluationDetails(DataClassSerializable):
21-
"""Dataclass to create aqua model evaluation.
19+
class CreateAquaEvaluationDetails(Serializable):
20+
"""Class for creating aqua model evaluation.
2221
23-
Fields
24-
------
22+
Properties
23+
----------
2524
evaluation_source_id: str
2625
The evaluation source id. Must be either model or model deployment ocid.
2726
evaluation_name: str
@@ -83,69 +82,74 @@ class CreateAquaEvaluationDetails(DataClassSerializable):
8382
ocpus: Optional[float] = None
8483
log_group_id: Optional[str] = None
8584
log_id: Optional[str] = None
86-
metrics: Optional[List] = None
85+
metrics: Optional[List[str]] = None
8786
force_overwrite: Optional[bool] = False
8887

88+
class Config:
89+
extra = "ignore"
8990

90-
@dataclass(repr=False)
91-
class AquaEvalReport(DataClassSerializable):
91+
class AquaEvalReport(Serializable):
9292
evaluation_id: str = ""
9393
content: str = ""
9494

95+
class Config:
96+
extra = "ignore"
9597

96-
@dataclass(repr=False)
97-
class ModelParams(DataClassSerializable):
98+
class ModelParams(Serializable):
9899
max_tokens: str = ""
99100
top_p: str = ""
100101
top_k: str = ""
101102
temperature: str = ""
102103
presence_penalty: Optional[float] = 0.0
103104
frequency_penalty: Optional[float] = 0.0
104-
stop: Optional[Union[str, List[str]]] = field(default_factory=list)
105+
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
105106
model: Optional[str] = "odsc-llm"
106107

108+
class Config:
109+
extra = "allow"
107110

108-
@dataclass(repr=False)
109-
class AquaEvalParams(ModelParams, DataClassSerializable):
111+
class AquaEvalParams(ModelParams):
110112
shape: str = ""
111113
dataset_path: str = ""
112114
report_path: str = ""
113115

114-
115-
@dataclass(repr=False)
116-
class AquaEvalMetric(DataClassSerializable):
116+
class AquaEvalMetric(Serializable):
117117
key: str
118118
name: str
119119
description: str = ""
120120

121+
class Config:
122+
extra = "ignore"
121123

122-
@dataclass(repr=False)
123-
class AquaEvalMetricSummary(DataClassSerializable):
124+
class AquaEvalMetricSummary(Serializable):
124125
metric: str = ""
125126
score: str = ""
126127
grade: str = ""
127128

129+
class Config:
130+
extra = "ignore"
128131

129-
@dataclass(repr=False)
130-
class AquaEvalMetrics(DataClassSerializable):
132+
class AquaEvalMetrics(Serializable):
131133
id: str
132134
report: str
133-
metric_results: List[AquaEvalMetric] = field(default_factory=list)
134-
metric_summary_result: List[AquaEvalMetricSummary] = field(default_factory=list)
135+
metric_results: List[AquaEvalMetric] = Field(default_factory=list)
136+
metric_summary_result: List[AquaEvalMetricSummary] = Field(default_factory=list)
135137

138+
class Config:
139+
extra = "ignore"
136140

137-
@dataclass(repr=False)
138-
class AquaEvaluationCommands(DataClassSerializable):
141+
class AquaEvaluationCommands(Serializable):
139142
evaluation_id: str
140143
evaluation_target_id: str
141-
input_data: dict
142-
metrics: list
144+
input_data: Dict[str, Any]
145+
metrics: List[str]
143146
output_dir: str
144-
params: dict
147+
params: Dict[str, Any]
145148

149+
class Config:
150+
extra = "ignore"
146151

147-
@dataclass(repr=False)
148-
class AquaEvaluationSummary(DataClassSerializable):
152+
class AquaEvaluationSummary(Serializable):
149153
"""Represents a summary of Aqua evalution."""
150154

151155
id: str
@@ -154,17 +158,18 @@ class AquaEvaluationSummary(DataClassSerializable):
154158
lifecycle_state: str
155159
lifecycle_details: str
156160
time_created: str
157-
tags: dict
158-
experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
159-
source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
160-
job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
161-
parameters: AquaEvalParams = field(default_factory=AquaEvalParams)
161+
tags: Dict[str, Any]
162+
experiment: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
163+
source: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
164+
job: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
165+
parameters: AquaEvalParams = Field(default_factory=AquaEvalParams)
162166

167+
class Config:
168+
extra = "ignore"
163169

164-
@dataclass(repr=False)
165-
class AquaEvaluationDetail(AquaEvaluationSummary, DataClassSerializable):
170+
class AquaEvaluationDetail(AquaEvaluationSummary):
166171
"""Represents a details of Aqua evalution."""
167172

168-
log_group: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
169-
log: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
170-
introspection: dict = field(default_factory=dict)
173+
log_group: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
174+
log: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
175+
introspection: dict = Field(default_factory=dict)

ads/aqua/evaluation/evaluation.py

Lines changed: 23 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import re
88
import tempfile
99
from concurrent.futures import ThreadPoolExecutor, as_completed
10-
from dataclasses import asdict, fields
1110
from datetime import datetime, timedelta
1211
from pathlib import Path
1312
from threading import Lock
@@ -46,7 +45,6 @@
4645
upload_local_to_os,
4746
)
4847
from ads.aqua.config.config import get_evaluation_service_config
49-
from ads.aqua.config.evaluation.evaluation_service_config import EvaluationServiceConfig
5048
from ads.aqua.constants import (
5149
CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
5250
EVALUATION_REPORT,
@@ -75,7 +73,6 @@
7573
AquaEvaluationSummary,
7674
AquaResourceIdentifier,
7775
CreateAquaEvaluationDetails,
78-
ModelParams,
7976
)
8077
from ads.aqua.evaluation.errors import EVALUATION_JOB_EXIT_CODE_MESSAGE
8178
from ads.aqua.ui import AquaContainerConfig
@@ -164,7 +161,7 @@ def create(
164161
raise AquaValueError(
165162
"Invalid create evaluation parameters. "
166163
"Allowable parameters are: "
167-
f"{', '.join([field.name for field in fields(CreateAquaEvaluationDetails)])}."
164+
f"{', '.join([field for field in CreateAquaEvaluationDetails.model_fields])}."
168165
) from ex
169166

170167
if not is_valid_ocid(create_aqua_evaluation_details.evaluation_source_id):
@@ -175,15 +172,7 @@ def create(
175172

176173
# The model to evaluate
177174
evaluation_source = None
178-
# The evaluation service config
179-
evaluation_config: EvaluationServiceConfig = get_evaluation_service_config()
180-
# The evaluation inference configuration. The inference configuration will be extracted
181-
# based on the inferencing container family.
182175
eval_inference_configuration: Dict = {}
183-
# The evaluation inference model sampling params. The system parameters that will not be
184-
# visible for user, but will be applied implicitly for evaluation. The service model params
185-
# will be extracted based on the container family and version.
186-
eval_inference_service_model_params: Dict = {}
187176

188177
if (
189178
DataScienceResource.MODEL_DEPLOYMENT
@@ -200,29 +189,14 @@ def create(
200189
runtime = ModelDeploymentContainerRuntime.from_dict(
201190
evaluation_source.runtime.to_dict()
202191
)
203-
container_config = AquaContainerConfig.from_container_index_json(
192+
inference_config = AquaContainerConfig.from_container_index_json(
204193
enable_spec=True
205-
)
206-
for (
207-
inference_container_family,
208-
inference_container_info,
209-
) in container_config.inference.items():
210-
if (
211-
inference_container_info.name
212-
== runtime.image[: runtime.image.rfind(":")]
213-
):
194+
).inference
195+
for container in inference_config.values():
196+
if container.name == runtime.image[: runtime.image.rfind(":")]:
214197
eval_inference_configuration = (
215-
evaluation_config.get_merged_inference_params(
216-
inference_container_family
217-
).to_dict()
218-
)
219-
eval_inference_service_model_params = (
220-
evaluation_config.get_merged_inference_model_params(
221-
inference_container_family,
222-
inference_container_info.version,
223-
)
198+
container.spec.evaluation_configuration
224199
)
225-
226200
except Exception:
227201
logger.debug(
228202
f"Could not load inference config details for the evaluation source id: "
@@ -277,19 +251,12 @@ def create(
277251
)
278252
evaluation_dataset_path = dst_uri
279253

280-
evaluation_model_parameters = None
281-
try:
282-
evaluation_model_parameters = AquaEvalParams(
283-
shape=create_aqua_evaluation_details.shape_name,
284-
dataset_path=evaluation_dataset_path,
285-
report_path=create_aqua_evaluation_details.report_path,
286-
**create_aqua_evaluation_details.model_parameters,
287-
)
288-
except Exception as ex:
289-
raise AquaValueError(
290-
"Invalid model parameters. Model parameters should "
291-
f"be a dictionary with keys: {', '.join(list(ModelParams.__annotations__.keys()))}."
292-
) from ex
254+
evaluation_model_parameters = AquaEvalParams(
255+
shape=create_aqua_evaluation_details.shape_name,
256+
dataset_path=evaluation_dataset_path,
257+
report_path=create_aqua_evaluation_details.report_path,
258+
**create_aqua_evaluation_details.model_parameters,
259+
)
293260

294261
target_compartment = (
295262
create_aqua_evaluation_details.compartment_id or COMPARTMENT_OCID
@@ -370,7 +337,7 @@ def create(
370337
evaluation_model_taxonomy_metadata = ModelTaxonomyMetadata()
371338
evaluation_model_taxonomy_metadata[
372339
MetadataTaxonomyKeys.HYPERPARAMETERS
373-
].value = {"model_params": dict(asdict(evaluation_model_parameters))}
340+
].value = {"model_params": evaluation_model_parameters.to_dict()}
374341

375342
evaluation_model = (
376343
DataScienceModel()
@@ -443,7 +410,6 @@ def create(
443410
dataset_path=evaluation_dataset_path,
444411
report_path=create_aqua_evaluation_details.report_path,
445412
model_parameters={
446-
**eval_inference_service_model_params,
447413
**create_aqua_evaluation_details.model_parameters,
448414
},
449415
metrics=create_aqua_evaluation_details.metrics,
@@ -580,16 +546,14 @@ def _build_evaluation_runtime(
580546
**{
581547
"AIP_SMC_EVALUATION_ARGUMENTS": json.dumps(
582548
{
583-
**asdict(
584-
self._build_launch_cmd(
585-
evaluation_id=evaluation_id,
586-
evaluation_source_id=evaluation_source_id,
587-
dataset_path=dataset_path,
588-
report_path=report_path,
589-
model_parameters=model_parameters,
590-
metrics=metrics,
591-
),
592-
),
549+
**self._build_launch_cmd(
550+
evaluation_id=evaluation_id,
551+
evaluation_source_id=evaluation_source_id,
552+
dataset_path=dataset_path,
553+
report_path=report_path,
554+
model_parameters=model_parameters,
555+
metrics=metrics,
556+
).to_dict(),
593557
**(inference_configuration or {}),
594558
},
595559
),
@@ -662,9 +626,9 @@ def _build_launch_cmd(
662626
"format": Path(dataset_path).suffix,
663627
"url": dataset_path,
664628
},
665-
metrics=metrics,
629+
metrics=metrics or [],
666630
output_dir=report_path,
667-
params=model_parameters,
631+
params=model_parameters or {},
668632
)
669633

670634
@telemetry(entry_point="plugin=evaluation&action=get", name="aqua")

ads/aqua/extension/evaluation_handler.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from ads.aqua.evaluation.entities import CreateAquaEvaluationDetails
1313
from ads.aqua.extension.base_handler import AquaAPIhandler
1414
from ads.aqua.extension.errors import Errors
15-
from ads.aqua.extension.utils import validate_function_parameters
1615
from ads.config import COMPARTMENT_OCID
1716

1817

@@ -47,10 +46,6 @@ def post(self, *args, **kwargs): # noqa
4746
if not input_data:
4847
raise HTTPError(400, Errors.NO_INPUT_DATA)
4948

50-
validate_function_parameters(
51-
data_class=CreateAquaEvaluationDetails, input_data=input_data
52-
)
53-
5449
self.finish(
5550
# TODO: decide what other kwargs will be needed for create aqua evaluation.
5651
AquaEvaluationApp().create(

tests/unitary/with_extras/aqua/test_evaluation.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import json
1010
import os
1111
import unittest
12-
from dataclasses import asdict
1312
from unittest.mock import MagicMock, PropertyMock, patch
1413

1514
import oci
@@ -419,14 +418,13 @@ def assert_payload(self, response, response_type):
419418
"""Checks each field is not empty."""
420419

421420
attributes = response_type.__annotations__.keys()
422-
rdict = asdict(response)
421+
rdict = response.to_dict()
423422

424423
for attr in attributes:
425424
if attr == "lifecycle_details": # can be empty when jobrun is succeed
426425
continue
427426
assert rdict.get(attr), f"{attr} is empty"
428427

429-
@patch("ads.aqua.evaluation.evaluation.get_evaluation_service_config")
430428
@patch.object(Job, "run")
431429
@patch("ads.jobs.ads_job.Job.name", new_callable=PropertyMock)
432430
@patch("ads.jobs.ads_job.Job.id", new_callable=PropertyMock)
@@ -445,7 +443,6 @@ def test_create_evaluation(
445443
mock_job_id,
446444
mock_job_name,
447445
mock_job_run,
448-
mock_get_evaluation_service_config,
449446
):
450447
foundation_model = MagicMock()
451448
foundation_model.display_name = "test_foundation_model"
@@ -475,8 +472,6 @@ def test_create_evaluation(
475472
evaluation_job_run.lifecycle_state = "IN_PROGRESS"
476473
mock_job_run.return_value = evaluation_job_run
477474

478-
mock_get_evaluation_service_config.return_value = EvaluationServiceConfig()
479-
480475
self.app.ds_client.update_model = MagicMock()
481476
self.app.ds_client.update_model_provenance = MagicMock()
482477

@@ -494,7 +489,7 @@ def test_create_evaluation(
494489
)
495490
aqua_evaluation_summary = self.app.create(**create_aqua_evaluation_details)
496491

497-
assert asdict(aqua_evaluation_summary) == {
492+
assert aqua_evaluation_summary.to_dict() == {
498493
"console_url": f"https://cloud.oracle.com/data-science/models/{evaluation_model.id}?region={self.app.region}",
499494
"experiment": {
500495
"id": f"{experiment.id}",

0 commit comments

Comments
 (0)