Skip to content

Commit 2a24441

Browse files
authored
Aqua refactor: improvement on evaluation.py (#855)
1 parent 1e88353 commit 2a24441

File tree

3 files changed

+46
-48
lines changed

3 files changed

+46
-48
lines changed

ads/aqua/common/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class Tags(str, metaclass=ExtendedEnumMeta):
4242
READY_TO_FINE_TUNE = "ready_to_fine_tune"
4343
READY_TO_IMPORT = "ready_to_import"
4444
BASE_MODEL_CUSTOM = "aqua_custom_base_model"
45+
AQUA_EVALUATION_MODEL_ID = "evaluation_model_id"
4546

4647

4748
class RqsAdditionalDetails(str, metaclass=ExtendedEnumMeta):

ads/aqua/evaluation/constants.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,29 +28,26 @@ class EvaluationCustomMetadata(str, metaclass=ExtendedEnumMeta):
2828
EVALUATION_ERROR = "aqua_evaluate_error"
2929

3030

31-
class EvaluationModelTags(str, metaclass=ExtendedEnumMeta):
32-
AQUA_EVALUATION = "aqua_evaluation"
31+
class EvaluationConfig(str, metaclass=ExtendedEnumMeta):
32+
PARAMS = "model_params"
3333

3434

35-
class EvaluationJobTags(str, metaclass=ExtendedEnumMeta):
36-
AQUA_EVALUATION = "aqua_evaluation"
37-
EVALUATION_MODEL_ID = "evaluation_model_id"
35+
class EvaluationReportJson(str, metaclass=ExtendedEnumMeta):
36+
"""Contains evaluation report.json fields name."""
3837

38+
METRIC_SUMMARY_RESULT = "metric_summary_result"
39+
METRIC_RESULT = "metric_results"
40+
MODEL_PARAMS = "model_params"
41+
MODEL_DETAILS = "model_details"
42+
DATA = "data"
43+
DATASET = "dataset"
3944

40-
class EvaluationUploadStatus(str, metaclass=ExtendedEnumMeta):
41-
IN_PROGRESS = "IN_PROGRESS"
42-
COMPLETED = "COMPLETED"
4345

46+
class EvaluationMetricResult(str, metaclass=ExtendedEnumMeta):
47+
"""Contains metric result's fields name in report.json."""
4448

45-
class RqsAdditionalDetails(str, metaclass=ExtendedEnumMeta):
46-
METADATA = "metadata"
47-
CREATED_BY = "createdBy"
49+
SHORT_NAME = "key"
50+
NAME = "name"
4851
DESCRIPTION = "description"
49-
MODEL_VERSION_SET_ID = "modelVersionSetId"
50-
MODEL_VERSION_SET_NAME = "modelVersionSetName"
51-
PROJECT_ID = "projectId"
52-
VERSION_LABEL = "versionLabel"
53-
54-
55-
class EvaluationConfig(str, metaclass=ExtendedEnumMeta):
56-
PARAMS = "model_params"
52+
SUMMARY_DATA = "summary_data"
53+
DATA = "data"

ads/aqua/evaluation/evaluation.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,8 @@ def create(
325325
# TODO: validate metrics if it's provided
326326

327327
evaluation_job_freeform_tags = {
328-
EvaluationJobTags.AQUA_EVALUATION: EvaluationJobTags.AQUA_EVALUATION,
329-
EvaluationJobTags.EVALUATION_MODEL_ID: evaluation_model.id,
328+
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
329+
Tags.AQUA_EVALUATION_MODEL_ID: evaluation_model.id,
330330
}
331331

332332
evaluation_job = Job(name=evaluation_model.display_name).with_infrastructure(
@@ -408,7 +408,7 @@ def create(
408408
update_model_details=UpdateModelDetails(
409409
custom_metadata_list=updated_custom_metadata_list,
410410
freeform_tags={
411-
EvaluationModelTags.AQUA_EVALUATION: EvaluationModelTags.AQUA_EVALUATION,
411+
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
412412
},
413413
),
414414
)
@@ -479,7 +479,7 @@ def create(
479479
),
480480
),
481481
tags=dict(
482-
aqua_evaluation=EvaluationModelTags.AQUA_EVALUATION,
482+
aqua_evaluation=Tags.AQUA_EVALUATION,
483483
evaluation_job_id=evaluation_job.id,
484484
evaluation_source=create_aqua_evaluation_details.evaluation_source_id,
485485
evaluation_experiment_id=experiment_model_version_set_id,
@@ -708,12 +708,10 @@ def list(
708708
models = utils.query_resources(
709709
compartment_id=compartment_id,
710710
resource_type="datasciencemodel",
711-
tag_list=[EvaluationModelTags.AQUA_EVALUATION],
711+
tag_list=[Tags.AQUA_EVALUATION],
712712
)
713713
logger.info(f"Fetched {len(models)} evaluations.")
714714

715-
# TODO: add filter based on project_id if needed.
716-
717715
mapping = self._prefetch_resources(compartment_id)
718716

719717
evaluations = []
@@ -934,11 +932,13 @@ def load_metrics(self, eval_id: str) -> AquaEvalMetrics:
934932
)
935933

936934
files_in_artifact = get_files(temp_dir)
937-
report_content = self._read_from_artifact(
935+
md_report_content = self._read_from_artifact(
938936
temp_dir, files_in_artifact, utils.EVALUATION_REPORT_MD
939937
)
938+
939+
# json report not availiable for failed evaluation
940940
try:
941-
report = json.loads(
941+
json_report = json.loads(
942942
self._read_from_artifact(
943943
temp_dir, files_in_artifact, utils.EVALUATION_REPORT_JSON
944944
)
@@ -947,27 +947,32 @@ def load_metrics(self, eval_id: str) -> AquaEvalMetrics:
947947
logger.debug(
948948
"Failed to load `report.json` from evaluation artifact" f"{str(e)}"
949949
)
950-
report = {}
950+
json_report = {}
951951

952-
# TODO: after finalizing the format of report.json, move the constant to class
953952
eval_metrics = AquaEvalMetrics(
954953
id=eval_id,
955-
report=base64.b64encode(report_content).decode(),
954+
report=base64.b64encode(md_report_content).decode(),
956955
metric_results=[
957956
AquaEvalMetric(
958-
key=metric_key,
959-
name=metadata.get("name", utils.UNKNOWN),
960-
description=metadata.get("description", utils.UNKNOWN),
957+
key=metadata.get(EvaluationMetricResult.SHORT_NAME, utils.UNKNOWN),
958+
name=metadata.get(EvaluationMetricResult.NAME, utils.UNKNOWN),
959+
description=metadata.get(
960+
EvaluationMetricResult.DESCRIPTION, utils.UNKNOWN
961+
),
961962
)
962-
for metric_key, metadata in report.get("metric_results", {}).items()
963+
for _, metadata in json_report.get(
964+
EvaluationReportJson.METRIC_RESULT, {}
965+
).items()
963966
],
964967
metric_summary_result=[
965968
AquaEvalMetricSummary(**m)
966-
for m in report.get("metric_summary_result", [{}])
969+
for m in json_report.get(
970+
EvaluationReportJson.METRIC_SUMMARY_RESULT, [{}]
971+
)
967972
],
968973
)
969974

970-
if report_content:
975+
if md_report_content:
971976
self._metrics_cache.__setitem__(key=eval_id, value=eval_metrics)
972977

973978
return eval_metrics
@@ -1266,16 +1271,16 @@ def _get_source(
12661271
)
12671272
)
12681273

1269-
if not source_name:
1274+
# try to resolve source_name from source id
1275+
if source_id and not source_name:
12701276
resource_type = utils.get_resource_type(source_id)
12711277

1272-
# TODO: adjust resource principal mapping
1273-
if resource_type == "datasciencemodel":
1274-
source_name = self.ds_client.get_model(source_id).data.display_name
1275-
elif resource_type == "datasciencemodeldeployment":
1278+
if resource_type.startswith("datasciencemodeldeployment"):
12761279
source_name = self.ds_client.get_model_deployment(
12771280
source_id
12781281
).data.display_name
1282+
elif resource_type.startswith("datasciencemodel"):
1283+
source_name = self.ds_client.get_model(source_id).data.display_name
12791284
else:
12801285
raise AquaRuntimeError(
12811286
f"Not supported source type: {resource_type}"
@@ -1404,8 +1409,6 @@ def _fetch_runtime_params(
14041409
"model parameters have not been saved in correct format in model taxonomy. ",
14051410
service_payload={"params": params},
14061411
)
1407-
# TODO: validate the format of parameters.
1408-
# self._validate_params(params)
14091412

14101413
return AquaEvalParams(**params[EvaluationConfig.PARAMS])
14111414
except Exception as e:
@@ -1438,7 +1441,6 @@ def _build_job_identifier(
14381441
)
14391442
return AquaResourceIdentifier()
14401443

1441-
# TODO: fix the logic for determine termination state
14421444
def _get_status(
14431445
self,
14441446
model: oci.resource_search.models.ResourceSummary,
@@ -1498,12 +1500,10 @@ def _get_status(
14981500

14991501
def _prefetch_resources(self, compartment_id) -> dict:
15001502
"""Fetches all AQUA resources."""
1501-
# TODO: handle cross compartment/tenency resources
1502-
# TODO: add cache
15031503
resources = utils.query_resources(
15041504
compartment_id=compartment_id,
15051505
resource_type="all",
1506-
tag_list=[EvaluationModelTags.AQUA_EVALUATION, "OCI_AQUA"],
1506+
tag_list=[Tags.AQUA_EVALUATION, "OCI_AQUA"],
15071507
connect_by_ampersands=False,
15081508
return_all=False,
15091509
)

0 commit comments

Comments
 (0)