Skip to content

Commit 63da499

Browse files
authored
ODSC-50786/GenericModel: add reload options to ModelArtifact.from_uri (#499)
1 parent cf636ff commit 63da499

File tree

3 files changed

+45
-20
lines changed

3 files changed

+45
-20
lines changed

ads/model/artifact.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,21 @@
3030
ADS_VERSION = __version__
3131

3232

33-
class ArtifactNestedFolderError(Exception): # pragma: no cover
33+
class ArtifactNestedFolderError(Exception): # pragma: no cover
3434
def __init__(self, folder: str):
3535
self.folder = folder
3636
super().__init__("The required artifact files placed in a nested folder.")
3737

3838

39-
class ArtifactRequiredFilesError(Exception): # pragma: no cover
39+
class ArtifactRequiredFilesError(Exception): # pragma: no cover
4040
def __init__(self, required_files: Tuple[str]):
4141
super().__init__(
4242
"Not all required files presented in artifact folder. "
4343
f"Required files for conda runtime: {required_files}. If you are using container runtime, set `ignore_conda_error=True`."
4444
)
4545

4646

47-
class AritfactFolderStructureError(Exception): # pragma: no cover
47+
class AritfactFolderStructureError(Exception): # pragma: no cover
4848
def __init__(self, required_files: Tuple[str]):
4949
super().__init__(
5050
"The artifact folder has a wrong structure. "
@@ -171,6 +171,7 @@ def __init__(
171171
self.ignore_conda_error = ignore_conda_error
172172
self.model = None
173173
self.auth = auth or authutil.default_signer()
174+
174175
if reload and not ignore_conda_error:
175176
self.reload()
176177
# Extracts the model_file_name from the score.py.
@@ -416,6 +417,7 @@ def from_uri(
416417
force_overwrite: Optional[bool] = False,
417418
auth: Optional[Dict] = None,
418419
ignore_conda_error: Optional[bool] = False,
420+
reload: Optional[bool] = False,
419421
):
420422
"""Constructs a ModelArtifact object from the existing model artifacts.
421423
@@ -426,16 +428,20 @@ def from_uri(
426428
OCI object storage URI.
427429
artifact_dir: str
428430
The local artifact folder to store the files needed for deployment.
429-
model_file_name: (str, optional). Defaults to `None`
430-
The file name of the serialized model.
431-
force_overwrite: (bool, optional). Defaults to False.
432-
Whether to overwrite existing files or not.
433431
auth: (Dict, optional). Defaults to None.
434432
The default authetication is set using `ads.set_auth` API.
435433
If you need to override the default, use the `ads.common.auth.api_keys`
436434
or `ads.common.auth.resource_principal` to create appropriate
437435
authentication signer and kwargs required to instantiate
438436
IdentityClient object.
437+
force_overwrite: (bool, optional). Defaults to False.
438+
Whether to overwrite existing files or not.
439+
ignore_conda_error: (bool, optional). Defaults to False.
440+
Parameter to ignore error when collecting conda information.
441+
model_file_name: (str, optional). Defaults to `None`
442+
The file name of the serialized model.
443+
reload: (bool, optional). Defaults to False.
444+
Whether to reload the Model into the environment.
439445
440446
Returns
441447
-------
@@ -492,6 +498,8 @@ def from_uri(
492498
utils.copy_from_uri(
493499
uri=temp_dir, to_path=to_path, force_overwrite=True
494500
)
501+
except ArtifactRequiredFilesError as ex:
502+
logger.warning(ex)
495503

496504
if ObjectStorageDetails.is_oci_path(artifact_dir):
497505
for root, dirs, files in os.walk(to_path):
@@ -507,10 +515,10 @@ def from_uri(
507515

508516
return cls(
509517
artifact_dir=artifact_dir,
510-
model_file_name=model_file_name,
511-
reload=True,
512518
ignore_conda_error=ignore_conda_error,
513519
local_copy_dir=to_path,
520+
model_file_name=model_file_name,
521+
reload=reload,
514522
)
515523

516524
def __getattr__(self, item):

ads/model/generic_model.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,13 +1349,15 @@ def from_model_artifact(
13491349
properties.with_dict(local_vars)
13501350
auth = auth or authutil.default_signer()
13511351
artifact_dir = _prepare_artifact_dir(artifact_dir)
1352+
reload = kwargs.pop("reload", False)
13521353
model_artifact = ModelArtifact.from_uri(
13531354
uri=uri,
13541355
artifact_dir=artifact_dir,
1355-
model_file_name=model_file_name,
1356-
force_overwrite=force_overwrite,
13571356
auth=auth,
1357+
force_overwrite=force_overwrite,
13581358
ignore_conda_error=ignore_conda_error,
1359+
model_file_name=model_file_name,
1360+
reload=reload,
13591361
)
13601362
model = cls(
13611363
estimator=model_artifact.model,
@@ -1368,22 +1370,33 @@ def from_model_artifact(
13681370
model.local_copy_dir = model_artifact.local_copy_dir
13691371
model.model_artifact = model_artifact
13701372
model.ignore_conda_error = ignore_conda_error
1371-
model.reload_runtime_info()
1373+
1374+
if reload:
1375+
model.reload_runtime_info()
1376+
model._summary_status.update_action(
1377+
detail="Populated metadata(Custom, Taxonomy and Provenance)",
1378+
action="Call .populate_metadata() to populate metadata.",
1379+
)
1380+
13721381
model._summary_status.update_status(
13731382
detail="Generated score.py",
1374-
status=ModelState.DONE.value,
1383+
status=ModelState.NOTAPPLICABLE.value,
13751384
)
13761385
model._summary_status.update_status(
13771386
detail="Generated runtime.yaml",
1378-
status=ModelState.DONE.value,
1387+
status=ModelState.NOTAPPLICABLE.value,
13791388
)
13801389
model._summary_status.update_status(
1381-
detail="Serialized model", status=ModelState.DONE.value
1390+
detail="Serialized model",
1391+
status=ModelState.NOTAPPLICABLE.value,
13821392
)
1383-
model._summary_status.update_action(
1393+
model._summary_status.update_status(
13841394
detail="Populated metadata(Custom, Taxonomy and Provenance)",
1385-
action=f"Call .populate_metadata() to populate metadata.",
1395+
status=ModelState.AVAILABLE.value
1396+
if reload
1397+
else ModelState.NOTAPPLICABLE.value,
13861398
)
1399+
13871400
return model
13881401

13891402
@classmethod
@@ -3034,6 +3047,7 @@ class ModelState(Enum):
30343047
AVAILABLE = "Available"
30353048
NOTAVAILABLE = "Not Available"
30363049
NEEDSACTION = "Needs Action"
3050+
NOTAPPLICABLE = "Not Applicable"
30373051

30383052

30393053
class SummaryStatus:

tests/unitary/with_extras/model/test_generic_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,6 @@ def test_predict_success(self, mock_url, mock_client, mock_signer):
807807
# )
808808
assert test_result == {"result": "result"}
809809

810-
# @pytest.mark.skip(reason="need to fix later.")
811810
@pytest.mark.parametrize(
812811
"test_args",
813812
[
@@ -849,6 +848,7 @@ def test_predict_success(self, mock_url, mock_client, mock_signer):
849848
"uri": "src/folder",
850849
"force_overwrite": True,
851850
"auth": {"config": "value"},
851+
"reload": True,
852852
},
853853
],
854854
)
@@ -882,6 +882,7 @@ def test_from_model_artifact(
882882
auth=test_args.get("auth", {"config": "value"}),
883883
model_file_name=test_args.get("model_file_name"),
884884
ignore_conda_error=False,
885+
reload=test_args.get("reload", False),
885886
)
886887

887888
if not test_args.get("as_onnx"):
@@ -899,8 +900,10 @@ def test_from_model_artifact(
899900
properties=ModelProperties(),
900901
as_onnx=True,
901902
)
902-
903-
mock_reload_runtime_info.assert_called()
903+
if test_args.get("reload", False):
904+
mock_reload_runtime_info.assert_called()
905+
else:
906+
mock_reload_runtime_info.assert_not_called()
904907
assert test_result == None
905908

906909
@patch("ads.common.auth.default_signer")

0 commit comments

Comments
 (0)