Skip to content

Commit 1857aff

Browse files
add support to manage concurrent eval requests
1 parent ca3c278 commit 1857aff

File tree

4 files changed

+12
-1
lines changed

4 files changed

+12
-1
lines changed

ads/aqua/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
EVALUATION_REPORT_JSON = "report.json"
1616
EVALUATION_REPORT_MD = "report.md"
1717
EVALUATION_REPORT = "report.html"
18+
EVALUATION_INFERENCE_DEFAULT_THREADS = 10
1819
UNKNOWN_JSON_STR = "{}"
1920
FINE_TUNING_RUNTIME_CONTAINER = "iad.ocir.io/ociodscdev/aqua_ft_cuda121:0.3.17.20"
2021
DEFAULT_FT_BLOCK_STORAGE_SIZE = 750

ads/aqua/evaluation/entities.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

@@ -13,6 +12,7 @@
1312
from dataclasses import dataclass, field
1413
from typing import List, Optional, Union
1514

15+
from ads.aqua.constants import EVALUATION_INFERENCE_DEFAULT_THREADS
1616
from ads.aqua.data import AquaResourceIdentifier
1717
from ads.common.serializer import DataClassSerializable
1818

@@ -65,6 +65,8 @@ class CreateAquaEvaluationDetails(DataClassSerializable):
6565
The metrics for the evaluation.
6666
force_overwrite: (bool, optional). Defaults to `False`.
6767
Whether to force overwrite the existing file in object storage.
68+
inference_max_threads: (int, optional). Defaults to None
69+
Set the value of concurrent requests to be made to the inference endpoint during evaluation.
6870
"""
6971

7072
evaluation_source_id: str
@@ -86,6 +88,7 @@ class CreateAquaEvaluationDetails(DataClassSerializable):
8688
log_id: Optional[str] = None
8789
metrics: Optional[List] = None
8890
force_overwrite: Optional[bool] = False
91+
inference_max_threads: Optional[int] = EVALUATION_INFERENCE_DEFAULT_THREADS
8992

9093

9194
@dataclass(repr=False)
@@ -142,6 +145,7 @@ class AquaEvaluationCommands(DataClassSerializable):
142145
metrics: list
143146
output_dir: str
144147
params: dict
148+
inference_max_threads: int
145149

146150

147151
@dataclass(repr=False)

ads/aqua/evaluation/evaluation.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ def create(
390390
report_path=create_aqua_evaluation_details.report_path,
391391
model_parameters=create_aqua_evaluation_details.model_parameters,
392392
metrics=create_aqua_evaluation_details.metrics,
393+
inference_max_threads=create_aqua_evaluation_details.inference_max_threads,
393394
)
394395
).create(**kwargs) ## TODO: decide what parameters will be needed
395396
logger.debug(
@@ -511,6 +512,7 @@ def _build_evaluation_runtime(
511512
report_path: str,
512513
model_parameters: dict,
513514
metrics: List = None,
515+
inference_max_threads: int = None,
514516
) -> Runtime:
515517
"""Builds evaluation runtime for Job."""
516518
# TODO the image name needs to be extracted from the mapping index.json file.
@@ -528,6 +530,7 @@ def _build_evaluation_runtime(
528530
report_path=report_path,
529531
model_parameters=model_parameters,
530532
metrics=metrics,
533+
inference_max_threads=inference_max_threads,
531534
)
532535
)
533536
),
@@ -587,6 +590,7 @@ def _build_launch_cmd(
587590
report_path: str,
588591
model_parameters: dict,
589592
metrics: List = None,
593+
inference_max_threads: int = None,
590594
):
591595
return AquaEvaluationCommands(
592596
evaluation_id=evaluation_id,
@@ -603,6 +607,7 @@ def _build_launch_cmd(
603607
metrics=metrics,
604608
output_dir=report_path,
605609
params=model_parameters,
610+
inference_max_threads=inference_max_threads,
606611
)
607612

608613
@telemetry(entry_point="plugin=evaluation&action=get", name="aqua")

ads/aqua/modeldeployment/deployment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def create(
167167
Tags.AQUA_SERVICE_MODEL_TAG,
168168
Tags.AQUA_FINE_TUNED_MODEL_TAG,
169169
Tags.AQUA_TAG,
170+
Tags.MODEL_FORMAT,
170171
]:
171172
if tag in aqua_model.freeform_tags:
172173
tags[tag] = aqua_model.freeform_tags[tag]

0 commit comments

Comments
 (0)