Skip to content

Commit f1afde5

Browse files
authored
Merge branch 'main' into feature/custom_chat_template
2 parents 1df6fff + 95c5a5d commit f1afde5

File tree

8 files changed

+458
-126
lines changed

8 files changed

+458
-126
lines changed

ads/aqua/app.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
import os
77
import traceback
8+
from concurrent.futures import ThreadPoolExecutor
89
from dataclasses import fields
910
from datetime import datetime, timedelta
1011
from itertools import chain
@@ -58,6 +59,8 @@
5859
class AquaApp:
5960
"""Base Aqua App to contain common components."""
6061

62+
MAX_WORKERS = 10 # Number of workers for asynchronous resource loading
63+
6164
@telemetry(name="aqua")
6265
def __init__(self) -> None:
6366
if OCI_RESOURCE_PRINCIPAL_VERSION:
@@ -128,20 +131,69 @@ def update_model_provenance(
128131
update_model_provenance_details=update_model_provenance_details,
129132
)
130133

131-
# TODO: refactor model evaluation implementation to use it.
132134
@staticmethod
133135
def get_source(source_id: str) -> Union[ModelDeployment, DataScienceModel]:
134-
if is_valid_ocid(source_id):
135-
if "datasciencemodeldeployment" in source_id:
136-
return ModelDeployment.from_id(source_id)
137-
elif "datasciencemodel" in source_id:
138-
return DataScienceModel.from_id(source_id)
136+
"""
137+
Fetches a model or model deployment based on the provided OCID.
138+
139+
Parameters
140+
----------
141+
source_id : str
142+
OCID of the Data Science model or model deployment.
143+
144+
Returns
145+
-------
146+
Union[ModelDeployment, DataScienceModel]
147+
The corresponding resource object.
148+
149+
Raises
150+
------
151+
AquaValueError
152+
If the OCID is invalid or unsupported.
153+
"""
154+
logger.debug(f"Resolving source for ID: {source_id}")
155+
if not is_valid_ocid(source_id):
156+
logger.error(f"Invalid OCID format: {source_id}")
157+
raise AquaValueError(
158+
f"Invalid source ID: {source_id}. Please provide a valid model or model deployment OCID."
159+
)
160+
161+
if "datasciencemodeldeployment" in source_id:
162+
logger.debug(f"Identified as ModelDeployment OCID: {source_id}")
163+
return ModelDeployment.from_id(source_id)
139164

165+
if "datasciencemodel" in source_id:
166+
logger.debug(f"Identified as DataScienceModel OCID: {source_id}")
167+
return DataScienceModel.from_id(source_id)
168+
169+
logger.error(f"Unrecognized OCID type: {source_id}")
140170
raise AquaValueError(
141-
f"Invalid source {source_id}. "
142-
"Specify either a model or model deployment id."
171+
f"Unsupported source ID type: {source_id}. Must be a model or model deployment OCID."
143172
)
144173

