Skip to content

Commit 4ccd311

Browse files
fix model tests
1 parent 5e30a01 commit 4ccd311

File tree

1 file changed

+130
-113
lines changed

1 file changed

+130
-113
lines changed

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 130 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def mock_get_container_config():
6565
yield mock_config
6666

6767

68-
@pytest.fixture(autouse=True, scope="class")
68+
@pytest.fixture(autouse=True, scope="function")
6969
def mock_get_hf_model_info():
7070
with patch.object(HfApi, "model_info") as mock_get_hf_model_info:
7171
test_hf_model_info = ModelInfo(
@@ -230,17 +230,17 @@ class TestDataset:
230230
class TestAquaModel:
231231
"""Contains unittests for AquaModelApp."""
232232

233-
@pytest.fixture(autouse=True, scope="class")
234-
def mock_auth(cls):
235-
with patch("ads.common.auth.default_signer") as mock_default_signer:
236-
yield mock_default_signer
237-
238-
@pytest.fixture(autouse=True, scope="class")
239-
def mock_init_client(cls):
240-
with patch(
241-
"ads.common.oci_datascience.OCIDataScienceMixin.init_client"
242-
) as mock_client:
243-
yield mock_client
233+
# @pytest.fixture(autouse=True, scope="class")
234+
# def mock_auth(cls):
235+
# with patch("ads.common.auth.default_signer") as mock_default_signer:
236+
# yield mock_default_signer
237+
#
238+
# @pytest.fixture(autouse=True, scope="class")
239+
# def mock_init_client(cls):
240+
# with patch(
241+
# "ads.common.oci_datascience.OCIDataScienceMixin.init_client"
242+
# ) as mock_client:
243+
# yield mock_client
244244

245245
def setup_method(self):
246246
self.default_signer_patch = patch(
@@ -658,6 +658,9 @@ def test_get_model_fine_tuned(
658658
(False, False),
659659
],
660660
)
661+
@patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create")
662+
@patch("ads.model.datascience_model.DataScienceModel.sync")
663+
@patch("ads.model.datascience_model.DataScienceModel.upload_artifact")
661664
@patch.object(AquaModelApp, "_find_matching_aqua_model")
662665
@patch("ads.aqua.common.utils.copy_file")
663666
@patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects")
@@ -672,16 +675,15 @@ def test_import_verified_model(
672675
mock_list_objects,
673676
mock_copy_file,
674677
mock__find_matching_aqua_model,
678+
mock_upload_artifact,
679+
mock_sync,
680+
mock_ocidsc_create,
675681
artifact_location_set,
676682
download_from_hf,
677683
mock_get_hf_model_info,
684+
mock_init_client,
678685
):
679686
ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True)
680-
ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock()
681-
DataScienceModel.upload_artifact = MagicMock()
682-
DataScienceModel.sync = MagicMock()
683-
OCIDataScienceModel.create = MagicMock()
684-
685687
# The name attribute cannot be mocked during creation of the mock object,
686688
# hence attach it separately to the mocked objects.
687689
artifact_path = "service_models/model-name/commit-id/artifact"
@@ -780,17 +782,21 @@ def test_import_verified_model(
780782
assert model.ready_to_deploy is True
781783
assert model.ready_to_finetune is False
782784

785+
@patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create")
786+
@patch("ads.model.datascience_model.DataScienceModel.sync")
787+
@patch("ads.model.datascience_model.DataScienceModel.upload_artifact")
783788
@patch.object(AquaModelApp, "_validate_model")
784789
@patch("ads.aqua.common.utils.load_config", return_value={})
785790
def test_import_any_model_no_containers_specified(
786-
self, mock_load_config, mock__validate_model, mock_get_hf_model_info
791+
self,
792+
mock_load_config,
793+
mock__validate_model,
794+
mock_upload_artifact,
795+
mock_sync,
796+
mock_ocidsc_create,
797+
mock_get_hf_model_info,
787798
):
788799
ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True)
789-
ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock()
790-
DataScienceModel.upload_artifact = MagicMock()
791-
DataScienceModel.sync = MagicMock()
792-
OCIDataScienceModel.create = MagicMock()
793-
794800
os_path = "oci://aqua-bkt@aqua-ns/prefix/path"
795801
model_name = "oracle/aqua-1t-mega-model"
796802
ds_freeform_tags = {
@@ -827,6 +833,9 @@ def test_import_any_model_no_containers_specified(
827833
"download_from_hf",
828834
[True, False],
829835
)
836+
@patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create")
837+
@patch("ads.model.datascience_model.DataScienceModel.sync")
838+
@patch("ads.model.datascience_model.DataScienceModel.upload_artifact")
830839
@patch.object(AquaModelApp, "_find_matching_aqua_model")
831840
@patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects")
832841
@patch("ads.aqua.common.utils.load_config", return_value={})
@@ -839,14 +848,13 @@ def test_import_model_with_project_compartment_override(
839848
mock_load_config,
840849
mock_list_objects,
841850
mock__find_matching_aqua_model,
851+
mock_upload_artifact,
852+
mock_sync,
853+
mock_ocidsc_create,
842854
download_from_hf,
843855
mock_get_hf_model_info,
844856
):
845857
ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True)
846-
ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock()
847-
DataScienceModel.upload_artifact = MagicMock()
848-
DataScienceModel.sync = MagicMock()
849-
OCIDataScienceModel.create = MagicMock()
850858

