|
1 | 1 | #!/usr/bin/env python
|
2 | 2 | # -*- coding: utf-8 -*--
|
3 |
| - |
| 3 | +import json |
4 | 4 | # Copyright (c) 2024 Oracle and/or its affiliates.
|
5 | 5 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6 | 6 |
|
@@ -37,6 +37,21 @@ def mock_auth():
|
37 | 37 | yield mock_default_signer
|
38 | 38 |
|
39 | 39 |
|
| 40 | +@pytest.fixture(autouse=True, scope="class") |
| 41 | +def mock_get_container_config(): |
| 42 | + with patch("ads.aqua.ui.get_container_config") as mock_config: |
| 43 | + with open( |
| 44 | + os.path.join( |
| 45 | + os.path.dirname(os.path.abspath(__file__)), |
| 46 | + "test_data/ui/container_index.json", |
| 47 | + ), |
| 48 | + "r", |
| 49 | + ) as _file: |
| 50 | + container_index_json = json.load(_file) |
| 51 | + mock_config.return_value = container_index_json |
| 52 | + yield mock_config |
| 53 | + |
| 54 | + |
40 | 55 | @pytest.fixture(autouse=True, scope="class")
|
41 | 56 | def mock_init_client():
|
42 | 57 | with patch(
|
@@ -256,6 +271,7 @@ def test_get_foundation_models(
|
256 | 271 | mock_from_id,
|
257 | 272 | mock_read_file,
|
258 | 273 | foundation_model_type,
|
| 274 | + mock_get_container_config, |
259 | 275 | mock_auth,
|
260 | 276 | ):
|
261 | 277 | ds_model = MagicMock()
|
@@ -334,7 +350,7 @@ def test_get_foundation_models(
|
334 | 350 | "model_card": f"{mock_read_file.return_value}",
|
335 | 351 | "model_format": ModelFormat.SAFETENSORS,
|
336 | 352 | "name": f"{ds_model.display_name}",
|
337 |
| - "nvidia_gpu_supported": False, |
| 353 | + "nvidia_gpu_supported": True, |
338 | 354 | "organization": f'{ds_model.freeform_tags["organization"]}',
|
339 | 355 | "project_id": f"{ds_model.project_id}",
|
340 | 356 | "ready_to_deploy": False if foundation_model_type == "verified" else True,
|
@@ -366,6 +382,7 @@ def test_get_model_fine_tuned(
|
366 | 382 | mock_from_id,
|
367 | 383 | mock_read_file,
|
368 | 384 | mock_query_resource,
|
| 385 | + mock_get_container_config, |
369 | 386 | mock_auth,
|
370 | 387 | ):
|
371 | 388 | ds_model = MagicMock()
|
@@ -507,7 +524,7 @@ def test_get_model_fine_tuned(
|
507 | 524 | "model_card": f"{mock_read_file.return_value}",
|
508 | 525 | "model_format": ModelFormat.SAFETENSORS,
|
509 | 526 | "name": f"{ds_model.display_name}",
|
510 |
| - "nvidia_gpu_supported": False, |
| 527 | + "nvidia_gpu_supported": True, |
511 | 528 | "organization": "test_organization",
|
512 | 529 | "project_id": f"{ds_model.project_id}",
|
513 | 530 | "ready_to_deploy": True,
|
@@ -709,12 +726,16 @@ def test_import_model_with_project_compartment_override(self, mock_load_config):
|
709 | 726 | assert model.project_id == project_override
|
710 | 727 |
|
711 | 728 | @patch("ads.aqua.common.utils.load_config", side_effect=AquaFileNotFoundError)
|
712 |
| - def test_import_model_with_missing_config(self, mock_load_config): |
| 729 | + def test_import_model_with_missing_config( |
| 730 | + self, mock_load_config, mock_get_container_config |
| 731 | + ): |
713 | 732 | """Test for validating if error is returned when model artifacts are incomplete or not available."""
|
| 733 | + |
714 | 734 | os_path = "oci://aqua-bkt@aqua-ns/prefix/path"
|
715 | 735 | model_name = "oracle/aqua-1t-mega-model"
|
716 | 736 | reload(ads.aqua.model.model)
|
717 | 737 | app = AquaModelApp()
|
| 738 | + app.list_resource = MagicMock(return_value=[]) |
718 | 739 | with pytest.raises(AquaRuntimeError):
|
719 | 740 | model: AquaModel = app.register(
|
720 | 741 | model=model_name,
|
|
0 commit comments