Skip to content

Commit 5b58330

Browse files
committed
Updating max threads inference
1 parent 158efdd commit 5b58330

File tree

5 files changed

+73
-14
lines changed

5 files changed

+73
-14
lines changed

ads/aqua/common/entities.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ class ContainerSpec:
1414
HEALTH_CHECK_PORT = "healthCheckPort"
1515
ENV_VARS = "envVars"
1616
RESTRICTED_PARAMS = "restrictedParams"
17+
EVALUATION_CONFIGURATION = "evaluationConfiguration"

ads/aqua/evaluation/entities.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from dataclasses import dataclass, field
1313
from typing import List, Optional, Union
1414

15-
from ads.aqua.constants import EVALUATION_INFERENCE_DEFAULT_THREADS
1615
from ads.aqua.data import AquaResourceIdentifier
1716
from ads.common.serializer import DataClassSerializable
1817

@@ -88,7 +87,7 @@ class CreateAquaEvaluationDetails(DataClassSerializable):
8887
log_id: Optional[str] = None
8988
metrics: Optional[List] = None
9089
force_overwrite: Optional[bool] = False
91-
inference_max_threads: Optional[int] = EVALUATION_INFERENCE_DEFAULT_THREADS
90+
inference_max_threads: Optional[int] = None
9291

9392

9493
@dataclass(repr=False)

ads/aqua/evaluation/evaluation.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,9 @@
1313
from threading import Lock
1414
from typing import Any, Dict, List, Union
1515

16-
import oci
1716
from cachetools import TTLCache
18-
from oci.data_science.models import (
19-
JobRun,
20-
Metadata,
21-
UpdateModelDetails,
22-
UpdateModelProvenanceDetails,
23-
)
2417

18+
import oci
2519
from ads.aqua import logger
2620
from ads.aqua.app import AquaApp
2721
from ads.aqua.common import utils
@@ -47,6 +41,7 @@
4741
)
4842
from ads.aqua.constants import (
4943
CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
44+
EVALUATION_INFERENCE_DEFAULT_THREADS,
5045
EVALUATION_REPORT,
5146
EVALUATION_REPORT_JSON,
5247
EVALUATION_REPORT_MD,
@@ -76,6 +71,7 @@
7671
ModelParams,
7772
)
7873
from ads.aqua.evaluation.errors import EVALUATION_JOB_EXIT_CODE_MESSAGE
74+
from ads.aqua.ui import AquaContainerConfig
7975
from ads.common.auth import default_signer
8076
from ads.common.object_storage_details import ObjectStorageDetails
8177
from ads.common.utils import get_console_link, get_files, get_log_links
@@ -90,7 +86,9 @@
9086
from ads.jobs.builders.runtimes.base import Runtime
9187
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
9288
from ads.model.datascience_model import DataScienceModel
89+
from ads.model.deployment import ModelDeploymentContainerRuntime
9390
from ads.model.deployment.model_deployment import ModelDeployment
91+
from ads.model.generic_model import ModelDeploymentRuntimeType
9492
from ads.model.model_metadata import (
9593
MetadataTaxonomyKeys,
9694
ModelCustomMetadata,
@@ -99,6 +97,12 @@
9997
)
10098
from ads.model.model_version_set import ModelVersionSet
10199
from ads.telemetry import telemetry
100+
from oci.data_science.models import (
101+
JobRun,
102+
Metadata,
103+
UpdateModelDetails,
104+
UpdateModelProvenanceDetails,
105+
)
102106

103107

