Skip to content

Commit f470234

Browse files
register model changes for 1.0.3
1 parent 592c107 commit f470234

File tree

3 files changed

+52
-38
lines changed

3 files changed

+52
-38
lines changed

ads/aqua/common/enums.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

@@ -8,6 +7,7 @@
87
~~~~~~~~~~~~~~
98
This module contains the set of enums used in AQUA.
109
"""
10+
1111
from ads.common.extended_enum import ExtendedEnumMeta
1212

1313

@@ -28,6 +28,7 @@ class Tags(str, metaclass=ExtendedEnumMeta):
2828
TASK = "task"
2929
LICENSE = "license"
3030
ORGANIZATION = "organization"
31+
PLATFORM = "platform"
3132
AQUA_TAG = "OCI_AQUA"
3233
AQUA_SERVICE_MODEL_TAG = "aqua_service_model"
3334
AQUA_FINE_TUNED_MODEL_TAG = "aqua_fine_tuned_model"

ads/aqua/model/entities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class AquaModelSummary(DataClassSerializable):
7676
ready_to_deploy: bool = True
7777
ready_to_finetune: bool = False
7878
ready_to_import: bool = False
79-
platform: List[str] = None
79+
platform: List[str] = field(default_factory=lambda: ["gpu"])
8080

8181

8282
@dataclass(repr=False)

ads/aqua/model/model.py

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54
import os
@@ -15,13 +14,17 @@
1514
from ads.aqua.common.enums import Tags
1615
from ads.aqua.common.errors import AquaRuntimeError
1716
from ads.aqua.common.utils import (
17+
copy_model_config,
1818
create_word_icon,
1919
get_artifact_path,
20-
read_file,
21-
copy_model_config,
2220
load_config,
21+
read_file,
2322
)
2423
from ads.aqua.constants import (
24+
AQUA_MODEL_ARTIFACT_CONFIG,
25+
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME,
26+
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE,
27+
AQUA_MODEL_TYPE_CUSTOM,
2528
LICENSE_TXT,
2629
MODEL_BY_REFERENCE_OSS_PATH_KEY,
2730
README,
@@ -33,10 +36,6 @@
3336
UNKNOWN,
3437
VALIDATION_METRICS,
3538
VALIDATION_METRICS_FINAL,
36-
AQUA_MODEL_ARTIFACT_CONFIG,
37-
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME,
38-
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE,
39-
AQUA_MODEL_TYPE_CUSTOM,
4039
)
4140
from ads.aqua.model.constants import *
4241
from ads.aqua.model.entities import *
@@ -236,7 +235,7 @@ def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaMod
236235
try:
237236
jobrun_ocid = ds_model.provenance_metadata.training_id
238237
jobrun = self.ds_client.get_job_run(jobrun_ocid).data
239-
except Exception as e:
238+
except Exception:
240239
logger.debug(
241240
f"Missing jobrun information in the provenance metadata of the given model {model_id}."
242241
)
@@ -268,8 +267,7 @@ def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaMod
268267

269268
job_run_status = (
270269
jobrun.lifecycle_state
271-
if jobrun
272-
and not jobrun.lifecycle_state == JobRun.LIFECYCLE_STATE_DELETED
270+
if jobrun and jobrun.lifecycle_state != JobRun.LIFECYCLE_STATE_DELETED
273271
else (
274272
JobRun.LIFECYCLE_STATE_SUCCEEDED
275273
if self.if_artifact_exist(ds_model.id)
@@ -540,7 +538,7 @@ def clear_model_list_cache(
540538
dict with the key used, and True if cache has the key that needs to be deleted.
541539
"""
542540
res = {}
543-
logger.info(f"Clearing _service_models_cache")
541+
logger.info("Clearing _service_models_cache")
544542
with self._cache_lock:
545543
if ODSC_MODEL_COMPARTMENT_OCID in self._service_models_cache.keys():
546544
self._service_models_cache.pop(key=ODSC_MODEL_COMPARTMENT_OCID)
@@ -561,6 +559,7 @@ def _create_model_catalog_entry(
561559
verified_model: DataScienceModel,
562560
compartment_id: Optional[str],
563561
project_id: Optional[str],
562+
is_gguf_model: bool,
564563
) -> DataScienceModel:
565564
"""Create model by reference from the object storage path
566565
@@ -581,9 +580,14 @@ def _create_model_catalog_entry(
581580
{
582581
**verified_model.freeform_tags,
583582
Tags.AQUA_SERVICE_MODEL_TAG: verified_model.id,
583+
Tags.PLATFORM: "cpu" if is_gguf_model else "gpu",
584584
}
585585
if verified_model
586-
else {Tags.AQUA_TAG: "active", Tags.BASE_MODEL_CUSTOM: "true"}
586+
else {
587+
Tags.AQUA_TAG: "active",
588+
Tags.BASE_MODEL_CUSTOM: "true",
589+
Tags.PLATFORM: "cpu" if is_gguf_model else "gpu",
590+
}
587591
)
588592
tags.update({Tags.BASE_MODEL_CUSTOM: "true"})
589593

@@ -615,8 +619,8 @@ def _create_model_catalog_entry(
615619
)
616620
else:
617621
logger.warn(
618-
f"Proceeding with model registration without the fine-tuning container information. "
619-
f"This model will not be available for fine tuning."
622+
"Proceeding with model registration without the fine-tuning container information. "
623+
"This model will not be available for fine tuning."
620624
)
621625

622626
metadata.add(
@@ -693,24 +697,25 @@ def register(
693697
The registered model as a AquaModel object.
694698
"""
695699
verified_model_details: DataScienceModel = None
696-
700+
model_config = None
697701
if not import_model_details:
698702
import_model_details = ImportModelDetails(**kwargs)
699-
700-
try:
701-
model_config = load_config(
702-
file_path=import_model_details.os_path,
703-
config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
704-
)
705-
except Exception as ex:
706-
logger.error(
707-
f"Exception occurred while loading config file from {import_model_details.os_path}"
708-
f"Exception message: {ex}"
709-
)
710-
raise AquaRuntimeError(
711-
f"The model path {import_model_details.os_path} does not contain the file config.json. "
712-
f"Please check if the path is correct or the model artifacts are available at this location."
713-
)
703+
is_gguf_model = import_model_details.inference_container == "odsc-llama-cpp"
704+
if not is_gguf_model:
705+
try:
706+
model_config = load_config(
707+
file_path=import_model_details.os_path,
708+
config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
709+
)
710+
except Exception as ex:
711+
logger.error(
712+
f"Exception occurred while loading config file from {import_model_details.os_path}"
713+
f"Exception message: {ex}"
714+
)
715+
raise AquaRuntimeError(
716+
f"The model path {import_model_details.os_path} does not contain the file config.json. "
717+
f"Please check if the path is correct or the model artifacts are available at this location."
718+
)
714719

