Skip to content

Commit 9d51d44

Browse files
mrDzurbelizjolu-ohai
authored
[Ready For Review][AQUA] Add Supporting Fine-Tuned Models in Multi-Model Deployment (#1186)
Co-authored-by: Liz Johnson <liz.j.johnson@oracle.com> Co-authored-by: Lu Peng <118394507+lu-ohai@users.noreply.github.com>
1 parent 1024cca commit 9d51d44

File tree

15 files changed

+1499
-917
lines changed

15 files changed

+1499
-917
lines changed

ads/aqua/app.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ads.aqua import logger
2323
from ads.aqua.common.entities import ModelConfigResult
2424
from ads.aqua.common.enums import ConfigFolder, Tags
25-
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
25+
from ads.aqua.common.errors import AquaValueError
2626
from ads.aqua.common.utils import (
2727
_is_valid_mvs,
2828
get_artifact_path,
@@ -284,8 +284,11 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
284284
logger.info(f"Artifact not found in model {model_id}.")
285285
return False
286286

287+
@cached(cache=TTLCache(maxsize=5, ttl=timedelta(minutes=1), timer=datetime.now))
287288
def get_config_from_metadata(
288-
self, model_id: str, metadata_key: str
289+
self,
290+
model_id: str,
291+
metadata_key: str,
289292
) -> ModelConfigResult:
290293
"""Gets the config for the given Aqua model from model catalog metadata content.
291294
@@ -300,8 +303,9 @@ def get_config_from_metadata(
300303
ModelConfigResult
301304
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
302305
"""
303-
config = {}
306+
config: Dict[str, Any] = {}
304307
oci_model = self.ds_client.get_model(model_id).data
308+
305309
try:
306310
config = self.ds_client.get_model_defined_metadatum_artifact_content(
307311
model_id, metadata_key
@@ -321,7 +325,7 @@ def get_config_from_metadata(
321325
)
322326
return ModelConfigResult(config=config, model_details=oci_model)
323327

324-
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
328+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=5), timer=datetime.now))
325329
def get_config(
326330
self,
327331
model_id: str,
@@ -346,8 +350,10 @@ def get_config(
346350
ModelConfigResult
347351
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
348352
"""
349-
config_folder = config_folder or ConfigFolder.CONFIG
353+
config: Dict[str, Any] = {}
350354
oci_model = self.ds_client.get_model(model_id).data
355+
356+
config_folder = config_folder or ConfigFolder.CONFIG
351357
oci_aqua = (
352358
(
353359
Tags.AQUA_TAG in oci_model.freeform_tags
@@ -357,9 +363,9 @@ def get_config(
357363
else False
358364
)
359365
if not oci_aqua:
360-
raise AquaRuntimeError(f"Target model {oci_model.id} is not an Aqua model.")
366+
logger.debug(f"Target model {oci_model.id} is not an Aqua model.")
367+
return ModelConfigResult(config=config, model_details=oci_model)
361368

362-
config: Dict[str, Any] = {}
363369
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
364370
if not artifact_path:
365371
logger.debug(

ads/aqua/common/entities.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
import re
6-
from typing import Any, Dict, Optional
6+
from typing import Any, Dict, List, Optional
77

88
from oci.data_science.models import Model
99
from pydantic import BaseModel, Field, model_validator
@@ -136,6 +136,28 @@ def set_gpu_specs(cls, model: "ComputeShapeSummary") -> "ComputeShapeSummary":
136136
return model
137137

138138

139+
class LoraModuleSpec(Serializable):
140+
"""
141+
Lightweight descriptor for LoRA Modules used in fine-tuning models.
142+
143+
Attributes
144+
----------
145+
model_id : str
146+
The unique identifier of the fine tuned model.
147+
model_name : str
148+
The name of the fine-tuned model.
149+
model_path : str
150+
The model-by-reference path to the LoRA Module within the model artifact
151+
"""
152+
153+
model_id: Optional[str] = Field(None, description="The fine tuned model OCID to deploy.")
154+
model_name: Optional[str] = Field(None, description="The name of the fine-tuned model.")
155+
model_path: Optional[str] = Field(
156+
None,
157+
description="The model-by-reference path to the LoRA Module within the model artifact.",
158+
)
159+
160+
139161
class AquaMultiModelRef(Serializable):
140162
"""
141163
Lightweight model descriptor used for multi-model deployment.
@@ -157,7 +179,7 @@ class AquaMultiModelRef(Serializable):
157179
Optional environment variables to override during deployment.
158180
artifact_location : Optional[str]
159181
Artifact path of model in the multimodel group.
160-
fine_tune_weights_location : Optional[str]
182+
fine_tune_weights : Optional[List[LoraModuleSpec]]
161183
For fine tuned models, the artifact path of the modified model weights
162184
"""
163185

@@ -166,15 +188,19 @@ class AquaMultiModelRef(Serializable):
166188
gpu_count: Optional[int] = Field(
167189
None, description="The gpu count allocation for the model."
168190
)
169-
model_task: Optional[str] = Field(None, description="The task that model operates on. Supported tasks are in MultiModelSupportedTaskType")
191+
model_task: Optional[str] = Field(
192+
None,
193+
description="The task that model operates on. Supported tasks are in MultiModelSupportedTaskType",
194+
)
170195
env_var: Optional[dict] = Field(
171196
default_factory=dict, description="The environment variables of the model."
172197
)
173198
artifact_location: Optional[str] = Field(
174199
None, description="Artifact path of model in the multimodel group."
175200
)
176-
fine_tune_weights_location: Optional[str] = Field(
177-
None, description="For fine tuned models, the artifact path of the modified model weights"
201+
fine_tune_weights: Optional[List[LoraModuleSpec]] = Field(
202+
None,
203+
description="For fine tuned models, the artifact path of the modified model weights",
178204
)
179205

180206
class Config:

ads/aqua/common/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,41 @@ def get_combined_params(params1: str = None, params2: str = None) -> str:
870870
return " ".join(combined_params)
871871

872872

873+
def find_restricted_params(
874+
default_params: Union[str, List[str]],
875+
user_params: Union[str, List[str]],
876+
container_family: str,
877+
) -> List[str]:
878+
"""Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
879+
The default parameters coming from the container index json file cannot be overridden.
880+
881+
Parameters
882+
----------
883+
default_params:
884+
Inference container parameter string with default values.
885+
user_params:
886+
Inference container parameter string with user provided values.
887+
container_family: str
888+
The image family of model deployment container runtime.
889+
890+
Returns
891+
-------
892+
A list with params keys common between params1 and params2.
893+
894+
"""
895+
restricted_params = []
896+
if default_params and user_params:
897+
default_params_dict = get_params_dict(default_params)
898+
user_params_dict = get_params_dict(user_params)
899+
900+
restricted_params_set = get_restricted_params_by_container(container_family)
901+
for key, _items in user_params_dict.items():
902+
if key in default_params_dict or key in restricted_params_set:
903+
restricted_params.append(key.lstrip("-"))
904+
905+
return restricted_params
906+
907+
873908
def build_params_string(params: dict) -> str:
874909
"""Builds params string from params dict
875910

ads/aqua/evaluation/evaluation.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -727,10 +727,11 @@ def validate_model_name(
727727
raise AquaRuntimeError(error_message) from ex
728728

729729
# Build the list of valid model names from custom metadata.
730-
model_names = [
731-
AquaMultiModelRef(**metadata).model_name
732-
for metadata in multi_model_metadata
733-
]
730+
model_names = []
731+
for metadata in multi_model_metadata:
732+
model = AquaMultiModelRef(**metadata)
733+
model_names.append(model.model_name)
734+
model_names.extend(ft.model_name for ft in (model.fine_tune_weights or []) if ft.model_name)
734735

735736
# Check if the provided model name is among the valid names.
736737
if user_model_name not in model_names:

ads/aqua/model/enums.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,24 @@ class FineTuningCustomMetadata(ExtendedEnum):
2727

2828
class MultiModelSupportedTaskType(ExtendedEnum):
2929
TEXT_GENERATION = "text_generation"
30+
TEXT_GENERATION_INFERENCE = "text_generation_inference"
31+
TEXT2TEXT_GENERATION = "text2text_generation"
32+
SUMMARIZATION = "summarization"
33+
TRANSLATION = "translation"
34+
CONVERSATIONAL = "conversational"
35+
FEATURE_EXTRACTION = "feature_extraction"
36+
SENTENCE_SIMILARITY = "sentence_similarity"
37+
AUTOMATIC_SPEECH_RECOGNITION = "automatic_speech_recognition"
38+
TEXT_TO_SPEECH = "text_to_speech"
39+
TEXT_TO_IMAGE = "text_to_image"
40+
TEXT_EMBEDDING = "text_embedding"
3041
IMAGE_TEXT_TO_TEXT = "image_text_to_text"
3142
CODE_SYNTHESIS = "code_synthesis"
32-
EMBEDDING = "text_embedding"
43+
QUESTION_ANSWERING = "question_answering"
44+
AUDIO_CLASSIFICATION = "audio_classification"
45+
AUDIO_TO_AUDIO = "audio_to_audio"
46+
IMAGE_CLASSIFICATION = "image_classification"
47+
IMAGE_TO_TEXT = "image_to_text"
48+
IMAGE_TO_IMAGE = "image_to_image"
49+
VIDEO_CLASSIFICATION = "video_classification"
50+
TIME_SERIES_FORECASTING = "time_series_forecasting"

ads/aqua/model/model.py

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from ads.aqua import logger
1818
from ads.aqua.app import AquaApp
19-
from ads.aqua.common.entities import AquaMultiModelRef
19+
from ads.aqua.common.entities import AquaMultiModelRef, LoraModuleSpec
2020
from ads.aqua.common.enums import (
2121
ConfigFolder,
2222
CustomInferenceContainerTypeFamily,
@@ -89,12 +89,7 @@
8989
)
9090
from ads.common.auth import default_signer
9191
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
92-
from ads.common.utils import (
93-
UNKNOWN,
94-
get_console_link,
95-
is_path_exists,
96-
read_file,
97-
)
92+
from ads.common.utils import UNKNOWN, get_console_link, is_path_exists, read_file
9893
from ads.config import (
9994
AQUA_DEPLOYMENT_CONTAINER_CMD_VAR_METADATA_NAME,
10095
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,
@@ -300,57 +295,73 @@ def create_multi(
300295

301296
selected_models_deployment_containers = set()
302297

303-
# Process each model
298+
# Process each model in the input list
304299
for model in models:
300+
# Retrieve model metadata from the Model Catalog using the model ID
305301
source_model = DataScienceModel.from_id(model.model_id)
306302
display_name = source_model.display_name
307303
model_file_description = source_model.model_file_description
308-
# Update model name in user's input model
304+
# If model_name is not explicitly provided, use the model's display name
309305
model.model_name = model.model_name or display_name
310306

311-
# TODO Uncomment the section below, if only service models should be allowed for multi-model deployment
312-
# if not source_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, UNKNOWN):
313-
# raise AquaValueError(
314-
# f"Invalid selected model {display_name}. "
315-
# "Currently only service models are supported for multi model deployment."
316-
# )
307+
if not model_file_description:
308+
raise AquaValueError(
309+
f"Model '{source_model.display_name}' (ID: {model.model_id}) has no file description. "
310+
"Please register the model first."
311+
)
317312

318-
# check if model is a fine-tuned model and if so, add the fine tuned weights path to the fine_tune_weights_location pydantic field
313+
# Check if the model is a fine-tuned model based on its tags
319314
is_fine_tuned_model = (
320315
Tags.AQUA_FINE_TUNED_MODEL_TAG in source_model.freeform_tags
321316
)
322317

318+
base_model_artifact_path = ""
319+
fine_tune_path = ""
320+
323321
if is_fine_tuned_model:
324-
model.model_id, model.model_name = extract_base_model_from_ft(
325-
source_model
326-
)
327-
model_artifact_path, model.fine_tune_weights_location = (
322+
# Extract artifact paths for the base and fine-tuned model
323+
base_model_artifact_path, fine_tune_path = (
328324
extract_fine_tune_artifacts_path(source_model)
329325
)
330326

331-
else:
332-
# Retrieve model artifact for base models
333-
model_artifact_path = source_model.artifact
327+
# Create a single LoRA module specification for the fine-tuned model
328+
# TODO: Support multiple LoRA modules in the future
329+
model.fine_tune_weights = [
330+
LoraModuleSpec(
331+
model_id=model.model_id,
332+
model_name=model.model_name,
333+
model_path=fine_tune_path,
334+
)
335+
]
334336

335-
display_name_list.append(display_name)
337+
# Use the LoRA module name as the model's display name
338+
display_name = model.model_name
336339

337-
self._extract_model_task(model, source_model)
340+
# Temporarily override model ID and name with those of the base model
341+
# TODO: Revisit this logic once proper base/FT model handling is implemented
342+
model.model_id, model.model_name = extract_base_model_from_ft(
343+
source_model
344+
)
345+
else:
346+
# For base models, use the original artifact path
347+
base_model_artifact_path = source_model.artifact
348+
display_name = model.model_name
338349

339-
if not model_artifact_path:
350+
if not base_model_artifact_path:
351+
# Fail if no artifact is found for the base model model
340352
raise AquaValueError(
341-
f"Model '{display_name}' (ID: {model.model_id}) has no artifacts. "
353+
f"Model '{model.model_name}' (ID: {model.model_id}) has no artifacts. "
342354
"Please register the model first."
343355
)
344356

345-
# Update model artifact location in user's input model
346-
model.artifact_location = model_artifact_path
357+
# Update the artifact path in the model configuration
358+
model.artifact_location = base_model_artifact_path
359+
display_name_list.append(display_name)
347360

348-
if not model_file_description:
349-
raise AquaValueError(
350-
f"Model '{display_name}' (ID: {model.model_id}) has no file description. "
351-
"Please register the model first."
352-
)
361+
# Extract model task metadata from source model
362+
self._extract_model_task(model, source_model)
353363

364+
# Track model file description in a validated structure
354365
model_file_description_list.append(
355366
ModelFileDescription(**model_file_description)
356367
)

ads/aqua/model/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
"""AQUA model utils"""
55

6-
from typing import Dict, Optional, Tuple
6+
from typing import Tuple
77

8-
from ads.aqua.common.entities import AquaMultiModelRef
98
from ads.aqua.common.errors import AquaValueError
109
from ads.aqua.common.utils import get_model_by_reference_paths
1110
from ads.aqua.finetuning.constants import FineTuneCustomMetadata

0 commit comments

Comments
 (0)