104108
class AquaEvaluationApp(AquaApp):
@@ -166,7 +170,6 @@ def create(
166170
f"Invalid evaluation source {create_aqua_evaluation_details.evaluation_source_id}. "
167171
"Specify either a model or model deployment id."
168172
)
169-
170173
evaluation_source = None
171174
if (
172175
DataScienceResource.MODEL_DEPLOYMENT
@@ -175,6 +178,33 @@ def create(
175178
evaluation_source = ModelDeployment.from_id(
176179
create_aqua_evaluation_details.evaluation_source_id
177180
)
181+
if evaluation_source.runtime.type == ModelDeploymentRuntimeType.CONTAINER:
182+
runtime = ModelDeploymentContainerRuntime.from_dict(
183+
evaluation_source.runtime.to_dict()
184+
)
185+
container_config = AquaContainerConfig.from_container_index_json(
186+
enable_spec=True
187+
)
188+
for container in container_config.inference.values():
189+
if container.name == runtime.image.split(":")[0]:
190+
max_threads = container.spec.evaluation_configuration.evaluation_max_threads
191+
if (
192+
max_threads
193+
and create_aqua_evaluation_details.inference_max_threads
194+
and max_threads
195+
< create_aqua_evaluation_details.inference_max_threads
196+
):
197+
raise AquaValueError(
198+
f"Invalid inference max threads. The maximum allowed value for {runtime.image} is {max_threads}."
199+
)
200+
if not create_aqua_evaluation_details.inference_max_threads:
201+
create_aqua_evaluation_details.inference_max_threads = container.spec.evaluation_configuration.evaluation_default_threads
202+
break
203+
if not create_aqua_evaluation_details.inference_max_threads:
204+
create_aqua_evaluation_details.inference_max_threads = (
205+
EVALUATION_INFERENCE_DEFAULT_THREADS
206+
)
207+
178208
elif (
179209
DataScienceResource.MODEL
180210
in create_aqua_evaluation_details.evaluation_source_id
@@ -1197,7 +1227,7 @@ def _delete_job_and_model(job, model):
11971227
f"Exception message: {ex}"
11981228
)
11991229

1200-
def load_evaluation_config(self, eval_id):
1230+
def load_evaluation_config(self, _):
12011231
"""Loads evaluation config."""
12021232
return {
12031233
"model_params": {

ads/aqua/modeldeployment/deployment.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,6 @@ def create(
167167
Tags.AQUA_SERVICE_MODEL_TAG,
168168
Tags.AQUA_FINE_TUNED_MODEL_TAG,
169169
Tags.AQUA_TAG,
170-
Tags.MODEL_FORMAT,
171170
]:
172171
if tag in aqua_model.freeform_tags:
173172
tags[tag] = aqua_model.freeform_tags[tag]

ads/aqua/ui.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@
99
from typing import Dict, List, Optional
1010

1111
from cachetools import TTLCache
12-
from oci.exceptions import ServiceError
13-
from oci.identity.models import Compartment
1412

1513
from ads.aqua import logger
1614
from ads.aqua.app import AquaApp
1715
from ads.aqua.common.entities import ContainerSpec
1816
from ads.aqua.common.enums import Tags
1917
from ads.aqua.common.errors import AquaResourceAccessError, AquaValueError
2018
from ads.aqua.common.utils import get_container_config, load_config, sanitize_response
19+
from ads.aqua.constants import EVALUATION_INFERENCE_DEFAULT_THREADS
2120
from ads.common import oci_client as oc
2221
from ads.common.auth import default_signer
2322
from ads.common.object_storage_details import ObjectStorageDetails
@@ -30,6 +29,8 @@
3029
TENANCY_OCID,
3130
)
3231
from ads.telemetry import telemetry
32+
from oci.exceptions import ServiceError
33+
from oci.identity.models import Compartment
3334

3435

3536
class ModelFormat(Enum):
@@ -45,13 +46,37 @@ def to_dict(self):
4546
# within ads.aqua.common.entities. In that case, check for circular imports due to usage of get_container_config.
4647

4748

49+
@dataclass(repr=False)
50+
class AquaContainerEvaluationConfiguration(DataClassSerializable):
51+
"""
52+
Represents the evaluation configuration for the container.
53+
"""
54+
55+
evaluation_max_threads: Optional[int] = None
56+
evaluation_default_threads: int = field(
57+
default=EVALUATION_INFERENCE_DEFAULT_THREADS
58+
)
59+
60+
@classmethod
61+
def from_config(cls, config: dict) -> "AquaContainerEvaluationConfiguration":
62+
return cls(
63+
evaluation_max_threads=config.get("MAX_THREADS"),
64+
evaluation_default_threads=config.get(
65+
"DEFAULT_THREADS", EVALUATION_INFERENCE_DEFAULT_THREADS
66+
),
67+
)
68+
69+
4870
@dataclass(repr=False)
4971
class AquaContainerConfigSpec(DataClassSerializable):
5072
cli_param: str = None
5173
server_port: str = None
5274
health_check_port: str = None
5375
env_vars: List[dict] = None
5476
restricted_params: List[str] = None
77+
evaluation_configuration: AquaContainerEvaluationConfiguration = field(
78+
default_factory=AquaContainerEvaluationConfiguration
79+
)
5580

5681

5782
@dataclass(repr=False)
@@ -161,6 +186,11 @@ def from_container_index_json(
161186
restricted_params=container_spec.get(
162187
ContainerSpec.RESTRICTED_PARAMS, []
163188
),
189+
evaluation_configuration=AquaContainerEvaluationConfiguration.from_config(
190+
container_spec.get(
191+
ContainerSpec.EVALUATION_CONFIGURATION, {}
192+
)
193+
),
164194
)
165195
if container_spec
166196
else None,

0 commit comments

Comments
 (0)