Skip to content

Commit 76f9494

Browse files
committed
Update model tests
1 parent 93609fa commit 76f9494

File tree

2 files changed

+25
-6
lines changed

2 files changed

+25
-6
lines changed

ads/aqua/common/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
)
4747
from ads.aqua.data import AquaResourceIdentifier
4848
from ads.common.auth import default_signer
49-
from ads.common.decorator.threaded import threaded
5049
from ads.common.extended_enum import ExtendedEnumMeta
5150
from ads.common.object_storage_details import ObjectStorageDetails
5251
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
@@ -216,7 +215,6 @@ def read_file(file_path: str, **kwargs) -> str:
216215
return UNKNOWN
217216

218217

219-
@threaded()
220218
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
221219
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
222220
signer = default_signer() if artifact_path.startswith("oci://") else {}

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*--
3-
3+
import json
44
# Copyright (c) 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

@@ -37,6 +37,21 @@ def mock_auth():
3737
yield mock_default_signer
3838

3939

40+
@pytest.fixture(autouse=True, scope="class")
41+
def mock_get_container_config():
42+
with patch("ads.aqua.ui.get_container_config") as mock_config:
43+
with open(
44+
os.path.join(
45+
os.path.dirname(os.path.abspath(__file__)),
46+
"test_data/ui/container_index.json",
47+
),
48+
"r",
49+
) as _file:
50+
container_index_json = json.load(_file)
51+
mock_config.return_value = container_index_json
52+
yield mock_config
53+
54+
4055
@pytest.fixture(autouse=True, scope="class")
4156
def mock_init_client():
4257
with patch(
@@ -256,6 +271,7 @@ def test_get_foundation_models(
256271
mock_from_id,
257272
mock_read_file,
258273
foundation_model_type,
274+
mock_get_container_config,
259275
mock_auth,
260276
):
261277
ds_model = MagicMock()
@@ -334,7 +350,7 @@ def test_get_foundation_models(
334350
"model_card": f"{mock_read_file.return_value}",
335351
"model_format": ModelFormat.SAFETENSORS,
336352
"name": f"{ds_model.display_name}",
337-
"nvidia_gpu_supported": False,
353+
"nvidia_gpu_supported": True,
338354
"organization": f'{ds_model.freeform_tags["organization"]}',
339355
"project_id": f"{ds_model.project_id}",
340356
"ready_to_deploy": False if foundation_model_type == "verified" else True,
@@ -366,6 +382,7 @@ def test_get_model_fine_tuned(
366382
mock_from_id,
367383
mock_read_file,
368384
mock_query_resource,
385+
mock_get_container_config,
369386
mock_auth,
370387
):
371388
ds_model = MagicMock()
@@ -507,7 +524,7 @@ def test_get_model_fine_tuned(
507524
"model_card": f"{mock_read_file.return_value}",
508525
"model_format": ModelFormat.SAFETENSORS,
509526
"name": f"{ds_model.display_name}",
510-
"nvidia_gpu_supported": False,
527+
"nvidia_gpu_supported": True,
511528
"organization": "test_organization",
512529
"project_id": f"{ds_model.project_id}",
513530
"ready_to_deploy": True,
@@ -709,12 +726,16 @@ def test_import_model_with_project_compartment_override(self, mock_load_config):
709726
assert model.project_id == project_override
710727

711728
@patch("ads.aqua.common.utils.load_config", side_effect=AquaFileNotFoundError)
712-
def test_import_model_with_missing_config(self, mock_load_config):
729+
def test_import_model_with_missing_config(
730+
self, mock_load_config, mock_get_container_config
731+
):
713732
"""Test for validating if error is returned when model artifacts are incomplete or not available."""
733+
714734
os_path = "oci://aqua-bkt@aqua-ns/prefix/path"
715735
model_name = "oracle/aqua-1t-mega-model"
716736
reload(ads.aqua.model.model)
717737
app = AquaModelApp()
738+
app.list_resource = MagicMock(return_value=[])
718739
with pytest.raises(AquaRuntimeError):
719740
model: AquaModel = app.register(
720741
model=model_name,

0 commit comments

Comments
 (0)