Skip to content

Commit d9756a8

Browse files
authored
ODSC-46634/utilize oci UploadManager to upload model artifacts (#304)
1 parent ca5766d commit d9756a8

File tree

8 files changed

+375
-103
lines changed

8 files changed

+375
-103
lines changed

ads/common/utils.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@
5353
from ads.dataset.progress import DummyProgressBar, TqdmProgressBar
5454

5555
from . import auth as authutil
56+
from oci import object_storage
57+
from ads.common.oci_client import OCIClientFactory
58+
from ads.common.object_storage_details import ObjectStorageDetails
5659

5760
# For Model / Model Artifact libraries
5861
lib_translator = {"sklearn": "scikit-learn"}
@@ -100,6 +103,9 @@
100103

101104
# declare custom exception class
102105

106+
# The number of worker processes to use in parallel for uploading individual parts of a multipart upload.
107+
DEFAULT_PARALLEL_PROCESS_COUNT = 9
108+
103109

104110
class FileOverwriteError(Exception): # pragma: no cover
105111
pass
@@ -1599,3 +1605,103 @@ def is_path_exists(uri: str, auth: Optional[Dict] = None) -> bool:
15991605
if fsspec.filesystem(path_scheme, **storage_options).exists(uri):
16001606
return True
16011607
return False
1608+
1609+
1610+
def upload_to_os(
1611+
src_uri: str,
1612+
dst_uri: str,
1613+
auth: dict = None,
1614+
parallel_process_count: int = DEFAULT_PARALLEL_PROCESS_COUNT,
1615+
progressbar_description: str = "Uploading `{src_uri}` to `{dst_uri}`.",
1616+
force_overwrite: bool = False,
1617+
):
1618+
"""Utilizes `oci.object_storage.Uploadmanager` to upload file to Object Storage.
1619+
1620+
Parameters
1621+
----------
1622+
src_uri: str
1623+
The path to the file to upload. This should be local path.
1624+
dst_uri: str
1625+
Object Storage path, eg. `oci://my-bucket@my-tenancy/prefix``.
1626+
auth: (Dict, optional) Defaults to None.
1627+
default_signer()
1628+
parallel_process_count: (int, optional) Defaults to 3.
1629+
The number of worker processes to use in parallel for uploading individual
1630+
parts of a multipart upload.
1631+
progressbar_description: (str, optional) Defaults to `"Uploading `{src_uri}` to `{dst_uri}`"`.
1632+
Prefix for the progressbar.
1633+
force_overwrite: (bool, optional). Defaults to False.
1634+
Whether to overwrite existing files or not.
1635+
1636+
Returns
1637+
-------
1638+
Response: oci.response.Response
1639+
The response from multipart commit operation or the put operation.
1640+
1641+
Raise
1642+
-----
1643+
ValueError
1644+
When the given `dst_uri` is not a valid Object Storage path.
1645+
FileNotFoundError
1646+
When the given `src_uri` does not exist.
1647+
RuntimeError
1648+
When upload operation fails.
1649+
"""
1650+
if not os.path.exists(src_uri):
1651+
raise FileNotFoundError(f"The give src_uri: {src_uri} does not exist.")
1652+
1653+
if not ObjectStorageDetails.is_oci_path(
1654+
dst_uri
1655+
) or not ObjectStorageDetails.is_valid_uri(dst_uri):
1656+
raise ValueError(
1657+
f"The given dst_uri:{dst_uri} is not a valid Object Storage path."
1658+
)
1659+
1660+
auth = auth or authutil.default_signer()
1661+
1662+
if not force_overwrite and is_path_exists(dst_uri):
1663+
raise FileExistsError(
1664+
f"The `{dst_uri}` exists. Please use a new file name or "
1665+
"set force_overwrite to True if you wish to overwrite."
1666+
)
1667+
1668+
upload_manager = object_storage.UploadManager(
1669+
object_storage_client=OCIClientFactory(**auth).object_storage,
1670+
parallel_process_count=parallel_process_count,
1671+
allow_multipart_uploads=True,
1672+
allow_parallel_uploads=True,
1673+
)
1674+
1675+
file_size = os.path.getsize(src_uri)
1676+
with open(src_uri, "rb") as fs:
1677+
with tqdm(
1678+
total=file_size,
1679+
unit="B",
1680+
unit_scale=True,
1681+
unit_divisor=1024,
1682+
position=0,
1683+
leave=False,
1684+
file=sys.stdout,
1685+
desc=progressbar_description,
1686+
) as pbar:
1687+
1688+
def progress_callback(progress):
1689+
pbar.update(progress)
1690+
1691+
bucket_details = ObjectStorageDetails.from_path(dst_uri)
1692+
response = upload_manager.upload_stream(
1693+
namespace_name=bucket_details.namespace,
1694+
bucket_name=bucket_details.bucket,
1695+
object_name=bucket_details.filepath,
1696+
stream_ref=fs,
1697+
progress_callback=progress_callback,
1698+
)
1699+
1700+
if response.status == 200:
1701+
print(f"{src_uri} has been successfully uploaded to {dst_uri}.")
1702+
else:
1703+
raise RuntimeError(
1704+
f"Failed to upload {src_uri}. Response code is {response.status}"
1705+
)
1706+
1707+
return response

ads/model/artifact_uploader.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def _upload(self):
9494

9595

9696
class SmallArtifactUploader(ArtifactUploader):
97+
"""The class helper to upload small model artifacts."""
98+
9799
PROGRESS_STEPS_COUNT = 1
98100

99101
def _upload(self):
@@ -104,6 +106,39 @@ def _upload(self):
104106

105107

106108
class LargeArtifactUploader(ArtifactUploader):
109+
"""
110+
The class helper to upload large model artifacts.
111+
112+
Attributes
113+
----------
114+
artifact_path: str
115+
The model artifact location.
116+
artifact_zip_path: str
117+
The uri of the zip of model artifact.
118+
auth: dict
119+
The default authetication is set using `ads.set_auth` API.
120+
If you need to override the default, use the `ads.common.auth.api_keys` or
121+
`ads.common.auth.resource_principal` to create appropriate authentication signer
122+
and kwargs required to instantiate IdentityClient object.
123+
bucket_uri: str
124+
The OCI Object Storage URI where model artifacts will be copied to.
125+
The `bucket_uri` is only necessary for uploading large artifacts which
126+
size is greater than 2GB. Example: `oci://<bucket_name>@<namespace>/prefix/`.
127+
dsc_model: OCIDataScienceModel
128+
The data scince model instance.
129+
overwrite_existing_artifact: bool
130+
Overwrite target bucket artifact if exists.
131+
progress: TqdmProgressBar
132+
An instance of the TqdmProgressBar.
133+
region: str
134+
The destination Object Storage bucket region.
135+
By default the value will be extracted from the `OCI_REGION_METADATA` environment variables.
136+
remove_existing_artifact: bool
137+
Wether artifacts uploaded to object storage bucket need to be removed or not.
138+
upload_manager: UploadManager
139+
The uploadManager simplifies interaction with the Object Storage service.
140+
"""
141+
107142
PROGRESS_STEPS_COUNT = 4
108143

109144
def __init__(
@@ -115,6 +150,7 @@ def __init__(
115150
region: Optional[str] = None,
116151
overwrite_existing_artifact: Optional[bool] = True,
117152
remove_existing_artifact: Optional[bool] = True,
153+
parallel_process_count: int = utils.DEFAULT_PARALLEL_PROCESS_COUNT,
118154
):
119155
"""Initializes `LargeArtifactUploader` instance.
120156
@@ -139,7 +175,9 @@ def __init__(
139175
overwrite_existing_artifact: (bool, optional). Defaults to `True`.
140176
Overwrite target bucket artifact if exists.
141177
remove_existing_artifact: (bool, optional). Defaults to `True`.
142-
Wether artifacts uploaded to object storage bucket need to be removed or not.
178+
Whether artifacts uploaded to object storage bucket need to be removed or not.
179+
parallel_process_count: (int, optional).
180+
The number of worker processes to use in parallel for uploading individual parts of a multipart upload.
143181
"""
144182
if not bucket_uri:
145183
raise ValueError("The `bucket_uri` must be provided.")
@@ -150,36 +188,45 @@ def __init__(
150188
self.bucket_uri = bucket_uri
151189
self.overwrite_existing_artifact = overwrite_existing_artifact
152190
self.remove_existing_artifact = remove_existing_artifact
191+
self._parallel_process_count = parallel_process_count
153192

154193
def _upload(self):
155194
"""Uploads model artifacts to the model catalog."""
156195
self.progress.update("Copying model artifact to the Object Storage bucket")
157196

158-
try:
159-
bucket_uri = self.bucket_uri
160-
bucket_uri_file_name = os.path.basename(bucket_uri)
197+
bucket_uri = self.bucket_uri
198+
bucket_uri_file_name = os.path.basename(bucket_uri)
161199

162-
if not bucket_uri_file_name:
163-
bucket_uri = os.path.join(bucket_uri, f"{self.dsc_model.id}.zip")
164-
elif not bucket_uri.lower().endswith(".zip"):
165-
bucket_uri = f"{bucket_uri}.zip"
200+
if not bucket_uri_file_name:
201+
bucket_uri = os.path.join(bucket_uri, f"{self.dsc_model.id}.zip")
202+
elif not bucket_uri.lower().endswith(".zip"):
203+
bucket_uri = f"{bucket_uri}.zip"
166204

167-
bucket_file_name = utils.copy_file(
168-
self.artifact_zip_path,
169-
bucket_uri,
170-
force_overwrite=self.overwrite_existing_artifact,
171-
auth=self.auth,
172-
progressbar_description="Copying model artifact to the Object Storage bucket",
173-
)
174-
except FileExistsError:
205+
if not self.overwrite_existing_artifact and utils.is_path_exists(
206+
uri=bucket_uri, auth=self.auth
207+
):
175208
raise FileExistsError(
176-
f"The `{self.bucket_uri}` exists. Please use a new file name or "
209+
f"The bucket_uri=`{self.bucket_uri}` exists. Please use a new file name or "
177210
"set `overwrite_existing_artifact` to `True` if you wish to overwrite."
178211
)
212+
213+
try:
214+
utils.upload_to_os(
215+
src_uri=self.artifact_zip_path,
216+
dst_uri=bucket_uri,
217+
auth=self.auth,
218+
parallel_process_count=self._parallel_process_count,
219+
force_overwrite=self.overwrite_existing_artifact,
220+
progressbar_description="Copying model artifact to the Object Storage bucket.",
221+
)
222+
except Exception as ex:
223+
raise RuntimeError(
224+
f"Failed to upload model artifact to the given Object Storage path `{self.bucket_uri}`."
225+
f"See Exception: {ex}"
226+
)
227+
179228
self.progress.update("Exporting model artifact to the model catalog")
180-
self.dsc_model.export_model_artifact(
181-
bucket_uri=bucket_file_name, region=self.region
182-
)
229+
self.dsc_model.export_model_artifact(bucket_uri=bucket_uri, region=self.region)
183230

184231
if self.remove_existing_artifact:
185232
self.progress.update(

ads/model/datascience_model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
_MAX_ARTIFACT_SIZE_IN_BYTES = 2147483648 # 2GB
3636

3737

38-
class ModelArtifactSizeError(Exception): # pragma: no cover
38+
class ModelArtifactSizeError(Exception): # pragma: no cover
3939
def __init__(self, max_artifact_size: str):
4040
super().__init__(
4141
f"The model artifacts size is greater than `{max_artifact_size}`. "
@@ -562,6 +562,8 @@ def create(self, **kwargs) -> "DataScienceModel":
562562
and kwargs required to instantiate IdentityClient object.
563563
timeout: (int, optional). Defaults to 10 seconds.
564564
The connection timeout in seconds for the client.
565+
parallel_process_count: (int, optional).
566+
The number of worker processes to use in parallel for uploading individual parts of a multipart upload.
565567
566568
Returns
567569
-------
@@ -607,6 +609,7 @@ def create(self, **kwargs) -> "DataScienceModel":
607609
region=kwargs.pop("region", None),
608610
auth=kwargs.pop("auth", None),
609611
timeout=kwargs.pop("timeout", None),
612+
parallel_process_count=kwargs.pop("parallel_process_count", None),
610613
)
611614

612615
# Sync up model
@@ -623,6 +626,7 @@ def upload_artifact(
623626
overwrite_existing_artifact: Optional[bool] = True,
624627
remove_existing_artifact: Optional[bool] = True,
625628
timeout: Optional[int] = None,
629+
parallel_process_count: int = utils.DEFAULT_PARALLEL_PROCESS_COUNT,
626630
) -> None:
627631
"""Uploads model artifacts to the model catalog.
628632
@@ -646,6 +650,8 @@ def upload_artifact(
646650
Wether artifacts uploaded to object storage bucket need to be removed or not.
647651
timeout: (int, optional). Defaults to 10 seconds.
648652
The connection timeout in seconds for the client.
653+
parallel_process_count: (int, optional)
654+
The number of worker processes to use in parallel for uploading individual parts of a multipart upload.
649655
"""
650656
# Upload artifact to the model catalog
651657
if not self.artifact:
@@ -676,6 +682,7 @@ def upload_artifact(
676682
bucket_uri=bucket_uri,
677683
overwrite_existing_artifact=overwrite_existing_artifact,
678684
remove_existing_artifact=remove_existing_artifact,
685+
parallel_process_count=parallel_process_count,
679686
)
680687
else:
681688
artifact_uploader = SmallArtifactUploader(

0 commit comments

Comments
 (0)