Skip to content

Commit eb31921

Browse files
[ODSC-60256] llama cpp evaluation support (#909)
2 parents ca3c278 + 11abb35 commit eb31921

File tree

7 files changed

+106
-16
lines changed

7 files changed

+106
-16
lines changed

ads/aqua/common/entities.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ class ContainerSpec:
1414
HEALTH_CHECK_PORT = "healthCheckPort"
1515
ENV_VARS = "envVars"
1616
RESTRICTED_PARAMS = "restrictedParams"
17+
EVALUATION_CONFIGURATION = "evaluationConfiguration"

ads/aqua/common/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def list_os_files_with_extension(oss_path: str, extension: str) -> [str]:
249249
files: List[ObjectSummary] = oss_client.list_objects().objects
250250

251251
return [
252-
file.name[len(oss_client.filepath) :]
252+
file.name[len(oss_client.filepath) :].lstrip("/")
253253
for file in files
254254
if file.name.endswith(extension)
255255
]

ads/aqua/evaluation/entities.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

ads/aqua/evaluation/evaluation.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
ModelParams,
7777
)
7878
from ads.aqua.evaluation.errors import EVALUATION_JOB_EXIT_CODE_MESSAGE
79+
from ads.aqua.ui import AquaContainerConfig
7980
from ads.common.auth import default_signer
8081
from ads.common.object_storage_details import ObjectStorageDetails
8182
from ads.common.utils import get_console_link, get_files, get_log_links
@@ -90,7 +91,9 @@
9091
from ads.jobs.builders.runtimes.base import Runtime
9192
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
9293
from ads.model.datascience_model import DataScienceModel
94+
from ads.model.deployment import ModelDeploymentContainerRuntime
9395
from ads.model.deployment.model_deployment import ModelDeployment
96+
from ads.model.generic_model import ModelDeploymentRuntimeType
9497
from ads.model.model_metadata import (
9598
MetadataTaxonomyKeys,
9699
ModelCustomMetadata,
@@ -166,15 +169,27 @@ def create(
166169
f"Invalid evaluation source {create_aqua_evaluation_details.evaluation_source_id}. "
167170
"Specify either a model or model deployment id."
168171
)
169-
170172
evaluation_source = None
173+
eval_inference_configuration = None
171174
if (
172175
DataScienceResource.MODEL_DEPLOYMENT
173176
in create_aqua_evaluation_details.evaluation_source_id
174177
):
175178
evaluation_source = ModelDeployment.from_id(
176179
create_aqua_evaluation_details.evaluation_source_id
177180
)
181+
if evaluation_source.runtime.type == ModelDeploymentRuntimeType.CONTAINER:
182+
runtime = ModelDeploymentContainerRuntime.from_dict(
183+
evaluation_source.runtime.to_dict()
184+
)
185+
inference_config = AquaContainerConfig.from_container_index_json(
186+
enable_spec=True
187+
).inference
188+
for container in inference_config.values():
189+
if container.name == runtime.image.split(":")[0]:
190+
eval_inference_configuration = (
191+
container.spec.evaluation_configuration
192+
)
178193
elif (
179194
DataScienceResource.MODEL
180195
in create_aqua_evaluation_details.evaluation_source_id
@@ -390,6 +405,9 @@ def create(
390405
report_path=create_aqua_evaluation_details.report_path,
391406
model_parameters=create_aqua_evaluation_details.model_parameters,
392407
metrics=create_aqua_evaluation_details.metrics,
408+
inference_configuration=eval_inference_configuration.to_filtered_dict()
409+
if eval_inference_configuration
410+
else {},
393411
)
394412
).create(**kwargs) ## TODO: decide what parameters will be needed
395413
logger.debug(
@@ -511,6 +529,7 @@ def _build_evaluation_runtime(
511529
report_path: str,
512530
model_parameters: dict,
513531
metrics: List = None,
532+
inference_configuration: dict = None,
514533
) -> Runtime:
515534
"""Builds evaluation runtime for Job."""
516535
# TODO the image name needs to be extracted from the mapping index.json file.
@@ -520,16 +539,19 @@ def _build_evaluation_runtime(
520539
.with_environment_variable(
521540
**{
522541
"AIP_SMC_EVALUATION_ARGUMENTS": json.dumps(
523-
asdict(
524-
self._build_launch_cmd(
525-
evaluation_id=evaluation_id,
526-
evaluation_source_id=evaluation_source_id,
527-
dataset_path=dataset_path,
528-
report_path=report_path,
529-
model_parameters=model_parameters,
530-
metrics=metrics,
531-
)
532-
)
542+
{
543+
**asdict(
544+
self._build_launch_cmd(
545+
evaluation_id=evaluation_id,
546+
evaluation_source_id=evaluation_source_id,
547+
dataset_path=dataset_path,
548+
report_path=report_path,
549+
model_parameters=model_parameters,
550+
metrics=metrics,
551+
),
552+
),
553+
**inference_configuration,
554+
},
533555
),
534556
"CONDA_BUCKET_NS": CONDA_BUCKET_NS,
535557
},

ads/aqua/ui.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (c) 2024 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
import concurrent.futures
5-
from dataclasses import dataclass, field
5+
from dataclasses import dataclass, field, fields
66
from datetime import datetime, timedelta
77
from enum import Enum
88
from threading import Lock
@@ -45,13 +45,48 @@ def to_dict(self):
4545
# within ads.aqua.common.entities. In that case, check for circular imports due to usage of get_container_config.
4646

4747

48+
@dataclass(repr=False)
49+
class AquaContainerEvaluationConfig(DataClassSerializable):
50+
"""
51+
Represents the evaluation configuration for the container.
52+
"""
53+
54+
inference_max_threads: Optional[int] = None
55+
inference_rps: Optional[int] = None
56+
inference_timeout: Optional[int] = None
57+
inference_retries: Optional[int] = None
58+
inference_backoff_factor: Optional[int] = None
59+
inference_delay: Optional[int] = None
60+
61+
@classmethod
62+
def from_config(cls, config: dict) -> "AquaContainerEvaluationConfig":
63+
return cls(
64+
inference_max_threads=config.get("inference_max_threads"),
65+
inference_rps=config.get("inference_rps"),
66+
inference_timeout=config.get("inference_timeout"),
67+
inference_retries=config.get("inference_retries"),
68+
inference_backoff_factor=config.get("inference_backoff_factor"),
69+
inference_delay=config.get("inference_delay"),
70+
)
71+
72+
def to_filtered_dict(self):
73+
return {
74+
field.name: getattr(self, field.name)
75+
for field in fields(self)
76+
if getattr(self, field.name) is not None
77+
}
78+
79+
4880
@dataclass(repr=False)
4981
class AquaContainerConfigSpec(DataClassSerializable):
5082
cli_param: str = None
5183
server_port: str = None
5284
health_check_port: str = None
5385
env_vars: List[dict] = None
5486
restricted_params: List[str] = None
87+
evaluation_configuration: AquaContainerEvaluationConfig = field(
88+
default_factory=AquaContainerEvaluationConfig
89+
)
5590

5691

5792
@dataclass(repr=False)
@@ -161,6 +196,11 @@ def from_container_index_json(
161196
restricted_params=container_spec.get(
162197
ContainerSpec.RESTRICTED_PARAMS, []
163198
),
199+
evaluation_configuration=AquaContainerEvaluationConfig.from_config(
200+
container_spec.get(
201+
ContainerSpec.EVALUATION_CONFIGURATION, {}
202+
)
203+
),
164204
)
165205
if container_spec
166206
else None,

tests/unitary/with_extras/aqua/test_data/ui/container_index.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
"HEALTH_CHECK_PORT": "8080"
2020
}
2121
],
22+
"evaluationConfiguration": {
23+
"inference_delay": 1,
24+
"inference_max_threads": 1
25+
},
2226
"healthCheckPort": "8080",
2327
"restrictedParams": [],
2428
"serverPort": "8080"
@@ -81,7 +85,7 @@
8185
"modelFormats": [
8286
"GGUF"
8387
],
84-
"name": "iad.ocir.io/ociodscdev/odsc-llama-cpp-python-aio-linux_arm64_v8",
88+
"name": "dsmc://odsc-llama-cpp-python-aio-linux_arm64_v8",
8589
"platforms": [
8690
"ARM_CPU"
8791
],

tests/unitary/with_extras/aqua/test_ui.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def test_list_containers(self, mock_get_container_config):
484484
],
485485
"inference": [
486486
{
487-
"name": "iad.ocir.io/ociodscdev/odsc-llama-cpp-python-aio-linux_arm64_v8",
487+
"name": "dsmc://odsc-llama-cpp-python-aio-linux_arm64_v8",
488488
"version": "0.2.75.5",
489489
"display_name": "LLAMA-CPP:0.2.75",
490490
"family": "odsc-llama-cpp-serving",
@@ -502,6 +502,14 @@ def test_list_containers(self, mock_get_container_config):
502502
"health_check_port": "8080",
503503
"restricted_params": [],
504504
"server_port": "8080",
505+
"evaluation_configuration": {
506+
"inference_max_threads": 1,
507+
"inference_rps": None,
508+
"inference_timeout": None,
509+
"inference_backoff_factor": None,
510+
"inference_delay": 1,
511+
"inference_retries": None,
512+
},
505513
},
506514
},
507515
{
@@ -528,6 +536,14 @@ def test_list_containers(self, mock_get_container_config):
528536
"--trust-remote-code",
529537
],
530538
"server_port": "8080",
539+
"evaluation_configuration": {
540+
"inference_max_threads": None,
541+
"inference_rps": None,
542+
"inference_timeout": None,
543+
"inference_backoff_factor": None,
544+
"inference_delay": None,
545+
"inference_retries": None,
546+
},
531547
},
532548
},
533549
{
@@ -553,6 +569,14 @@ def test_list_containers(self, mock_get_container_config):
553569
"--seed",
554570
],
555571
"server_port": "8080",
572+
"evaluation_configuration": {
573+
"inference_max_threads": None,
574+
"inference_rps": None,
575+
"inference_timeout": None,
576+
"inference_backoff_factor": None,
577+
"inference_delay": None,
578+
"inference_retries": None,
579+
},
556580
},
557581
},
558582
],

0 commit comments

Comments
 (0)