851859
mock_list_objects.return_value = MagicMock(objects=[])
852860
ds_model = DataScienceModel()
@@ -908,6 +916,8 @@ def test_import_model_with_project_compartment_override(
908916
"download_from_hf",
909917
[True, False],
910918
)
919+
@patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create")
920+
@patch("ads.model.datascience_model.DataScienceModel.upload_artifact")
911921
@patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects")
912922
@patch("ads.aqua.common.utils.load_config", side_effect=AquaFileNotFoundError)
913923
@patch("huggingface_hub.snapshot_download")
@@ -918,17 +928,18 @@ def test_import_model_with_missing_config(
918928
mock_snapshot_download,
919929
mock_load_config,
920930
mock_list_objects,
931+
mock_upload_artifact,
932+
mock_ocidsc_create,
921933
mock_get_container_config,
922934
download_from_hf,
923935
mock_get_hf_model_info,
936+
mock_init_client,
924937
):
925938
"""Test for validating if error is returned when model artifacts are incomplete or not available."""
926939

927940
os_path = "oci://aqua-bkt@aqua-ns/prefix/path"
928941
model_name = "oracle/aqua-1t-mega-model"
929942
ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True)
930-
ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock()
931-
DataScienceModel.upload_artifact = MagicMock()
932943
mock_list_objects.return_value = MagicMock(objects=[])
933944
reload(ads.aqua.model.model)
934945
app = AquaModelApp()
@@ -952,21 +963,24 @@ def test_import_model_with_missing_config(
952963
download_from_hf=False,
953964
)
954965

966+
@patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create")
967+
@patch("ads.model.datascience_model.DataScienceModel.sync")
968+
@patch("ads.model.datascience_model.DataScienceModel.upload_artifact")
955969
@patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects")
956970
@patch.object(HfApi, "model_info")
957971
@patch("ads.aqua.common.utils.load_config", return_value={})
958972
def test_import_any_model_smc_container(
959973
self,
960974
mock_load_config,
961975
mock_list_objects,
976+
mock_upload_artifact,
977+
mock_sync,
978+
mock_ocidsc_create,
962979
mock_get_hf_model_info,
980+
mock_init_client,
963981
):
964982
my_model = "oracle/aqua-1t-mega-model"
965983
ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True)
966-
ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock()
967-
DataScienceModel.upload_artifact = MagicMock()
968-
DataScienceModel.sync = MagicMock()
969-
OCIDataScienceModel.create = MagicMock()
970984

