Skip to content

Commit 616f11a

Browse files
authored
BugFix/add reload option in GenericModel.save() (#462)
1 parent 0f497b0 commit 616f11a

File tree

2 files changed

+83
-34
lines changed

2 files changed

+83
-34
lines changed

ads/model/generic_model.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@
6767
Framework,
6868
ModelCustomMetadata,
6969
ModelProvenanceMetadata,
70-
ModelTaxonomyMetadata, MetadataCustomCategory,
70+
ModelTaxonomyMetadata,
71+
MetadataCustomCategory,
7172
)
7273
from ads.model.model_metadata_mixin import MetadataMixin
7374
from ads.model.model_properties import ModelProperties
@@ -1824,18 +1825,19 @@ def _random_display_name(self):
18241825

18251826
def save(
18261827
self,
1827-
display_name: Optional[str] = None,
1828+
bucket_uri: Optional[str] = None,
1829+
defined_tags: Optional[dict] = None,
18281830
description: Optional[str] = None,
1831+
display_name: Optional[str] = None,
1832+
featurestore_dataset=None,
18291833
freeform_tags: Optional[dict] = None,
1830-
defined_tags: Optional[dict] = None,
18311834
ignore_introspection: Optional[bool] = False,
1832-
bucket_uri: Optional[str] = None,
1835+
model_version_set: Optional[Union[str, ModelVersionSet]] = None,
18331836
overwrite_existing_artifact: Optional[bool] = True,
1837+
parallel_process_count: int = utils.DEFAULT_PARALLEL_PROCESS_COUNT,
18341838
remove_existing_artifact: Optional[bool] = True,
1835-
model_version_set: Optional[Union[str, ModelVersionSet]] = None,
1839+
reload: Optional[bool] = True,
18361840
version_label: Optional[str] = None,
1837-
featurestore_dataset=None,
1838-
parallel_process_count: int = utils.DEFAULT_PARALLEL_PROCESS_COUNT,
18391841
**kwargs,
18401842
) -> str:
18411843
"""Saves model artifacts to the model catalog.
@@ -1862,7 +1864,7 @@ def save(
18621864
overwrite_existing_artifact: (bool, optional). Defaults to `True`.
18631865
Overwrite target bucket artifact if exists.
18641866
remove_existing_artifact: (bool, optional). Defaults to `True`.
1865-
Wether artifacts uploaded to object storage bucket need to be removed or not.
1867+
Whether artifacts uploaded to object storage bucket need to be removed or not.
18661868
model_version_set: (Union[str, ModelVersionSet], optional). Defaults to None.
18671869
The model version set OCID, or model version set name, or `ModelVersionSet` instance.
18681870
version_label: (str, optional). Defaults to None.
@@ -1871,6 +1873,8 @@ def save(
18711873
The feature store dataset
18721874
parallel_process_count: (int, optional)
18731875
The number of worker processes to use in parallel for uploading individual parts of a multipart upload.
1876+
reload: (bool, optional)
1877+
Whether to reload to check if `load_model()` works in `score.py`. Default to `True`.
18741878
kwargs:
18751879
project_id: (str, optional).
18761880
Project OCID. If not specified, the value will be taken either
@@ -1926,12 +1930,17 @@ def save(
19261930
raise RuntimeInfoInconsistencyError(
19271931
"`.runtime_info` does not sync with runtime.yaml file. Call "
19281932
"`.runtime_info.save()` if you updated `runtime_info`. "
1929-
"Call `.reload()` if you updated runtime.yaml file."
1933+
"Call `.reload_runtime_info()` if you updated runtime.yaml file."
19301934
)
19311935
# reload to check if load_model works in score.py, i.e.
19321936
# whether the model file has been serialized, and whether it can be loaded
19331937
# successfully.
1934-
self.reload()
1938+
if reload:
1939+
self.reload()
1940+
else:
1941+
logger.warning(
1942+
"The score.py file has not undergone testing, and this could result in deployment errors. To verify its functionality, please set `reload=True`."
1943+
)
19351944
except:
19361945
if not self.ignore_conda_error:
19371946
raise
@@ -1967,11 +1976,15 @@ def save(
19671976
if featurestore_dataset:
19681977
dataset_details = {
19691978
"dataset-id": featurestore_dataset.id,
1970-
"dataset-name": featurestore_dataset.name
1979+
"dataset-name": featurestore_dataset.name,
19711980
}
1972-
self.metadata_custom.add("featurestore.dataset", value=str(dataset_details),
1973-
category=MetadataCustomCategory.TRAINING_AND_VALIDATION_DATASETS,
1974-
description="feature store dataset", replace=True)
1981+
self.metadata_custom.add(
1982+
"featurestore.dataset",
1983+
value=str(dataset_details),
1984+
category=MetadataCustomCategory.TRAINING_AND_VALIDATION_DATASETS,
1985+
description="feature store dataset",
1986+
replace=True,
1987+
)
19751988

19761989
self.dsc_model = (
19771990
self.dsc_model.with_compartment_id(self.properties.compartment_id)

tests/unitary/with_extras/model/test_generic_model.py

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -283,19 +283,19 @@ def test_prepare_fail(self, mock_handle_model_file_name):
283283
@patch("ads.common.auth.default_signer")
284284
def test_prepare_both_conda_env(self, mock_signer, mock_get_service_packs):
285285
"""prepare a model by only providing inference conda env."""
286-
inference_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1"
287-
inference_python_version="3.6"
288-
training_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1"
289-
training_python_version="3.7"
286+
inference_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1"
287+
inference_python_version = "3.6"
288+
training_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1"
289+
training_python_version = "3.7"
290290
mock_get_service_packs.return_value = (
291291
{
292-
inference_conda_env : ("mlcpuv1", inference_python_version),
293-
training_conda_env : ("database_p37_cpu_v1", training_python_version)
292+
inference_conda_env: ("mlcpuv1", inference_python_version),
293+
training_conda_env: ("database_p37_cpu_v1", training_python_version),
294294
},
295295
{
296-
"mlcpuv1" : (inference_conda_env, inference_python_version),
297-
"database_p37_cpu_v1" : (training_conda_env, training_python_version)
298-
}
296+
"mlcpuv1": (inference_conda_env, inference_python_version),
297+
"database_p37_cpu_v1": (training_conda_env, training_python_version),
298+
},
299299
)
300300
self.generic_model.prepare(
301301
inference_conda_env=inference_conda_env,
@@ -365,17 +365,19 @@ def test_reload(self):
365365
@patch.object(GenericModel, "_random_display_name", return_value="test_name")
366366
@patch.object(DataScienceModel, "create")
367367
@patch("ads.model.runtime.env_info.get_service_packs")
368-
def test_save(self, mock_get_service_packs, mock_dsc_model_create, mock__random_display_name):
368+
def test_save(
369+
self, mock_get_service_packs, mock_dsc_model_create, mock__random_display_name
370+
):
369371
"""test saving a model to artifact."""
370-
inference_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/Data_Exploration_and_Manipulation_for_CPU_Python_3.7/3.0/dataexpl_p37_cpu_v3"
371-
inference_python_version="3.7"
372+
inference_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/Data_Exploration_and_Manipulation_for_CPU_Python_3.7/3.0/dataexpl_p37_cpu_v3"
373+
inference_python_version = "3.7"
372374
mock_get_service_packs.return_value = (
373375
{
374-
inference_conda_env : ("dataexpl_p37_cpu_v3", inference_python_version),
376+
inference_conda_env: ("dataexpl_p37_cpu_v3", inference_python_version),
375377
},
376378
{
377-
"dataexpl_p37_cpu_v3" : (inference_conda_env, inference_python_version),
378-
}
379+
"dataexpl_p37_cpu_v3": (inference_conda_env, inference_python_version),
380+
},
379381
)
380382
mock_dsc_model_create.return_value = MagicMock(id="fake_id")
381383
self.generic_model.prepare(
@@ -400,15 +402,15 @@ def test_save(self, mock_get_service_packs, mock_dsc_model_create, mock__random_
400402
@patch("ads.model.runtime.env_info.get_service_packs")
401403
def test_save_not_implemented_error(self, mock_get_service_packs):
402404
"""test saving a model to artifact."""
403-
inference_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/Data_Exploration_and_Manipulation_for_CPU_Python_3.7/3.0/dataexpl_p37_cpu_v3"
404-
inference_python_version="3.7"
405+
inference_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/Data_Exploration_and_Manipulation_for_CPU_Python_3.7/3.0/dataexpl_p37_cpu_v3"
406+
inference_python_version = "3.7"
405407
mock_get_service_packs.return_value = (
406408
{
407-
inference_conda_env : ("dataexpl_p37_cpu_v3", inference_python_version),
409+
inference_conda_env: ("dataexpl_p37_cpu_v3", inference_python_version),
408410
},
409411
{
410-
"dataexpl_p37_cpu_v3" : (inference_conda_env, inference_python_version),
411-
}
412+
"dataexpl_p37_cpu_v3": (inference_conda_env, inference_python_version),
413+
},
412414
)
413415
self.generic_model._serialize = False
414416
self.generic_model.prepare(
@@ -429,6 +431,40 @@ def test_save_not_implemented_error(self, mock_get_service_packs):
429431
with pytest.raises(NotImplementedError):
430432
self.generic_model.save()
431433

434+
@patch.object(GenericModel, "_random_display_name", return_value="test_name")
435+
@patch.object(DataScienceModel, "create")
436+
@patch("ads.model.runtime.env_info.get_service_packs")
437+
@patch("ads.model.GenericModel.reload")
438+
def test_save_not_reload(
439+
self,
440+
mock_reload,
441+
mock_get_service_packs,
442+
mock_dsc_model_create,
443+
mock__random_display_name,
444+
):
445+
"""test saving a model to artifact without verify score.py."""
446+
inference_conda_env = "oci://bucket@tenancy/prefix/dataexpl_p37_cpu_v3"
447+
inference_python_version = "3.7"
448+
mock_get_service_packs.return_value = (
449+
{
450+
inference_conda_env: ("dataexpl_p37_cpu_v3", inference_python_version),
451+
},
452+
{
453+
"dataexpl_p37_cpu_v3": (inference_conda_env, inference_python_version),
454+
},
455+
)
456+
mock_dsc_model_create.return_value = MagicMock(id="fake_id")
457+
self.generic_model.prepare(
458+
inference_conda_env="dataexpl_p37_cpu_v3",
459+
namespace="ociodscdev",
460+
inference_python_version="3.7",
461+
model_file_name="model.joblib",
462+
force_overwrite=True,
463+
training_id=None,
464+
)
465+
self.generic_model.save(ignore_introspection=True, reload=False)
466+
mock_reload.assert_not_called()
467+
432468
def test_set_model_input_serializer(self):
433469
"""Tests set_model_input_serializer() with different input types."""
434470
from ads.model.serde.model_input import (

0 commit comments

Comments
 (0)