Skip to content

Commit 672bbac

Browse files
committed
Merge branch 'main' of https://github.com/oracle/accelerated-data-science into ODSC-50632/improve_progress_bar
2 parents ecc5347 + 2eabfb0 commit 672bbac

File tree

13 files changed

+138
-112
lines changed

13 files changed

+138
-112
lines changed

ads/catalog/model.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,11 @@
8282
_WORK_REQUEST_INTERVAL_IN_SEC = 3
8383

8484

85-
class ModelWithActiveDeploymentError(Exception): # pragma: no cover
85+
class ModelWithActiveDeploymentError(Exception): # pragma: no cover
8686
pass
8787

8888

89-
class ModelArtifactSizeError(Exception): # pragma: no cover
89+
class ModelArtifactSizeError(Exception): # pragma: no cover
9090
def __init__(self, max_artifact_size: str):
9191
super().__init__(
9292
f"The model artifacts size is greater than `{max_artifact_size}`. "
@@ -391,7 +391,6 @@ def show_in_notebook(self, display_format: str = "dataframe") -> None:
391391
Nothing.
392392
"""
393393
if display_format == "dataframe":
394-
395394
from IPython.core.display import display
396395

397396
display(self.to_dataframe())
@@ -1560,9 +1559,9 @@ def _wait_for_work_request(
15601559
if work_request_logs:
15611560
new_work_request_logs = work_request_logs[i:]
15621561

1563-
for wr_item in new_work_request_logs:
1564-
progress.update(wr_item.message)
1565-
i += 1
1562+
for wr_item in new_work_request_logs:
1563+
progress.update(wr_item.message)
1564+
i += 1
15661565

15671566
if work_request.data.status in STOP_STATE:
15681567
if work_request.data.status != WorkRequest.STATUS_SUCCEEDED:

ads/common/dsc_file_system.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,6 @@ def update_from_dsc_model(cls, dsc_model) -> dict:
265265

266266
class DSCFileSystemManager:
267267

268-
storage_mount_dest = set()
269-
270268
@classmethod
271269
def initialize(cls, arguments: dict) -> dict:
272270
"""Initialize and update arguments to dsc model.
@@ -286,12 +284,6 @@ def initialize(cls, arguments: dict) -> dict:
286284
"Parameter `dest` is required for mounting file storage system."
287285
)
288286

289-
if arguments["dest"] in cls.storage_mount_dest:
290-
raise ValueError(
291-
"Duplicate `dest` found. Please specify different `dest` for each file system to be mounted."
292-
)
293-
cls.storage_mount_dest.add(arguments["dest"])
294-
295287
# case oci://bucket@namespace/prefix
296288
if arguments["src"].startswith("oci://") and "@" in arguments["src"]:
297289
return OCIObjectStorage(**arguments).update_to_dsc_model()

ads/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
NB_SESSION_OCID or JOB_RUN_OCID or MD_OCID or PIPELINE_RUN_OCID or DATAFLOW_RUN_OCID
4646
)
4747
NO_CONTAINER = os.environ.get("NO_CONTAINER")
48+
TMPDIR = os.environ.get("TMPDIR")
4849

4950

