Skip to content

Commit 6d9c1a1

Browse files
Moving config from ADS to UI
1 parent 49da2cd commit 6d9c1a1

File tree

4 files changed

+58
-60
lines changed

4 files changed

+58
-60
lines changed

ads/aqua/extension/ui_handler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

@@ -10,8 +9,8 @@
109

1110
from ads.aqua.common.decorator import handle_exceptions
1211
from ads.aqua.common.enums import Tags
13-
from ads.aqua.extension.errors import Errors
1412
from ads.aqua.extension.base_handler import AquaAPIhandler
13+
from ads.aqua.extension.errors import Errors
1514
from ads.aqua.extension.utils import validate_function_parameters
1615
from ads.aqua.model.entities import ImportModelDetails
1716
from ads.aqua.ui import AquaUIApp
@@ -184,7 +183,10 @@ def get_shape_availability(self, **kwargs):
184183

185184
return self.finish(
186185
AquaUIApp().get_shape_availability(
187-
compartment_id=compartment_id, instance_shape=instance_shape, limit_name=limit_name, **kwargs
186+
compartment_id=compartment_id,
187+
instance_shape=instance_shape,
188+
limit_name=limit_name,
189+
**kwargs,
188190
)
189191
)
190192

ads/aqua/finetuning/finetuning.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
get_container_image,
2222
upload_local_to_os,
2323
)
24-
from ads.aqua.config.config import get_finetuning_config_defaults
2524
from ads.aqua.constants import (
2625
DEFAULT_FT_BATCH_SIZE,
2726
DEFAULT_FT_BLOCK_STORAGE_SIZE,
@@ -563,7 +562,9 @@ def get_finetuning_config(self, model_id: str) -> Dict:
563562

564563
config = self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG)
565564
if not config:
566-
logger.info(f"default fine-tuning config will be used for model: {model_id}")
565+
logger.info(
566+
f"default fine-tuning config will be used for model: {model_id}"
567+
)
567568
return config
568569

