Skip to content

Commit 5f3067c

Browse files
Addressing review comments
1 parent f470234 commit 5f3067c

File tree

4 files changed

+15
-14
lines changed

4 files changed

+15
-14
lines changed

ads/aqua/common/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class InferenceContainerType(str, metaclass=ExtendedEnumMeta):
4949
class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5050
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
5151
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
52+
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
5253

5354

5455
class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):

ads/aqua/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
DEFAULT_FT_REPLICA = 1
2222
DEFAULT_FT_BATCH_SIZE = 1
2323
DEFAULT_FT_VALIDATION_SET_SIZE = 0.1
24-
24+
ARM_CPU="arm_cpu"
25+
NVIDIA_GPU="nvidia_gpu"
2526
MAXIMUM_ALLOWED_DATASET_IN_BYTE = 52428800 # 1024 x 1024 x 50 = 50MB
2627
JOB_INFRASTRUCTURE_TYPE_DEFAULT_NETWORKING = "ME_STANDALONE"
2728
NB_SESSION_IDENTIFIER = "NB_SESSION_OCID"

ads/aqua/model/entities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class AquaModelSummary(DataClassSerializable):
7676
ready_to_deploy: bool = True
7777
ready_to_finetune: bool = False
7878
ready_to_import: bool = False
79-
platform: List[str] = field(default_factory=lambda: ["gpu"])
79+
platform: List[str] = field(default_factory=lambda: ["nvidia_gpu"])
8080

8181

8282
@dataclass(repr=False)

ads/aqua/model/model.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,16 @@
1111

1212
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
1313
from ads.aqua.app import AquaApp
14-
from ads.aqua.common.enums import Tags
14+
from ads.aqua.common.enums import Tags, InferenceContainerTypeFamily
1515
from ads.aqua.common.errors import AquaRuntimeError
1616
from ads.aqua.common.utils import (
17-
copy_model_config,
1817
create_word_icon,
1918
get_artifact_path,
20-
load_config,
2119
read_file,
20+
copy_model_config,
21+
load_config,
2222
)
2323
from ads.aqua.constants import (
24-
AQUA_MODEL_ARTIFACT_CONFIG,
25-
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME,
26-
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE,
27-
AQUA_MODEL_TYPE_CUSTOM,
2824
LICENSE_TXT,
2925
MODEL_BY_REFERENCE_OSS_PATH_KEY,
3026
README,
@@ -36,6 +32,10 @@
3632
UNKNOWN,
3733
VALIDATION_METRICS,
3834
VALIDATION_METRICS_FINAL,
35+
AQUA_MODEL_ARTIFACT_CONFIG,
36+
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME,
37+
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE,
38+
AQUA_MODEL_TYPE_CUSTOM, ARM_CPU, NVIDIA_GPU,
3939
)
4040
from ads.aqua.model.constants import *
4141
from ads.aqua.model.entities import *
@@ -235,7 +235,7 @@ def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaMod
235235
try:
236236
jobrun_ocid = ds_model.provenance_metadata.training_id
237237
jobrun = self.ds_client.get_job_run(jobrun_ocid).data
238-
except Exception:
238+
except Exception as e:
239239
logger.debug(
240240
f"Missing jobrun information in the provenance metadata of the given model {model_id}."
241241
)
@@ -580,16 +580,15 @@ def _create_model_catalog_entry(
580580
{
581581
**verified_model.freeform_tags,
582582
Tags.AQUA_SERVICE_MODEL_TAG: verified_model.id,
583-
Tags.PLATFORM: "cpu" if is_gguf_model else "gpu",
584583
}
585584
if verified_model
586585
else {
587586
Tags.AQUA_TAG: "active",
588587
Tags.BASE_MODEL_CUSTOM: "true",
589-
Tags.PLATFORM: "cpu" if is_gguf_model else "gpu",
590588
}
591589
)
592590
tags.update({Tags.BASE_MODEL_CUSTOM: "true"})
591+
tags.update({Tags.PLATFORM: ARM_CPU if is_gguf_model else NVIDIA_GPU})
593592

594593
# Remove `ready_to_import` tag that might get copied from service model.
595594
tags.pop(Tags.READY_TO_IMPORT, None)
@@ -700,7 +699,8 @@ def register(
700699
model_config = None
701700
if not import_model_details:
702701
import_model_details = ImportModelDetails(**kwargs)
703-
is_gguf_model = import_model_details.inference_container == "odsc-llama-cpp"
702+
is_gguf_model = import_model_details.inference_container == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
703+
platform = ARM_CPU if is_gguf_model else NVIDIA_GPU
704704
if not is_gguf_model:
705705
try:
706706
model_config = load_config(
@@ -792,7 +792,6 @@ def register(
792792
except:
793793
finetuning_container = None
794794

795-
platform = "cpu" if is_gguf_model else "gpu"
796795
aqua_model_attributes = dict(
797796
**self._process_model(ds_model, self.region),
798797
project_id=ds_model.project_id,

0 commit comments

Comments
 (0)