Skip to content

Commit 3a0b38f

Browse files
change load_artifact flag to download_artifact
1 parent a33f348 commit 3a0b38f

File tree

2 files changed

+22
-22
lines changed

2 files changed

+22
-22
lines changed

ads/model/generic_model.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,7 +1521,7 @@ def from_model_catalog(
15211521
bucket_uri: Optional[str] = None,
15221522
remove_existing_artifact: Optional[bool] = True,
15231523
ignore_conda_error: Optional[bool] = False,
1524-
load_artifact: Optional[bool] = True,
1524+
download_artifact: Optional[bool] = True,
15251525
**kwargs,
15261526
) -> Self:
15271527
"""Loads model from model catalog.
@@ -1551,7 +1551,7 @@ def from_model_catalog(
15511551
Wether artifacts uploaded to object storage bucket need to be removed or not.
15521552
ignore_conda_error: (bool, optional). Defaults to False.
15531553
Parameter to ignore error when collecting conda information.
1554-
load_artifact: (bool, optional). Defaults to True.
1554+
download_artifact: (bool, optional). Defaults to True.
15551555
Whether to download the model pickle or checkpoints
15561556
kwargs:
15571557
compartment_id : (str, optional)
@@ -1594,7 +1594,7 @@ def from_model_catalog(
15941594
)
15951595
dsc_model = DataScienceModel.from_id(model_id)
15961596

1597-
if not load_artifact:
1597+
if not download_artifact:
15981598
result_model = cls(
15991599
artifact_dir=artifact_dir,
16001600
bucket_uri=bucket_uri,
@@ -1692,7 +1692,7 @@ def from_model_deployment(
16921692
bucket_uri: Optional[str] = None,
16931693
remove_existing_artifact: Optional[bool] = True,
16941694
ignore_conda_error: Optional[bool] = False,
1695-
load_artifact: Optional[bool] = True,
1695+
download_artifact: Optional[bool] = True,
16961696
**kwargs,
16971697
) -> Self:
16981698
"""Loads model from model deployment.
@@ -1722,7 +1722,7 @@ def from_model_deployment(
17221722
Wether artifacts uploaded to object storage bucket need to be removed or not.
17231723
ignore_conda_error: (bool, optional). Defaults to False.
17241724
Parameter to ignore error when collecting conda information.
1725-
load_artifact: (bool, optional). Defaults to True.
1725+
download_artifact: (bool, optional). Defaults to True.
17261726
Whether to download the model pickle or checkpoints
17271727
kwargs:
17281728
compartment_id : (str, optional)
@@ -1767,7 +1767,7 @@ def from_model_deployment(
17671767
bucket_uri=bucket_uri,
17681768
remove_existing_artifact=remove_existing_artifact,
17691769
ignore_conda_error=ignore_conda_error,
1770-
load_artifact=load_artifact,
1770+
download_artifact=download_artifact,
17711771
**kwargs,
17721772
)
17731773
model._summary_status.update_status(
@@ -1890,7 +1890,7 @@ def from_id(
18901890
bucket_uri: Optional[str] = None,
18911891
remove_existing_artifact: Optional[bool] = True,
18921892
ignore_conda_error: Optional[bool] = False,
1893-
load_artifact: Optional[bool] = True,
1893+
download_artifact: Optional[bool] = True,
18941894
**kwargs,
18951895
) -> Self:
18961896
"""Loads model from model OCID or model deployment OCID.
@@ -1920,7 +1920,7 @@ def from_id(
19201920
Wether artifacts uploaded to object storage bucket need to be removed or not.
19211921
ignore_conda_error: (bool, optional). Defaults to False.
19221922
Parameter to ignore error when collecting conda information.
1923-
load_artifact: (bool, optional). Defaults to True.
1923+
download_artifact: (bool, optional). Defaults to True.
19241924
Whether to download the model pickle or checkpoints
19251925
kwargs:
19261926
compartment_id : (str, optional)
@@ -1945,7 +1945,7 @@ def from_id(
19451945
bucket_uri=bucket_uri,
19461946
remove_existing_artifact=remove_existing_artifact,
19471947
ignore_conda_error=ignore_conda_error,
1948-
load_artifact=load_artifact,
1948+
download_artifact=download_artifact,
19491949
**kwargs,
19501950
)
19511951
elif DataScienceModelType.MODEL in ocid:
@@ -1959,7 +1959,7 @@ def from_id(
19591959
bucket_uri=bucket_uri,
19601960
remove_existing_artifact=remove_existing_artifact,
19611961
ignore_conda_error=ignore_conda_error,
1962-
load_artifact=load_artifact,
1962+
download_artifact=download_artifact,
19631963
**kwargs,
19641964
)
19651965
else:

tests/unitary/with_extras/model/test_generic_model.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,7 +1070,7 @@ def test_from_model_deployment(
10701070
bucket_uri="test_bucket_uri",
10711071
remove_existing_artifact=True,
10721072
compartment_id="test_compartment_id",
1073-
load_artifact=True,
1073+
download_artifact=True,
10741074
)
10751075

10761076
mock_from_id.assert_called_with(test_model_deployment_id)
@@ -1085,7 +1085,7 @@ def test_from_model_deployment(
10851085
remove_existing_artifact=True,
10861086
compartment_id="test_compartment_id",
10871087
ignore_conda_error=False,
1088-
load_artifact=True,
1088+
download_artifact=True,
10891089
)
10901090

10911091
assert test_result == test_model
@@ -1189,7 +1189,7 @@ def test_from_id_model_deployment(
11891189
remove_existing_artifact=True,
11901190
compartment_id="test_compartment_id",
11911191
ignore_conda_error=True,
1192-
load_artifact=True,
1192+
download_artifact=True,
11931193
)
11941194

11951195
mock_from_model_deployment.assert_called_with(
@@ -1203,7 +1203,7 @@ def test_from_id_model_deployment(
12031203
remove_existing_artifact=True,
12041204
compartment_id="test_compartment_id",
12051205
ignore_conda_error=True,
1206-
load_artifact=True,
1206+
download_artifact=True,
12071207
)
12081208

12091209
assert test_model_deployment_result == test_model
@@ -1239,7 +1239,7 @@ def test_from_id_model(self, mock_from_model_catalog):
12391239
remove_existing_artifact=True,
12401240
compartment_id="test_compartment_id",
12411241
ignore_conda_error=True,
1242-
load_artifact=True,
1242+
download_artifact=True,
12431243
)
12441244

12451245
mock_from_model_catalog.assert_called_with(
@@ -1253,7 +1253,7 @@ def test_from_id_model(self, mock_from_model_catalog):
12531253
remove_existing_artifact=True,
12541254
compartment_id="test_compartment_id",
12551255
ignore_conda_error=True,
1256-
load_artifact=True,
1256+
download_artifact=True,
12571257
)
12581258

12591259
assert test_model_result == mock_oci_model
@@ -1280,15 +1280,15 @@ def test_from_id_fail(self):
12801280

12811281
@patch.object(GenericModel, "from_model_catalog")
12821282
def test_from_id_model_without_artifact(self, mock_from_model_catalog):
1283-
"""Test to check model artifact is not loaded when load_artifact is set to False"""
1283+
"""Test to check model artifact is not loaded when download_artifact is set to False"""
12841284
test_model_id = "xxxx.datasciencemodel.xxxx"
12851285
mock_model = MagicMock(model_id=test_model_id, model_artifact=None)
12861286
mock_from_model_catalog.return_value = mock_model
12871287

12881288
test_model_result = GenericModel.from_id(
12891289
ocid=test_model_id,
12901290
compartment_id="test_compartment_id",
1291-
load_artifact=False,
1291+
download_artifact=False,
12921292
)
12931293
mock_from_model_catalog.assert_called_with(
12941294
test_model_id,
@@ -1301,14 +1301,14 @@ def test_from_id_model_without_artifact(self, mock_from_model_catalog):
13011301
remove_existing_artifact=True,
13021302
compartment_id="test_compartment_id",
13031303
ignore_conda_error=False,
1304-
load_artifact=False,
1304+
download_artifact=False,
13051305
)
13061306
assert test_model_result.model_artifact is None
13071307
assert test_model_result == mock_model
13081308

13091309
@patch.object(GenericModel, "from_model_catalog")
13101310
def test_from_id_with_artifact(self, mock_from_model_catalog):
1311-
"""Test to check model artifact is loaded when load_artifact is set to True"""
1311+
"""Test to check model artifact is loaded when download_artifact is set to True"""
13121312
test_model_id = "xxxx.datasciencemodel.xxxx"
13131313
artifact_dir = "test_dir"
13141314
model_artifact = MagicMock(artifact_dir=artifact_dir, reload=False)
@@ -1320,7 +1320,7 @@ def test_from_id_with_artifact(self, mock_from_model_catalog):
13201320
ocid=test_model_id,
13211321
compartment_id="test_compartment_id",
13221322
remove_existing_artifact=True,
1323-
load_artifact=True,
1323+
download_artifact=True,
13241324
)
13251325

13261326
mock_from_model_catalog.assert_called_with(
@@ -1334,7 +1334,7 @@ def test_from_id_with_artifact(self, mock_from_model_catalog):
13341334
remove_existing_artifact=True,
13351335
compartment_id="test_compartment_id",
13361336
ignore_conda_error=False,
1337-
load_artifact=True,
1337+
download_artifact=True,
13381338
)
13391339
assert test_model_result.model_artifact is not None
13401340
assert test_model_result == mock_model

0 commit comments

Comments
 (0)