971985
os_path = "oci://aqua-bkt@aqua-ns/prefix/path"
972986
ds_freeform_tags = {
@@ -1014,86 +1028,89 @@ def test_import_any_model_smc_container(
10141028
assert model.ready_to_deploy is True
10151029
assert model.ready_to_finetune is True
10161030

1017-
# @pytest.mark.parametrize(
1018-
# "download_from_hf",
1019-
# [True, False],
1020-
# )
1021-
# @patch.object(AquaModelApp, "_find_matching_aqua_model")
1022-
# @patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects")
1023-
# @patch("ads.aqua.common.utils.load_config", return_value={})
1024-
# @patch("huggingface_hub.snapshot_download")
1025-
# @patch("subprocess.check_call")
1026-
# def test_import_tei_model_byoc(
1027-
# self,
1028-
# mock_subprocess,
1029-
# mock_snapshot_download,
1030-
# mock_load_config,
1031-
# mock_list_objects,
1032-
# mock__find_matching_aqua_model,
1033-
# download_from_hf,
1034-
# mock_get_hf_model_info,
1035-
# ):
1036-
# ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True)
1037-
# ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock()
1038-
# DataScienceModel.upload_artifact = MagicMock()
1039-
# DataScienceModel.sync = MagicMock()
1040-
# OCIDataScienceModel.create = MagicMock()
1041-
#
1042-
# artifact_path = "service_models/model-name/commit-id/artifact"
1043-
# obj1 = MagicMock(etag="12345-1234-1234-1234-123456789", size=150)
1044-
# obj1.name = f"{artifact_path}/config.json"
1045-
# objects = [obj1]
1046-
# mock_list_objects.return_value = MagicMock(objects=objects)
1047-
# ds_model = DataScienceModel()
1048-
# os_path = "oci://aqua-bkt@aqua-ns/prefix/path"
1049-
# model_name = "oracle/aqua-1t-mega-model"
1050-
# ds_freeform_tags = {
1051-
# "OCI_AQUA": "ACTIVE",
1052-
# "license": "aqua-license",
1053-
# "organization": "oracle",
1054-
# "task": "text_embedding",
1055-
# }
1056-
# ds_model = (
1057-
# ds_model.with_compartment_id("test_model_compartment_id")
1058-
# .with_project_id("test_project_id")
1059-
# .with_display_name(model_name)
1060-
# .with_description("test_description")
1061-
# .with_model_version_set_id("test_model_version_set_id")
1062-
# .with_freeform_tags(**ds_freeform_tags)
1063-
# .with_version_id("ocid1.version.id")
1064-
# )
1065-
# custom_metadata_list = ModelCustomMetadata()
1066-
# custom_metadata_list.add(
1067-
# **{"key": "deployment-container", "value": "odsc-tei-serving"}
1068-
# )
1069-
# ds_model.with_custom_metadata_list(custom_metadata_list)
1070-
# ds_model.set_spec(ds_model.CONST_MODEL_FILE_DESCRIPTION, {})
1071-
# DataScienceModel.from_id = MagicMock(return_value=ds_model)
1072-
# mock__find_matching_aqua_model.return_value = None
1073-
# reload(ads.aqua.model.model)
1074-
# app = AquaModelApp()
1075-
#
1076-
# if download_from_hf:
1077-
# with tempfile.TemporaryDirectory() as tmpdir:
1078-
# model: AquaModel = app.register(
1079-
# model=model_name,
1080-
# os_path=os_path,
1081-
# local_dir=str(tmpdir),
1082-
# download_from_hf=True,
1083-
# inference_container="odsc-tei-serving",
1084-
# inference_container_uri="region.ocir.io/your_tenancy/your_image",
1085-
# )
1086-
# else:
1087-
# model: AquaModel = app.register(
1088-
# model="ocid1.datasciencemodel.xxx.xxxx.",
1089-
# os_path=os_path,
1090-
# download_from_hf=False,
1091-
# inference_container="odsc-tei-serving",
1092-
# inference_container_uri="region.ocir.io/your_tenancy/your_image",
1093-
# )
1094-
# assert model.inference_container == "odsc-tei-serving"
1095-
# assert model.ready_to_deploy is True
1096-
# assert model.ready_to_finetune is False
1031+
@pytest.mark.parametrize(
1032+
"download_from_hf",
1033+
[True, False],
1034+
)
1035+
@patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create")
1036+
@patch("ads.model.datascience_model.DataScienceModel.sync")
1037+
@patch("ads.model.datascience_model.DataScienceModel.upload_artifact")
1038+
@patch.object(AquaModelApp, "_find_matching_aqua_model")
1039+
@patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects")
1040+
@patch("ads.aqua.common.utils.load_config", return_value={})
1041+
@patch("huggingface_hub.snapshot_download")
1042+
@patch("subprocess.check_call")
1043+
def test_import_tei_model_byoc(
1044+
self,
1045+
mock_subprocess,
1046+
mock_snapshot_download,
1047+
mock_load_config,
1048+
mock_list_objects,
1049+
mock__find_matching_aqua_model,
1050+
mock_upload_artifact,
1051+
mock_sync,
1052+
mock_ocidsc_create,
1053+
download_from_hf,
1054+
mock_get_hf_model_info,
1055+
mock_init_client,
1056+
):
1057+
ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True)
1058+
1059+
artifact_path = "service_models/model-name/commit-id/artifact"
1060+
obj1 = MagicMock(etag="12345-1234-1234-1234-123456789", size=150)
1061+
obj1.name = f"{artifact_path}/config.json"
1062+
objects = [obj1]
1063+
mock_list_objects.return_value = MagicMock(objects=objects)
1064+
ds_model = DataScienceModel()
1065+
os_path = "oci://aqua-bkt@aqua-ns/prefix/path"
1066+
model_name = "oracle/aqua-1t-mega-model"
1067+
ds_freeform_tags = {
1068+
"OCI_AQUA": "ACTIVE",
1069+
"license": "aqua-license",
1070+
"organization": "oracle",
1071+
"task": "text_embedding",
1072+
}
1073+
ds_model = (
1074+
ds_model.with_compartment_id("test_model_compartment_id")
1075+
.with_project_id("test_project_id")
1076+
.with_display_name(model_name)
1077+
.with_description("test_description")
1078+
.with_model_version_set_id("test_model_version_set_id")
1079+
.with_freeform_tags(**ds_freeform_tags)
1080+
.with_version_id("ocid1.version.id")
1081+
)
1082+
custom_metadata_list = ModelCustomMetadata()
1083+
custom_metadata_list.add(
1084+
**{"key": "deployment-container", "value": "odsc-tei-serving"}
1085+
)
1086+
ds_model.with_custom_metadata_list(custom_metadata_list)
1087+
ds_model.set_spec(ds_model.CONST_MODEL_FILE_DESCRIPTION, {})
1088+
DataScienceModel.from_id = MagicMock(return_value=ds_model)
1089+
mock__find_matching_aqua_model.return_value = None
1090+
reload(ads.aqua.model.model)
1091+
app = AquaModelApp()
1092+
1093+
if download_from_hf:
1094+
with tempfile.TemporaryDirectory() as tmpdir:
1095+
model: AquaModel = app.register(
1096+
model=model_name,
1097+
os_path=os_path,
1098+
local_dir=str(tmpdir),
1099+
download_from_hf=True,
1100+
inference_container="odsc-tei-serving",
1101+
inference_container_uri="region.ocir.io/your_tenancy/your_image",
1102+
)
1103+
else:
1104+
model: AquaModel = app.register(
1105+
model="ocid1.datasciencemodel.xxx.xxxx.",
1106+
os_path=os_path,
1107+
download_from_hf=False,
1108+
inference_container="odsc-tei-serving",
1109+
inference_container_uri="region.ocir.io/your_tenancy/your_image",
1110+
)
1111+
assert model.inference_container == "odsc-tei-serving"
1112+
assert model.ready_to_deploy is True
1113+
assert model.ready_to_finetune is False
10971114

10981115
@pytest.mark.parametrize(
10991116
"data, expected_output",

0 commit comments

Comments
 (0)