715720
model_service_id = None
716721
# If OCID of a model is passed, we need to copy the defaults for Tags and metadata from the service model.
@@ -770,6 +775,7 @@ def register(
770775
verified_model=verified_model_details,
771776
compartment_id=import_model_details.compartment_id,
772777
project_id=import_model_details.project_id,
778+
is_gguf_model=is_gguf_model,
773779
)
774780
# registered model will always have inference and evaluation container, but
775781
# fine-tuning container may be not set
@@ -786,6 +792,7 @@ def register(
786792
except:
787793
finetuning_container = None
788794

795+
platform = "cpu" if is_gguf_model else "gpu"
789796
aqua_model_attributes = dict(
790797
**self._process_model(ds_model, self.region),
791798
project_id=ds_model.project_id,
@@ -798,17 +805,23 @@ def register(
798805
inference_container=inference_container,
799806
finetuning_container=finetuning_container,
800807
evaluation_container=evaluation_container,
808+
platform=platform,
801809
)
802810

803811
if verified_model_details:
804812
telemetry_model_name = model_name
813+
elif (
814+
model_config is not None
815+
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
816+
):
817+
telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
818+
elif (
819+
model_config is not None
820+
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config
821+
):
822+
telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
805823
else:
806-
if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config:
807-
telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
808-
elif AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config:
809-
telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
810-
else:
811-
telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
824+
telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
812825

813826
self.telemetry.record_event_async(
814827
category="aqua/model",

0 commit comments

Comments
 (0)