174+
def get_multi_source(
175+
self,
176+
ids: List[str],
177+
) -> Dict[str, Union[ModelDeployment, DataScienceModel]]:
178+
"""
179+
Retrieves multiple DataScience resources concurrently.
180+
181+
Parameters
182+
----------
183+
ids : List[str]
184+
A list of DataScience OCIDs.
185+
186+
Returns
187+
-------
188+
Dict[str, Union[ModelDeployment, DataScienceModel]]
189+
A mapping from OCID to the corresponding resolved resource object.
190+
"""
191+
logger.debug(f"Fetching {ids} sources in parallel.")
192+
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
193+
results = list(executor.map(self.get_source, ids))
194+
195+
return dict(zip(ids, results))
196+
145197
# TODO: refactor model evaluation implementation to use it.
146198
@staticmethod
147199
def create_model_version_set(

ads/aqua/common/entities.py

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, Dict, List, Optional
77

88
from oci.data_science.models import Model
9-
from pydantic import BaseModel, Field, model_validator
9+
from pydantic import BaseModel, ConfigDict, Field, model_validator
1010

1111
from ads.aqua import logger
1212
from ads.aqua.config.utils.serializer import Serializable
@@ -80,24 +80,29 @@ class GPUShapesIndex(Serializable):
8080

8181
class ComputeShapeSummary(Serializable):
8282
"""
83-
Represents the specifications of a compute instance's shape.
83+
Represents the specifications of a compute instance shape,
84+
including CPU, memory, and optional GPU characteristics.
8485
"""
8586

8687
core_count: Optional[int] = Field(
87-
default=None, description="The number of CPU cores available."
88+
default=None,
89+
description="Total number of CPU cores available for the compute shape.",
8890
)
8991
memory_in_gbs: Optional[int] = Field(
90-
default=None, description="The amount of memory (in GB) available."
92+
default=None,
93+
description="Amount of memory (in GB) available for the compute shape.",
9194
)
9295
name: Optional[str] = Field(
93-
default=None, description="The name identifier of the compute shape."
96+
default=None,
97+
description="Full name of the compute shape, e.g., 'VM.GPU.A10.2'.",
9498
)
9599
shape_series: Optional[str] = Field(
96-
default=None, description="The series or category of the compute shape."
100+
default=None,
101+
description="Shape family or series, e.g., 'GPU', 'Standard', etc.",
97102
)
98103
gpu_specs: Optional[GPUSpecs] = Field(
99104
default=None,
100-
description="The GPU specifications associated with the compute shape.",
105+
description="Optional GPU specifications associated with the shape.",
101106
)
102107

103108
@model_validator(mode="after")
@@ -136,27 +141,46 @@ def set_gpu_specs(cls, model: "ComputeShapeSummary") -> "ComputeShapeSummary":
136141
return model
137142

138143

139-
class LoraModuleSpec(Serializable):
144+
class LoraModuleSpec(BaseModel):
140145
"""
141-
Lightweight descriptor for LoRA Modules used in fine-tuning models.
146+
Descriptor for a LoRA (Low-Rank Adaptation) module used in fine-tuning base models.
147+
148+
This class is used to define a single fine-tuned module that can be loaded during
149+
multi-model deployment alongside a base model.
142150
143151
Attributes
144152
----------
145153
model_id : str
146-
The unique identifier of the fine tuned model.
147-
model_name : str
148-
The name of the fine-tuned model.
149-
model_path : str
150-
The model-by-reference path to the LoRA Module within the model artifact
154+
The OCID of the fine-tuned model registered in the OCI Model Catalog.
155+
model_name : Optional[str]
156+
The unique name used to route inference requests to this model variant.
157+
model_path : Optional[str]
158+
The relative path within the artifact pointing to the LoRA adapter weights.
151159
"""
152160

153-
model_id: Optional[str] = Field(None, description="The fine tuned model OCID to deploy.")
154-
model_name: Optional[str] = Field(None, description="The name of the fine-tuned model.")
161+
model_config = ConfigDict(protected_namespaces=(), extra="allow")
162+
163+
model_id: str = Field(
164+
...,
165+
description="OCID of the fine-tuned model (must be registered in the Model Catalog).",
166+
)
167+
model_name: Optional[str] = Field(
168+
default=None,
169+
description="Name assigned to the fine-tuned model for serving (used as inference route).",
170+
)
155171
model_path: Optional[str] = Field(
156-
None,
157-
description="The model-by-reference path to the LoRA Module within the model artifact.",
172+
default=None,
173+
description="Relative path to the LoRA weights inside the model artifact.",
158174
)
159175

176+
@model_validator(mode="before")
177+
@classmethod
178+
def validate_lora_module(cls, data: dict) -> dict:
179+
"""Validates that required structure exists for a LoRA module."""
180+
if "model_id" not in data or not data["model_id"]:
181+
raise ValueError("Missing required field: 'model_id' for fine-tuned model.")
182+
return data
183+
160184

161185
class AquaMultiModelRef(Serializable):
162186
"""
@@ -203,6 +227,22 @@ class AquaMultiModelRef(Serializable):
203227
description="For fine tuned models, the artifact path of the modified model weights",
204228
)
205229

230+
def all_model_ids(self) -> List[str]:
231+
"""
232+
Returns all associated model OCIDs, including the base model and any fine-tuned models.
233+
234+
Returns
235+
-------
236+
List[str]
237+
A list of all model OCIDs associated with this multi-model reference.
238+
"""
239+
ids = {self.model_id}
240+
if self.fine_tune_weights:
241+
ids.update(
242+
module.model_id for module in self.fine_tune_weights if module.model_id
243+
)
244+
return list(ids)
245+
206246
class Config:
207247
extra = "ignore"
208248
protected_namespaces = ()

0 commit comments

Comments
 (0)