Skip to content

Commit 3cdcea2

Browse files
ODSC-34939: Implement optional loading of artifacts in the GenericModel (#488)
2 parents 94965b5 + 3a0b38f commit 3cdcea2

File tree

2 files changed

+325
-4
lines changed

2 files changed

+325
-4
lines changed

ads/model/generic_model.py

Lines changed: 189 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,13 @@ def __init__(self, state: str):
153153
super().__init__(msg)
154154

155155

156+
class ArtifactsNotAvailableError(Exception):
157+
def __init__(
158+
self, msg="Model artifacts are either not generated or not available locally."
159+
):
160+
super().__init__(msg)
161+
162+
156163
class SerializeModelNotImplementedError(NotImplementedError): # pragma: no cover
157164
pass
158165

@@ -281,6 +288,8 @@ class GenericModel(MetadataMixin, Introspectable, EvaluatorMixin):
281288
Tests if deployment works in local environment.
282289
upload_artifact(...)
283290
Uploads model artifacts to the provided `uri`.
291+
download_artifact(...)
292+
Downloads model artifacts from the model catalog.
284293
285294
286295
Examples
@@ -1249,6 +1258,9 @@ def verify(
12491258
Dict
12501259
A dictionary which contains prediction results.
12511260
"""
1261+
if self.model_artifact is None:
1262+
raise ArtifactsNotAvailableError
1263+
12521264
endpoint = f"http://127.0.0.1:8000/predict"
12531265
data = self._handle_input_data(data, auto_serialize_data, **kwargs)
12541266

@@ -1402,6 +1414,117 @@ def from_model_artifact(
14021414

14031415
return model
14041416

1417+
def download_artifact(
1418+
self,
1419+
artifact_dir: Optional[str] = None,
1420+
auth: Optional[Dict] = None,
1421+
force_overwrite: Optional[bool] = False,
1422+
bucket_uri: Optional[str] = None,
1423+
remove_existing_artifact: Optional[bool] = True,
1424+
**kwargs,
1425+
) -> "GenericModel":
1426+
"""Downloads model artifacts from the model catalog.
1427+
1428+
Parameters
1429+
----------
1430+
artifact_dir: (str, optional). Defaults to `None`.
1431+
The artifact directory to store the files needed for deployment.
1432+
Will be created if not exists.
1433+
auth: (Dict, optional). Defaults to None.
1434+
The default authentication is set using `ads.set_auth` API. If you need to override the
1435+
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
1436+
authentication signer and kwargs required to instantiate IdentityClient object.
1437+
force_overwrite: (bool, optional). Defaults to False.
1438+
Whether to overwrite existing files or not.
1439+
bucket_uri: (str, optional). Defaults to None.
1440+
The OCI Object Storage URI where model artifacts will be copied to.
1441+
The `bucket_uri` is only necessary for downloading large artifacts with
1442+
size is greater than 2GB. Example: `oci://<bucket_name>@<namespace>/prefix/`.
1443+
remove_existing_artifact: (bool, optional). Defaults to `True`.
1444+
Whether artifacts uploaded to object storage bucket need to be removed or not.
1445+
1446+
Returns
1447+
-------
1448+
Self
1449+
An instance of `GenericModel` class.
1450+
1451+
Raises
1452+
------
1453+
ValueError
1454+
If `model_id` is not available in the GenericModel object.
1455+
"""
1456+
model_id = self.model_id
1457+
if not model_id:
1458+
raise ValueError(
1459+
"`model_id` is not available, load the GenericModel object first."
1460+
)
1461+
1462+
if not artifact_dir:
1463+
artifact_dir = self.artifact_dir
1464+
artifact_dir = _prepare_artifact_dir(artifact_dir)
1465+
1466+
target_dir = (
1467+
_prepare_artifact_dir()
1468+
if ObjectStorageDetails.is_oci_path(artifact_dir)
1469+
else artifact_dir
1470+
)
1471+
1472+
dsc_model = DataScienceModel.from_id(model_id)
1473+
dsc_model.download_artifact(
1474+
target_dir=target_dir,
1475+
force_overwrite=force_overwrite,
1476+
bucket_uri=bucket_uri,
1477+
remove_existing_artifact=remove_existing_artifact,
1478+
auth=auth,
1479+
region=kwargs.pop("region", None),
1480+
timeout=kwargs.pop("timeout", None),
1481+
)
1482+
model_artifact = ModelArtifact.from_uri(
1483+
uri=target_dir,
1484+
artifact_dir=artifact_dir,
1485+
model_file_name=self.model_file_name,
1486+
force_overwrite=force_overwrite,
1487+
auth=auth,
1488+
ignore_conda_error=self.ignore_conda_error,
1489+
)
1490+
self.dsc_model = dsc_model
1491+
self.local_copy_dir = model_artifact.local_copy_dir
1492+
self.model_artifact = model_artifact
1493+
self.reload_runtime_info()
1494+
1495+
self._summary_status.update_status(
1496+
detail="Generated score.py",
1497+
status=ModelState.DONE.value,
1498+
)
1499+
self._summary_status.update_status(
1500+
detail="Generated runtime.yaml",
1501+
status=ModelState.DONE.value,
1502+
)
1503+
self._summary_status.update_status(
1504+
detail="Serialized model", status=ModelState.DONE.value
1505+
)
1506+
self._summary_status.update_status(
1507+
detail="Populated metadata(Custom, Taxonomy and Provenance)",
1508+
status=ModelState.DONE.value,
1509+
)
1510+
self._summary_status.update_status(
1511+
detail="Local tested .predict from score.py",
1512+
status=ModelState.AVAILABLE.value,
1513+
)
1514+
self._summary_status.update_action(
1515+
detail="Local tested .predict from score.py",
1516+
action="",
1517+
)
1518+
self._summary_status.update_status(
1519+
detail="Conducted Introspect Test",
1520+
status=ModelState.AVAILABLE.value,
1521+
)
1522+
self._summary_status.update_status(
1523+
detail="Uploaded artifact to model catalog",
1524+
status=ModelState.AVAILABLE.value,
1525+
)
1526+
return self
1527+
14051528
@classmethod
14061529
def from_model_catalog(
14071530
cls: Type[Self],
@@ -1414,6 +1537,7 @@ def from_model_catalog(
14141537
bucket_uri: Optional[str] = None,
14151538
remove_existing_artifact: Optional[bool] = True,
14161539
ignore_conda_error: Optional[bool] = False,
1540+
download_artifact: Optional[bool] = True,
14171541
**kwargs,
14181542
) -> Self:
14191543
"""Loads model from model catalog.
@@ -1443,6 +1567,8 @@ def from_model_catalog(
14431567
Wether artifacts uploaded to object storage bucket need to be removed or not.
14441568
ignore_conda_error: (bool, optional). Defaults to False.
14451569
Parameter to ignore error when collecting conda information.
1570+
download_artifact: (bool, optional). Defaults to True.
1571+
Whether to download the model pickle or checkpoints
14461572
kwargs:
14471573
compartment_id : (str, optional)
14481574
Compartment OCID. If not specified, the value will be taken from the environment variables.
@@ -1475,14 +1601,60 @@ def from_model_catalog(
14751601
artifact_dir = _prepare_artifact_dir(artifact_dir)
14761602

14771603
target_dir = (
1478-
artifact_dir
1479-
if not ObjectStorageDetails.is_oci_path(artifact_dir)
1480-
else tempfile.mkdtemp()
1604+
_prepare_artifact_dir()
1605+
if ObjectStorageDetails.is_oci_path(artifact_dir)
1606+
else artifact_dir
14811607
)
14821608
bucket_uri = bucket_uri or (
14831609
artifact_dir if ObjectStorageDetails.is_oci_path(artifact_dir) else None
14841610
)
14851611
dsc_model = DataScienceModel.from_id(model_id)
1612+
1613+
if not download_artifact:
1614+
result_model = cls(
1615+
artifact_dir=artifact_dir,
1616+
bucket_uri=bucket_uri,
1617+
auth=auth,
1618+
properties=properties,
1619+
ignore_conda_error=ignore_conda_error,
1620+
**kwargs,
1621+
)
1622+
result_model._summary_status.update_status(
1623+
detail="Generated score.py",
1624+
status=ModelState.NOTAPPLICABLE.value,
1625+
)
1626+
result_model._summary_status.update_status(
1627+
detail="Generated runtime.yaml",
1628+
status=ModelState.NOTAPPLICABLE.value,
1629+
)
1630+
result_model._summary_status.update_status(
1631+
detail="Serialized model", status=ModelState.NOTAPPLICABLE.value
1632+
)
1633+
result_model._summary_status.update_status(
1634+
detail="Populated metadata(Custom, Taxonomy and Provenance)",
1635+
status=ModelState.NOTAPPLICABLE.value,
1636+
)
1637+
result_model._summary_status.update_status(
1638+
detail="Local tested .predict from score.py",
1639+
status=ModelState.NOTAPPLICABLE.value,
1640+
)
1641+
result_model._summary_status.update_action(
1642+
detail="Local tested .predict from score.py",
1643+
action="Local artifact is not available. "
1644+
"Set load_artifact flag to True while loading the model or "
1645+
"call .download_artifact().",
1646+
)
1647+
result_model._summary_status.update_status(
1648+
detail="Conducted Introspect Test",
1649+
status=ModelState.NOTAPPLICABLE.value,
1650+
)
1651+
result_model._summary_status.update_status(
1652+
detail="Uploaded artifact to model catalog",
1653+
status=ModelState.NOTAPPLICABLE.value,
1654+
)
1655+
result_model.dsc_model = dsc_model
1656+
return result_model
1657+
14861658
dsc_model.download_artifact(
14871659
target_dir=target_dir,
14881660
force_overwrite=force_overwrite,
@@ -1536,6 +1708,7 @@ def from_model_deployment(
15361708
bucket_uri: Optional[str] = None,
15371709
remove_existing_artifact: Optional[bool] = True,
15381710
ignore_conda_error: Optional[bool] = False,
1711+
download_artifact: Optional[bool] = True,
15391712
**kwargs,
15401713
) -> Self:
15411714
"""Loads model from model deployment.
@@ -1565,6 +1738,8 @@ def from_model_deployment(
15651738
Wether artifacts uploaded to object storage bucket need to be removed or not.
15661739
ignore_conda_error: (bool, optional). Defaults to False.
15671740
Parameter to ignore error when collecting conda information.
1741+
download_artifact: (bool, optional). Defaults to True.
1742+
Whether to download the model pickle or checkpoints
15681743
kwargs:
15691744
compartment_id : (str, optional)
15701745
Compartment OCID. If not specified, the value will be taken from the environment variables.
@@ -1608,6 +1783,7 @@ def from_model_deployment(
16081783
bucket_uri=bucket_uri,
16091784
remove_existing_artifact=remove_existing_artifact,
16101785
ignore_conda_error=ignore_conda_error,
1786+
download_artifact=download_artifact,
16111787
**kwargs,
16121788
)
16131789
model._summary_status.update_status(
@@ -1730,6 +1906,7 @@ def from_id(
17301906
bucket_uri: Optional[str] = None,
17311907
remove_existing_artifact: Optional[bool] = True,
17321908
ignore_conda_error: Optional[bool] = False,
1909+
download_artifact: Optional[bool] = True,
17331910
**kwargs,
17341911
) -> Self:
17351912
"""Loads model from model OCID or model deployment OCID.
@@ -1757,6 +1934,10 @@ def from_id(
17571934
size is greater than 2GB. Example: `oci://<bucket_name>@<namespace>/prefix/`.
17581935
remove_existing_artifact: (bool, optional). Defaults to `True`.
17591936
Wether artifacts uploaded to object storage bucket need to be removed or not.
1937+
ignore_conda_error: (bool, optional). Defaults to False.
1938+
Parameter to ignore error when collecting conda information.
1939+
download_artifact: (bool, optional). Defaults to True.
1940+
Whether to download the model pickle or checkpoints
17601941
kwargs:
17611942
compartment_id : (str, optional)
17621943
Compartment OCID. If not specified, the value will be taken from the environment variables.
@@ -1780,6 +1961,7 @@ def from_id(
17801961
bucket_uri=bucket_uri,
17811962
remove_existing_artifact=remove_existing_artifact,
17821963
ignore_conda_error=ignore_conda_error,
1964+
download_artifact=download_artifact,
17831965
**kwargs,
17841966
)
17851967
elif DataScienceModelType.MODEL in ocid:
@@ -1793,6 +1975,7 @@ def from_id(
17931975
bucket_uri=bucket_uri,
17941976
remove_existing_artifact=remove_existing_artifact,
17951977
ignore_conda_error=ignore_conda_error,
1978+
download_artifact=download_artifact,
17961979
**kwargs,
17971980
)
17981981
else:
@@ -1924,11 +2107,13 @@ def save(
19242107
... bucket_uri="oci://my-bucket@my-tenancy/",
19252108
... overwrite_existing_artifact=True,
19262109
... remove_existing_artifact=True,
1927-
... remove_existing_artifact=True,
19282110
... parallel_process_count=9,
19292111
... )
19302112
19312113
"""
2114+
if self.model_artifact is None:
2115+
raise ArtifactsNotAvailableError
2116+
19322117
# Set default display_name if not specified - randomly generated easy to remember name generated
19332118
if not display_name:
19342119
display_name = self._random_display_name()

0 commit comments

Comments
 (0)