|
| 1 | +#!/usr/bin/env python |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | +# Copyright (c) 2024 Oracle and/or its affiliates. |
| 4 | +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
| 5 | + |
| 6 | +import os |
| 7 | +from typing import Dict, Union |
| 8 | + |
| 9 | +import oci |
| 10 | +from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails |
| 11 | + |
| 12 | +from ads import set_auth |
| 13 | +from ads.aqua import logger |
| 14 | +from ads.aqua.data import Tags |
| 15 | +from ads.aqua.exception import AquaRuntimeError, AquaValueError |
| 16 | +from ads.aqua.utils import ( |
| 17 | + UNKNOWN, |
| 18 | + _is_valid_mvs, |
| 19 | + get_artifact_path, |
| 20 | + get_base_model_from_tags, |
| 21 | + is_valid_ocid, |
| 22 | + load_config, |
| 23 | + logger, |
| 24 | +) |
| 25 | +from ads.common import oci_client as oc |
| 26 | +from ads.common.auth import default_signer |
| 27 | +from ads.common.utils import extract_region |
| 28 | +from ads.config import ( |
| 29 | + AQUA_TELEMETRY_BUCKET, |
| 30 | + AQUA_TELEMETRY_BUCKET_NS, |
| 31 | + OCI_ODSC_SERVICE_ENDPOINT, |
| 32 | + OCI_RESOURCE_PRINCIPAL_VERSION, |
| 33 | +) |
| 34 | +from ads.model.datascience_model import DataScienceModel |
| 35 | +from ads.model.deployment.model_deployment import ModelDeployment |
| 36 | +from ads.model.model_metadata import ( |
| 37 | + ModelCustomMetadata, |
| 38 | + ModelProvenanceMetadata, |
| 39 | + ModelTaxonomyMetadata, |
| 40 | +) |
| 41 | +from ads.model.model_version_set import ModelVersionSet |
| 42 | +from ads.telemetry import telemetry |
| 43 | +from ads.telemetry.client import TelemetryClient |
| 44 | + |
| 45 | + |
| 46 | +class AquaApp: |
| 47 | + """Base Aqua App to contain common components.""" |
| 48 | + |
| 49 | + @telemetry(name="aqua") |
| 50 | + def __init__(self) -> None: |
| 51 | + if OCI_RESOURCE_PRINCIPAL_VERSION: |
| 52 | + set_auth("resource_principal") |
| 53 | + self._auth = default_signer({"service_endpoint": OCI_ODSC_SERVICE_ENDPOINT}) |
| 54 | + self.ds_client = oc.OCIClientFactory(**self._auth).data_science |
| 55 | + self.logging_client = oc.OCIClientFactory(**default_signer()).logging_management |
| 56 | + self.identity_client = oc.OCIClientFactory(**default_signer()).identity |
| 57 | + self.region = extract_region(self._auth) |
| 58 | + self._telemetry = None |
| 59 | + |
| 60 | + def list_resource( |
| 61 | + self, |
| 62 | + list_func_ref, |
| 63 | + **kwargs, |
| 64 | + ) -> list: |
| 65 | + """Generic method to list OCI Data Science resources. |
| 66 | +
|
| 67 | + Parameters |
| 68 | + ---------- |
| 69 | + list_func_ref : function |
| 70 | + A reference to the list operation which will be called. |
| 71 | + **kwargs : |
| 72 | + Additional keyword arguments to filter the resource. |
| 73 | + The kwargs are passed into OCI API. |
| 74 | +
|
| 75 | + Returns |
| 76 | + ------- |
| 77 | + list |
| 78 | + A list of OCI Data Science resources. |
| 79 | + """ |
| 80 | + return oci.pagination.list_call_get_all_results( |
| 81 | + list_func_ref, |
| 82 | + **kwargs, |
| 83 | + ).data |
| 84 | + |
| 85 | + def update_model(self, model_id: str, update_model_details: UpdateModelDetails): |
| 86 | + """Updates model details. |
| 87 | +
|
| 88 | + Parameters |
| 89 | + ---------- |
| 90 | + model_id : str |
| 91 | + The id of target model. |
| 92 | + update_model_details: UpdateModelDetails |
| 93 | + The model details to be updated. |
| 94 | + """ |
| 95 | + self.ds_client.update_model( |
| 96 | + model_id=model_id, update_model_details=update_model_details |
| 97 | + ) |
| 98 | + |
| 99 | + def update_model_provenance( |
| 100 | + self, |
| 101 | + model_id: str, |
| 102 | + update_model_provenance_details: UpdateModelProvenanceDetails, |
| 103 | + ): |
| 104 | + """Updates model provenance details. |
| 105 | +
|
| 106 | + Parameters |
| 107 | + ---------- |
| 108 | + model_id : str |
| 109 | + The id of target model. |
| 110 | + update_model_provenance_details: UpdateModelProvenanceDetails |
| 111 | + The model provenance details to be updated. |
| 112 | + """ |
| 113 | + self.ds_client.update_model_provenance( |
| 114 | + model_id=model_id, |
| 115 | + update_model_provenance_details=update_model_provenance_details, |
| 116 | + ) |
| 117 | + |
| 118 | + # TODO: refactor model evaluation implementation to use it. |
| 119 | + @staticmethod |
| 120 | + def get_source(source_id: str) -> Union[ModelDeployment, DataScienceModel]: |
| 121 | + if is_valid_ocid(source_id): |
| 122 | + if "datasciencemodeldeployment" in source_id: |
| 123 | + return ModelDeployment.from_id(source_id) |
| 124 | + elif "datasciencemodel" in source_id: |
| 125 | + return DataScienceModel.from_id(source_id) |
| 126 | + |
| 127 | + raise AquaValueError( |
| 128 | + f"Invalid source {source_id}. " |
| 129 | + "Specify either a model or model deployment id." |
| 130 | + ) |
| 131 | + |
| 132 | + # TODO: refactor model evaluation implementation to use it. |
| 133 | + @staticmethod |
| 134 | + def create_model_version_set( |
| 135 | + model_version_set_id: str = None, |
| 136 | + model_version_set_name: str = None, |
| 137 | + description: str = None, |
| 138 | + compartment_id: str = None, |
| 139 | + project_id: str = None, |
| 140 | + **kwargs, |
| 141 | + ) -> tuple: |
| 142 | + """Creates ModelVersionSet from given ID or Name. |
| 143 | +
|
| 144 | + Parameters |
| 145 | + ---------- |
| 146 | + model_version_set_id: (str, optional): |
| 147 | + ModelVersionSet OCID. |
| 148 | + model_version_set_name: (str, optional): |
| 149 | + ModelVersionSet Name. |
| 150 | + description: (str, optional): |
| 151 | + TBD |
| 152 | + compartment_id: (str, optional): |
| 153 | + Compartment OCID. |
| 154 | + project_id: (str, optional): |
| 155 | + Project OCID. |
| 156 | + tag: (str, optional) |
| 157 | + calling tag, can be Tags.AQUA_FINE_TUNING or Tags.AQUA_EVALUATION |
| 158 | +
|
| 159 | + Returns |
| 160 | + ------- |
| 161 | + tuple: (model_version_set_id, model_version_set_name) |
| 162 | + """ |
| 163 | + # TODO: tag should be selected based on which operation (eval/FT) invoke this method |
| 164 | + # currently only used by fine-tuning flow. |
| 165 | + tag = Tags.AQUA_FINE_TUNING.value |
| 166 | + |
| 167 | + if not model_version_set_id: |
| 168 | + tag = Tags.AQUA_FINE_TUNING.value # TODO: Fix this |
| 169 | + try: |
| 170 | + model_version_set = ModelVersionSet.from_name( |
| 171 | + name=model_version_set_name, |
| 172 | + compartment_id=compartment_id, |
| 173 | + ) |
| 174 | + |
| 175 | + if not _is_valid_mvs(model_version_set, tag): |
| 176 | + raise AquaValueError( |
| 177 | + f"Invalid model version set name. Please provide a model version set with `{tag}` in tags." |
| 178 | + ) |
| 179 | + |
| 180 | + except: |
| 181 | + logger.debug( |
| 182 | + f"Model version set {model_version_set_name} doesn't exist. " |
| 183 | + "Creating new model version set." |
| 184 | + ) |
| 185 | + mvs_freeform_tags = { |
| 186 | + tag: tag, |
| 187 | + } |
| 188 | + model_version_set = ( |
| 189 | + ModelVersionSet() |
| 190 | + .with_compartment_id(compartment_id) |
| 191 | + .with_project_id(project_id) |
| 192 | + .with_name(model_version_set_name) |
| 193 | + .with_description(description) |
| 194 | + .with_freeform_tags(**mvs_freeform_tags) |
| 195 | + # TODO: decide what parameters will be needed |
| 196 | + # when refactor eval to use this method, we need to pass tag here. |
| 197 | + .create(**kwargs) |
| 198 | + ) |
| 199 | + logger.debug( |
| 200 | + f"Successfully created model version set {model_version_set_name} with id {model_version_set.id}." |
| 201 | + ) |
| 202 | + return (model_version_set.id, model_version_set_name) |
| 203 | + else: |
| 204 | + model_version_set = ModelVersionSet.from_id(model_version_set_id) |
| 205 | + # TODO: tag should be selected based on which operation (eval/FT) invoke this method |
| 206 | + if not _is_valid_mvs(model_version_set, tag): |
| 207 | + raise AquaValueError( |
| 208 | + f"Invalid model version set id. Please provide a model version set with `{tag}` in tags." |
| 209 | + ) |
| 210 | + return (model_version_set_id, model_version_set.name) |
| 211 | + |
| 212 | + # TODO: refactor model evaluation implementation to use it. |
| 213 | + @staticmethod |
| 214 | + def create_model_catalog( |
| 215 | + display_name: str, |
| 216 | + description: str, |
| 217 | + model_version_set_id: str, |
| 218 | + model_custom_metadata: Union[ModelCustomMetadata, Dict], |
| 219 | + model_taxonomy_metadata: Union[ModelTaxonomyMetadata, Dict], |
| 220 | + compartment_id: str, |
| 221 | + project_id: str, |
| 222 | + **kwargs, |
| 223 | + ) -> DataScienceModel: |
| 224 | + model = ( |
| 225 | + DataScienceModel() |
| 226 | + .with_compartment_id(compartment_id) |
| 227 | + .with_project_id(project_id) |
| 228 | + .with_display_name(display_name) |
| 229 | + .with_description(description) |
| 230 | + .with_model_version_set_id(model_version_set_id) |
| 231 | + .with_custom_metadata_list(model_custom_metadata) |
| 232 | + .with_defined_metadata_list(model_taxonomy_metadata) |
| 233 | + .with_provenance_metadata(ModelProvenanceMetadata(training_id=UNKNOWN)) |
| 234 | + # TODO: decide what parameters will be needed |
| 235 | + .create( |
| 236 | + **kwargs, |
| 237 | + ) |
| 238 | + ) |
| 239 | + return model |
| 240 | + |
| 241 | + def if_artifact_exist(self, model_id: str, **kwargs) -> bool: |
| 242 | + """Checks if the artifact exists. |
| 243 | +
|
| 244 | + Parameters |
| 245 | + ---------- |
| 246 | + model_id : str |
| 247 | + The model OCID. |
| 248 | + **kwargs : |
| 249 | + Additional keyword arguments passed in head_model_artifact. |
| 250 | +
|
| 251 | + Returns |
| 252 | + ------- |
| 253 | + bool |
| 254 | + Whether the artifact exists. |
| 255 | + """ |
| 256 | + |
| 257 | + try: |
| 258 | + response = self.ds_client.head_model_artifact(model_id=model_id, **kwargs) |
| 259 | + return True if response.status == 200 else False |
| 260 | + except oci.exceptions.ServiceError as ex: |
| 261 | + if ex.status == 404: |
| 262 | + logger.info(f"Artifact not found in model {model_id}.") |
| 263 | + return False |
| 264 | + |
| 265 | + def get_config(self, model_id: str, config_file_name: str) -> Dict: |
| 266 | + """Gets the config for the given Aqua model. |
| 267 | +
|
| 268 | + Parameters |
| 269 | + ---------- |
| 270 | + model_id: str |
| 271 | + The OCID of the Aqua model. |
| 272 | + config_file_name: str |
| 273 | + name of the config file |
| 274 | +
|
| 275 | + Returns |
| 276 | + ------- |
| 277 | + Dict: |
| 278 | + A dict of allowed configs. |
| 279 | + """ |
| 280 | + oci_model = self.ds_client.get_model(model_id).data |
| 281 | + oci_aqua = ( |
| 282 | + ( |
| 283 | + Tags.AQUA_TAG.value in oci_model.freeform_tags |
| 284 | + or Tags.AQUA_TAG.value.lower() in oci_model.freeform_tags |
| 285 | + ) |
| 286 | + if oci_model.freeform_tags |
| 287 | + else False |
| 288 | + ) |
| 289 | + |
| 290 | + if not oci_aqua: |
| 291 | + raise AquaRuntimeError(f"Target model {oci_model.id} is not Aqua model.") |
| 292 | + |
| 293 | + config = {} |
| 294 | + artifact_path = get_artifact_path(oci_model.custom_metadata_list) |
| 295 | + if not artifact_path: |
| 296 | + logger.error( |
| 297 | + f"Failed to get artifact path from custom metadata for the model: {model_id}" |
| 298 | + ) |
| 299 | + return config |
| 300 | + |
| 301 | + try: |
| 302 | + config_path = f"{os.path.dirname(artifact_path)}/config/" |
| 303 | + config = load_config( |
| 304 | + config_path, |
| 305 | + config_file_name=config_file_name, |
| 306 | + ) |
| 307 | + except: |
| 308 | + pass |
| 309 | + |
| 310 | + if not config: |
| 311 | + logger.error( |
| 312 | + f"{config_file_name} is not available for the model: {model_id}. Check if the custom metadata has the artifact path set." |
| 313 | + ) |
| 314 | + return config |
| 315 | + |
| 316 | + return config |
| 317 | + |
| 318 | + @property |
| 319 | + def telemetry(self): |
| 320 | + if not self._telemetry: |
| 321 | + self._telemetry = TelemetryClient( |
| 322 | + bucket=AQUA_TELEMETRY_BUCKET, namespace=AQUA_TELEMETRY_BUCKET_NS |
| 323 | + ) |
| 324 | + return self._telemetry |
0 commit comments