569570
@telemetry(

ads/aqua/modeldeployment/deployment.py

Lines changed: 48 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
get_params_list,
2424
get_resource_name,
2525
get_restricted_params_by_container,
26-
load_config,
2726
)
2827
from ads.aqua.constants import (
2928
AQUA_MODEL_ARTIFACT_FILE,
@@ -44,11 +43,8 @@
4443
from ads.common.object_storage_details import ObjectStorageDetails
4544
from ads.common.utils import get_log_links
4645
from ads.config import (
47-
AQUA_CONFIG_FOLDER,
4846
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,
49-
AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME,
5047
AQUA_MODEL_DEPLOYMENT_CONFIG,
51-
AQUA_MODEL_DEPLOYMENT_CONFIG_DEFAULTS,
5248
COMPARTMENT_OCID,
5349
)
5450
from ads.model.datascience_model import DataScienceModel
@@ -87,26 +83,26 @@ class AquaDeploymentApp(AquaApp):
8783

8884
@telemetry(entry_point="plugin=deployment&action=create", name="aqua")
8985
def create(
90-
self,
91-
model_id: str,
92-
instance_shape: str,
93-
display_name: str,
94-
instance_count: int = None,
95-
log_group_id: str = None,
96-
access_log_id: str = None,
97-
predict_log_id: str = None,
98-
compartment_id: str = None,
99-
project_id: str = None,
100-
description: str = None,
101-
bandwidth_mbps: int = None,
102-
web_concurrency: int = None,
103-
server_port: int = None,
104-
health_check_port: int = None,
105-
env_var: Dict = None,
106-
container_family: str = None,
107-
memory_in_gbs: Optional[float] = None,
108-
ocpus: Optional[float] = None,
109-
model_file: Optional[str] = None,
86+
self,
87+
model_id: str,
88+
instance_shape: str,
89+
display_name: str,
90+
instance_count: int = None,
91+
log_group_id: str = None,
92+
access_log_id: str = None,
93+
predict_log_id: str = None,
94+
compartment_id: str = None,
95+
project_id: str = None,
96+
description: str = None,
97+
bandwidth_mbps: int = None,
98+
web_concurrency: int = None,
99+
server_port: int = None,
100+
health_check_port: int = None,
101+
env_var: Dict = None,
102+
container_family: str = None,
103+
memory_in_gbs: Optional[float] = None,
104+
ocpus: Optional[float] = None,
105+
model_file: Optional[str] = None,
110106
) -> "AquaDeployment":
111107
"""
112108
Creates a new Aqua deployment
@@ -231,8 +227,7 @@ def create(
231227
env_var.update({"FT_MODEL": f"{fine_tune_output_path}"})
232228

233229
container_type_key = self._get_container_type_key(
234-
model=aqua_model,
235-
container_family=container_family
230+
model=aqua_model, container_family=container_family
236231
)
237232

238233
# fetch image name from config
@@ -248,7 +243,11 @@ def create(
248243
model_format = model_formats_str.split(",")
249244

250245
# Figure out a better way to handle this in future release
251-
if ModelFormat.GGUF.value in model_format and container_type_key.lower() == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY:
246+
if (
247+
ModelFormat.GGUF.value in model_format
248+
and container_type_key.lower()
249+
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
250+
):
252251
if model_file is not None:
253252
logger.info(
254253
f"Overriding {model_file} as model_file for model {aqua_model.id}."
@@ -299,8 +298,8 @@ def create(
299298
if user_params:
300299
# todo: remove this check in the future version, logic to be moved to container_index
301300
if (
302-
container_type_key.lower()
303-
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
301+
container_type_key.lower()
302+
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
304303
):
305304
# AQUA_LLAMA_CPP_CONTAINER_FAMILY container uses uvicorn that required model/server params
306305
# to be set as env vars
@@ -422,9 +421,8 @@ def _get_container_type_key(model: DataScienceModel, container_family: str) -> s
422421
f"for model {model.id}. For unverified Aqua models, {AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} should be"
423422
f"set and value can be one of {', '.join(InferenceContainerTypeFamily.values())}."
424423
) from err
425-
424+
426425
return container_type_key
427-
428426

429427
@telemetry(entry_point="plugin=deployment&action=list", name="aqua")
430428
def list(self, **kwargs) -> List["AquaDeployment"]:
@@ -453,8 +451,8 @@ def list(self, **kwargs) -> List["AquaDeployment"]:
453451
for model_deployment in model_deployments:
454452
oci_aqua = (
455453
(
456-
Tags.AQUA_TAG in model_deployment.freeform_tags
457-
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
454+
Tags.AQUA_TAG in model_deployment.freeform_tags
455+
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
458456
)
459457
if model_deployment.freeform_tags
460458
else False
@@ -508,8 +506,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
508506

509507
oci_aqua = (
510508
(
511-
Tags.AQUA_TAG in model_deployment.freeform_tags
512-
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
509+
Tags.AQUA_TAG in model_deployment.freeform_tags
510+
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
513511
)
514512
if model_deployment.freeform_tags
515513
else False
@@ -526,8 +524,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
526524
log_group_name = ""
527525

528526
logs = (
529-
model_deployment.category_log_details.access
530-
or model_deployment.category_log_details.predict
527+
model_deployment.category_log_details.access
528+
or model_deployment.category_log_details.predict
531529
)
532530
if logs:
533531
log_id = logs.log_id
@@ -578,9 +576,9 @@ def get_deployment_config(self, model_id: str) -> Dict:
578576
return config
579577

580578
def get_deployment_default_params(
581-
self,
582-
model_id: str,
583-
instance_shape: str,
579+
self,
580+
model_id: str,
581+
instance_shape: str,
584582
) -> List[str]:
585583
"""Gets the default params set in the deployment configs for the given model and instance shape.
586584
@@ -612,8 +610,8 @@ def get_deployment_default_params(
612610
)
613611

614612
if (
615-
container_type_key
616-
and container_type_key in InferenceContainerTypeFamily.values()
613+
container_type_key
614+
and container_type_key in InferenceContainerTypeFamily.values()
617615
):
618616
deployment_config = self.get_deployment_config(model_id)
619617
config_params = (
@@ -636,10 +634,10 @@ def get_deployment_default_params(
636634
return default_params
637635

638636
def validate_deployment_params(
639-
self,
640-
model_id: str,
641-
params: List[str] = None,
642-
container_family: str = None,
637+
self,
638+
model_id: str,
639+
params: List[str] = None,
640+
container_family: str = None,
643641
) -> Dict:
644642
"""Validate if the deployment parameters passed by the user can be overridden. Parameter values are not
645643
validated, only param keys are validated.
@@ -662,8 +660,7 @@ def validate_deployment_params(
662660
if params:
663661
model = DataScienceModel.from_id(model_id)
664662
container_type_key = self._get_container_type_key(
665-
model=model,
666-
container_family=container_family
663+
model=model, container_family=container_family
667664
)
668665

669666
container_config = get_container_config()
@@ -685,9 +682,9 @@ def validate_deployment_params(
685682

686683
@staticmethod
687684
def _find_restricted_params(
688-
default_params: Union[str, List[str]],
689-
user_params: Union[str, List[str]],
690-
container_family: str,
685+
default_params: Union[str, List[str]],
686+
user_params: Union[str, List[str]],
687+
container_family: str,
691688
) -> List[str]:
692689
"""Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
693690
The default parameters coming from the container index json file cannot be overridden.

ads/aqua/ui.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,12 @@
1717
from ads.aqua.common.entities import ContainerSpec
1818
from ads.aqua.common.enums import Tags
1919
from ads.aqua.common.errors import AquaResourceAccessError, AquaValueError
20-
from ads.aqua.common.utils import get_container_config, load_config, sanitize_response
20+
from ads.aqua.common.utils import get_container_config, sanitize_response
2121
from ads.common import oci_client as oc
2222
from ads.common.auth import default_signer
2323
from ads.common.object_storage_details import ObjectStorageDetails
2424
from ads.common.serializer import DataClassSerializable
2525
from ads.config import (
26-
AQUA_CONFIG_FOLDER,
27-
AQUA_RESOURCE_LIMIT_NAMES_CONFIG,
2826
COMPARTMENT_OCID,
2927
DATA_SCIENCE_SERVICE_NAME,
3028
TENANCY_OCID,
@@ -572,7 +570,7 @@ def get_shape_availability(self, **kwargs):
572570
"""
573571
compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID)
574572
instance_shape = kwargs.pop("instance_shape", None)
575-
limit_name = kwargs.pop("limit_name",None)
573+
limit_name = kwargs.pop("limit_name", None)
576574

577575
if not instance_shape:
578576
raise AquaValueError("instance_shape argument is required.")

0 commit comments

Comments
 (0)