5051
def export(

ads/model/artifact.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,21 @@
3030
ADS_VERSION = __version__
3131

3232

33-
class ArtifactNestedFolderError(Exception): # pragma: no cover
33+
class ArtifactNestedFolderError(Exception): # pragma: no cover
3434
def __init__(self, folder: str):
3535
self.folder = folder
3636
super().__init__("The required artifact files placed in a nested folder.")
3737

3838

39-
class ArtifactRequiredFilesError(Exception): # pragma: no cover
39+
class ArtifactRequiredFilesError(Exception): # pragma: no cover
4040
def __init__(self, required_files: Tuple[str]):
4141
super().__init__(
4242
"Not all required files presented in artifact folder. "
4343
f"Required files for conda runtime: {required_files}. If you are using container runtime, set `ignore_conda_error=True`."
4444
)
4545

4646

47-
class AritfactFolderStructureError(Exception): # pragma: no cover
47+
class AritfactFolderStructureError(Exception): # pragma: no cover
4848
def __init__(self, required_files: Tuple[str]):
4949
super().__init__(
5050
"The artifact folder has a wrong structure. "
@@ -171,6 +171,7 @@ def __init__(
171171
self.ignore_conda_error = ignore_conda_error
172172
self.model = None
173173
self.auth = auth or authutil.default_signer()
174+
174175
if reload and not ignore_conda_error:
175176
self.reload()
176177
# Extracts the model_file_name from the score.py.
@@ -272,8 +273,9 @@ def prepare_runtime_yaml(
272273
or runtime_info.model_deployment.inference_conda_env.inference_python_version.strip()
273274
== ""
274275
):
275-
warnings.warn(
276-
"Cannot automatically detect the inference python version. `inference_python_version` must be provided."
276+
raise ValueError(
277+
"Cannot automatically detect the inference python version. "
278+
"`inference_python_version` must be provided."
277279
)
278280
runtime_file_path = os.path.join(self.artifact_dir, "runtime.yaml")
279281
if os.path.exists(runtime_file_path) and not force_overwrite:
@@ -416,6 +418,7 @@ def from_uri(
416418
force_overwrite: Optional[bool] = False,
417419
auth: Optional[Dict] = None,
418420
ignore_conda_error: Optional[bool] = False,
421+
reload: Optional[bool] = False,
419422
):
420423
"""Constructs a ModelArtifact object from the existing model artifacts.
421424
@@ -426,16 +429,20 @@ def from_uri(
426429
OCI object storage URI.
427430
artifact_dir: str
428431
The local artifact folder to store the files needed for deployment.
429-
model_file_name: (str, optional). Defaults to `None`
430-
The file name of the serialized model.
431-
force_overwrite: (bool, optional). Defaults to False.
432-
Whether to overwrite existing files or not.
433432
auth: (Dict, optional). Defaults to None.
434433
The default authetication is set using `ads.set_auth` API.
435434
If you need to override the default, use the `ads.common.auth.api_keys`
436435
or `ads.common.auth.resource_principal` to create appropriate
437436
authentication signer and kwargs required to instantiate
438437
IdentityClient object.
438+
force_overwrite: (bool, optional). Defaults to False.
439+
Whether to overwrite existing files or not.
440+
ignore_conda_error: (bool, optional). Defaults to False.
441+
Parameter to ignore error when collecting conda information.
442+
model_file_name: (str, optional). Defaults to `None`
443+
The file name of the serialized model.
444+
reload: (bool, optional). Defaults to False.
445+
Whether to reload the Model into the environment.
439446
440447
Returns
441448
-------
@@ -492,6 +499,8 @@ def from_uri(
492499
utils.copy_from_uri(
493500
uri=temp_dir, to_path=to_path, force_overwrite=True
494501
)
502+
except ArtifactRequiredFilesError as ex:
503+
logger.warning(ex)
495504

496505
if ObjectStorageDetails.is_oci_path(artifact_dir):
497506
for root, dirs, files in os.walk(to_path):
@@ -507,10 +516,10 @@ def from_uri(
507516

508517
return cls(
509518
artifact_dir=artifact_dir,
510-
model_file_name=model_file_name,
511-
reload=True,
512519
ignore_conda_error=ignore_conda_error,
513520
local_copy_dir=to_path,
521+
model_file_name=model_file_name,
522+
reload=reload,
514523
)
515524

516525
def __getattr__(self, item):

ads/model/artifact_downloader.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def download(self):
6363
"Set `force_overwrite` to `True` if you wish to overwrite."
6464
)
6565
shutil.rmtree(self.target_dir)
66+
os.mkdir(self.target_dir)
6667
with utils.get_progress_bar(
6768
ArtifactDownloader.PROGRESS_STEPS_COUNT + self.PROGRESS_STEPS_COUNT
6869
) as progress:
@@ -85,15 +86,16 @@ def _download(self):
8586
"""Downloads model artifacts."""
8687
self.progress.update("Importing model artifacts from catalog")
8788
zip_content = self.dsc_model.get_model_artifact_content()
88-
with tempfile.TemporaryDirectory() as temp_dir:
89-
self.progress.update("Copying model artifacts to the artifact directory")
90-
zip_file_path = os.path.join(temp_dir, f"{str(uuid.uuid4())}.zip")
91-
with open(zip_file_path, "wb") as zip_file:
92-
zip_file.write(zip_content)
93-
self.progress.update("Extracting model artifacts")
94-
with ZipFile(zip_file_path) as zip_file:
95-
zip_file.extractall(self.target_dir)
96-
89+
self.progress.update("Copying model artifacts to the artifact directory")
90+
91+
zip_file_path = os.path.join(self.target_dir, f"{str(uuid.uuid4())}.zip")
92+
with open(zip_file_path, "wb") as zip_file:
93+
zip_file.write(zip_content)
94+
self.progress.update("Extracting model artifacts")
95+
with ZipFile(zip_file_path) as zip_file:
96+
zip_file.extractall(self.target_dir)
97+
98+
utils.remove_file(zip_file_path)
9799

98100
class LargeArtifactDownloader(ArtifactDownloader):
99101
PROGRESS_STEPS_COUNT = 4
@@ -157,17 +159,19 @@ def _download(self):
157159

158160
self.dsc_model.import_model_artifact(bucket_uri=bucket_uri, region=self.region)
159161
self.progress.update("Copying model artifacts to the artifact directory")
160-
with tempfile.TemporaryDirectory() as temp_dir:
161-
zip_file_path = os.path.join(temp_dir, f"{str(uuid.uuid4())}.zip")
162-
zip_file_path = utils.copy_file(
163-
uri_src=bucket_uri,
164-
uri_dst=zip_file_path,
165-
auth=self.auth,
166-
progressbar_description="Copying model artifacts to the artifact directory",
167-
)
168-
self.progress.update("Extracting model artifacts")
169-
with ZipFile(zip_file_path) as zip_file:
170-
zip_file.extractall(self.target_dir)
162+
163+
zip_file_path = os.path.join(self.target_dir, f"{str(uuid.uuid4())}.zip")
164+
zip_file_path = utils.copy_file(
165+
uri_src=bucket_uri,
166+
uri_dst=zip_file_path,
167+
auth=self.auth,
168+
progressbar_description="Copying model artifacts to the artifact directory",
169+
)
170+
self.progress.update("Extracting model artifacts")
171+
with ZipFile(zip_file_path) as zip_file:
172+
zip_file.extractall(self.target_dir)
173+
174+
utils.remove_file(zip_file_path)
171175

172176
if self.remove_existing_artifact:
173177
self.progress.update(

ads/model/deployment/model_deployment_infrastructure.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
MODEL_DEPLOYMENT_INFRASTRUCTURE_KIND = "infrastructure"
2121

2222
DEFAULT_BANDWIDTH_MBPS = 10
23-
DEFAULT_WEB_CONCURRENCY = 10
2423
DEFAULT_REPLICA = 1
2524
DEFAULT_SHAPE_NAME = "VM.Standard.E4.Flex"
2625
DEFAULT_OCPUS = 1
@@ -219,7 +218,6 @@ def _load_default_properties(self) -> Dict:
219218
defaults[self.CONST_PROJECT_ID] = PROJECT_OCID
220219

221220
defaults[self.CONST_BANDWIDTH_MBPS] = DEFAULT_BANDWIDTH_MBPS
222-
defaults[self.CONST_WEB_CONCURRENCY] = DEFAULT_WEB_CONCURRENCY
223221
defaults[self.CONST_REPLICA] = DEFAULT_REPLICA
224222

225223
if NB_SESSION_OCID:
@@ -628,7 +626,6 @@ def init(self, **kwargs) -> "ModelDeploymentInfrastructure":
628626
.with_compartment_id(self.compartment_id or "{Provide a compartment OCID}")
629627
.with_project_id(self.project_id or "{Provide a project OCID}")
630628
.with_bandwidth_mbps(self.bandwidth_mbps or DEFAULT_BANDWIDTH_MBPS)
631-
.with_web_concurrency(self.web_concurrency or DEFAULT_WEB_CONCURRENCY)
632629
.with_replica(self.replica or DEFAULT_REPLICA)
633630
.with_shape_name(self.shape_name or DEFAULT_SHAPE_NAME)
634631
.with_shape_config_details(

ads/model/generic_model.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
NB_SESSION_OCID,
3232
PIPELINE_RUN_COMPARTMENT_OCID,
3333
PROJECT_OCID,
34+
TMPDIR,
3435
)
3536
from ads.evaluations import EvaluatorMixin
3637
from ads.feature_engineering import ADSImage
@@ -183,7 +184,7 @@ def _prepare_artifact_dir(artifact_dir: str = None) -> str:
183184
if artifact_dir and isinstance(artifact_dir, str):
184185
return os.path.abspath(os.path.expanduser(artifact_dir))
185186

186-
artifact_dir = tempfile.mkdtemp()
187+
artifact_dir = TMPDIR or tempfile.mkdtemp()
187188
logger.info(
188189
f"The `artifact_dir` was not provided and "
189190
f"automatically set to: {artifact_dir}"
@@ -965,17 +966,20 @@ def prepare(
965966
auth=self.auth,
966967
local_copy_dir=self.local_copy_dir,
967968
)
968-
self.runtime_info = self.model_artifact.prepare_runtime_yaml(
969-
inference_conda_env=self.properties.inference_conda_env,
970-
inference_python_version=self.properties.inference_python_version,
971-
training_conda_env=self.properties.training_conda_env,
972-
training_python_version=self.properties.training_python_version,
973-
force_overwrite=force_overwrite,
974-
namespace=namespace,
975-
bucketname=DEFAULT_CONDA_BUCKET_NAME,
976-
auth=self.auth,
977-
ignore_conda_error=self.ignore_conda_error,
978-
)
969+
try:
970+
self.runtime_info = self.model_artifact.prepare_runtime_yaml(
971+
inference_conda_env=self.properties.inference_conda_env,
972+
inference_python_version=self.properties.inference_python_version,
973+
training_conda_env=self.properties.training_conda_env,
974+
training_python_version=self.properties.training_python_version,
975+
force_overwrite=force_overwrite,
976+
namespace=namespace,
977+
bucketname=DEFAULT_CONDA_BUCKET_NAME,
978+
auth=self.auth,
979+
ignore_conda_error=self.ignore_conda_error,
980+
)
981+
except ValueError as e:
982+
raise e
979983

980984
self._summary_status.update_status(
981985
detail="Generated runtime.yaml", status=ModelState.DONE.value
@@ -1348,13 +1352,15 @@ def from_model_artifact(
13481352
properties.with_dict(local_vars)
13491353
auth = auth or authutil.default_signer()
13501354
artifact_dir = _prepare_artifact_dir(artifact_dir)
1355+
reload = kwargs.pop("reload", False)
13511356
model_artifact = ModelArtifact.from_uri(
13521357
uri=uri,
13531358
artifact_dir=artifact_dir,
1354-
model_file_name=model_file_name,
1355-
force_overwrite=force_overwrite,
13561359
auth=auth,
1360+
force_overwrite=force_overwrite,
13571361
ignore_conda_error=ignore_conda_error,
1362+
model_file_name=model_file_name,
1363+
reload=reload,
13581364
)
13591365
model = cls(
13601366
estimator=model_artifact.model,
@@ -1367,22 +1373,33 @@ def from_model_artifact(
13671373
model.local_copy_dir = model_artifact.local_copy_dir
13681374
model.model_artifact = model_artifact
13691375
model.ignore_conda_error = ignore_conda_error
1370-
model.reload_runtime_info()
1376+
1377+
if reload:
1378+
model.reload_runtime_info()
1379+
model._summary_status.update_action(
1380+
detail="Populated metadata(Custom, Taxonomy and Provenance)",
1381+
action="Call .populate_metadata() to populate metadata.",
1382+
)
1383+
13711384
model._summary_status.update_status(
13721385
detail="Generated score.py",
1373-
status=ModelState.DONE.value,
1386+
status=ModelState.NOTAPPLICABLE.value,
13741387
)
13751388
model._summary_status.update_status(
13761389
detail="Generated runtime.yaml",
1377-
status=ModelState.DONE.value,
1390+
status=ModelState.NOTAPPLICABLE.value,
13781391
)
13791392
model._summary_status.update_status(
1380-
detail="Serialized model", status=ModelState.DONE.value
1393+
detail="Serialized model",
1394+
status=ModelState.NOTAPPLICABLE.value,
13811395
)
1382-
model._summary_status.update_action(
1396+
model._summary_status.update_status(
13831397
detail="Populated metadata(Custom, Taxonomy and Provenance)",
1384-
action=f"Call .populate_metadata() to populate metadata.",
1398+
status=ModelState.AVAILABLE.value
1399+
if reload
1400+
else ModelState.NOTAPPLICABLE.value,
13851401
)
1402+
13861403
return model
13871404

13881405
@classmethod
@@ -3033,6 +3050,7 @@ class ModelState(Enum):
30333050
AVAILABLE = "Available"
30343051
NOTAVAILABLE = "Not Available"
30353052
NEEDSACTION = "Needs Action"
3053+
NOTAPPLICABLE = "Not Applicable"
30363054

30373055

30383056
class SummaryStatus:

0 commit comments

Comments
 (0)