Skip to content

Commit 14c4856

Browse files
Merge branch 'main' into ODSC-34939/explicit_download_artifact
2 parents c8cb9e8 + cf636ff commit 14c4856

32 files changed

+1112
-184
lines changed

ads/catalog/model.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,11 @@
8282
_WORK_REQUEST_INTERVAL_IN_SEC = 3
8383

8484

85-
class ModelWithActiveDeploymentError(Exception): # pragma: no cover
85+
class ModelWithActiveDeploymentError(Exception): # pragma: no cover
8686
pass
8787

8888

89-
class ModelArtifactSizeError(Exception): # pragma: no cover
89+
class ModelArtifactSizeError(Exception): # pragma: no cover
9090
def __init__(self, max_artifact_size: str):
9191
super().__init__(
9292
f"The model artifacts size is greater than `{max_artifact_size}`. "
@@ -391,7 +391,6 @@ def show_in_notebook(self, display_format: str = "dataframe") -> None:
391391
Nothing.
392392
"""
393393
if display_format == "dataframe":
394-
395394
from IPython.core.display import display
396395

397396
display(self.to_dataframe())
@@ -1560,9 +1559,9 @@ def _wait_for_work_request(
15601559
if work_request_logs:
15611560
new_work_request_logs = work_request_logs[i:]
15621561

1563-
for wr_item in new_work_request_logs:
1564-
progress.update(wr_item.message)
1565-
i += 1
1562+
for wr_item in new_work_request_logs:
1563+
progress.update(wr_item.message)
1564+
i += 1
15661565

15671566
if work_request.data.status in STOP_STATE:
15681567
if work_request.data.status != WorkRequest.STATUS_SUCCEEDED:

ads/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
NB_SESSION_OCID or JOB_RUN_OCID or MD_OCID or PIPELINE_RUN_OCID or DATAFLOW_RUN_OCID
4646
)
4747
NO_CONTAINER = os.environ.get("NO_CONTAINER")
48+
TMPDIR = os.environ.get("TMPDIR")
4849

4950

5051
def export(

ads/jobs/templates/driver_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ def run_command(
413413
logger.log(level=level, msg=msg)
414414
# Add a small delay so that
415415
# outputs from the subsequent code will have different timestamp for oci logging
416-
time.sleep(0.05)
416+
time.sleep(0.02)
417417
if check and process.returncode != 0:
418418
# If there is an error, exit the main process with the same return code.
419419
sys.exit(process.returncode)

ads/model/artifact_downloader.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def download(self):
6363
"Set `force_overwrite` to `True` if you wish to overwrite."
6464
)
6565
shutil.rmtree(self.target_dir)
66+
os.mkdir(self.target_dir)
6667
with utils.get_progress_bar(
6768
ArtifactDownloader.PROGRESS_STEPS_COUNT + self.PROGRESS_STEPS_COUNT
6869
) as progress:
@@ -85,15 +86,16 @@ def _download(self):
8586
"""Downloads model artifacts."""
8687
self.progress.update("Importing model artifacts from catalog")
8788
zip_content = self.dsc_model.get_model_artifact_content()
88-
with tempfile.TemporaryDirectory() as temp_dir:
89-
self.progress.update("Copying model artifacts to the artifact directory")
90-
zip_file_path = os.path.join(temp_dir, f"{str(uuid.uuid4())}.zip")
91-
with open(zip_file_path, "wb") as zip_file:
92-
zip_file.write(zip_content)
93-
self.progress.update("Extracting model artifacts")
94-
with ZipFile(zip_file_path) as zip_file:
95-
zip_file.extractall(self.target_dir)
96-
89+
self.progress.update("Copying model artifacts to the artifact directory")
90+
91+
zip_file_path = os.path.join(self.target_dir, f"{str(uuid.uuid4())}.zip")
92+
with open(zip_file_path, "wb") as zip_file:
93+
zip_file.write(zip_content)
94+
self.progress.update("Extracting model artifacts")
95+
with ZipFile(zip_file_path) as zip_file:
96+
zip_file.extractall(self.target_dir)
97+
98+
utils.remove_file(zip_file_path)
9799

98100
class LargeArtifactDownloader(ArtifactDownloader):
99101
PROGRESS_STEPS_COUNT = 4
@@ -157,17 +159,19 @@ def _download(self):
157159

158160
self.dsc_model.import_model_artifact(bucket_uri=bucket_uri, region=self.region)
159161
self.progress.update("Copying model artifacts to the artifact directory")
160-
with tempfile.TemporaryDirectory() as temp_dir:
161-
zip_file_path = os.path.join(temp_dir, f"{str(uuid.uuid4())}.zip")
162-
zip_file_path = utils.copy_file(
163-
uri_src=bucket_uri,
164-
uri_dst=zip_file_path,
165-
auth=self.auth,
166-
progressbar_description="Copying model artifacts to the artifact directory",
167-
)
168-
self.progress.update("Extracting model artifacts")
169-
with ZipFile(zip_file_path) as zip_file:
170-
zip_file.extractall(self.target_dir)
162+
163+
zip_file_path = os.path.join(self.target_dir, f"{str(uuid.uuid4())}.zip")
164+
zip_file_path = utils.copy_file(
165+
uri_src=bucket_uri,
166+
uri_dst=zip_file_path,
167+
auth=self.auth,
168+
progressbar_description="Copying model artifacts to the artifact directory",
169+
)
170+
self.progress.update("Extracting model artifacts")
171+
with ZipFile(zip_file_path) as zip_file:
172+
zip_file.extractall(self.target_dir)
173+
174+
utils.remove_file(zip_file_path)
171175

172176
if self.remove_existing_artifact:
173177
self.progress.update(

ads/model/deployment/model_deployment.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
DELETE_WORKFLOW_STEPS = 2
6565
DEACTIVATE_WORKFLOW_STEPS = 2
6666
DEFAULT_RETRYING_REQUEST_ATTEMPTS = 3
67-
TERMINAL_STATES = [State.ACTIVE, State.FAILED, State.DELETED, State.INACTIVE]
6867

6968
MODEL_DEPLOYMENT_KIND = "deployment"
7069
MODEL_DEPLOYMENT_TYPE = "modelDeployment"
@@ -720,7 +719,7 @@ def update(
720719

721720
def watch(
722721
self,
723-
log_type: str = ModelDeploymentLogType.ACCESS,
722+
log_type: str = None,
724723
time_start: datetime = None,
725724
interval: int = LOG_INTERVAL,
726725
log_filter: str = None,
@@ -731,7 +730,7 @@ def watch(
731730
----------
732731
log_type: str, optional
733732
The log type. Can be `access`, `predict` or None.
734-
Defaults to access.
733+
Defaults to None.
735734
time_start : datetime.datetime, optional
736735
Starting time for the log query.
737736
Defaults to None.
@@ -757,7 +756,7 @@ def watch(
757756
count = self.logs(log_type).stream(
758757
source=self.model_deployment_id,
759758
interval=interval,
760-
stop_condition=self._stop_condition,
759+
stop_condition=self._stream_stop_condition,
761760
time_start=time_start,
762761
log_filter=log_filter,
763762
)
@@ -773,7 +772,11 @@ def watch(
773772

774773
def _stop_condition(self):
775774
"""Stops the sync once the model deployment is in a terminal state."""
776-
return self.state in TERMINAL_STATES
775+
return self.state in [State.ACTIVE, State.FAILED, State.DELETED, State.INACTIVE]
776+
777+
def _stream_stop_condition(self):
778+
"""Stops the stream sync once the model deployment is in a terminal state."""
779+
return self.state in [State.FAILED, State.DELETED, State.INACTIVE]
777780

778781
def _check_and_print_status(self, prev_status) -> str:
779782
"""Check and print the next status.

ads/model/deployment/model_deployment_infrastructure.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
MODEL_DEPLOYMENT_INFRASTRUCTURE_KIND = "infrastructure"
2121

2222
DEFAULT_BANDWIDTH_MBPS = 10
23-
DEFAULT_WEB_CONCURRENCY = 10
2423
DEFAULT_REPLICA = 1
2524
DEFAULT_SHAPE_NAME = "VM.Standard.E4.Flex"
2625
DEFAULT_OCPUS = 1
@@ -219,7 +218,6 @@ def _load_default_properties(self) -> Dict:
219218
defaults[self.CONST_PROJECT_ID] = PROJECT_OCID
220219

221220
defaults[self.CONST_BANDWIDTH_MBPS] = DEFAULT_BANDWIDTH_MBPS
222-
defaults[self.CONST_WEB_CONCURRENCY] = DEFAULT_WEB_CONCURRENCY
223221
defaults[self.CONST_REPLICA] = DEFAULT_REPLICA
224222

225223
if NB_SESSION_OCID:
@@ -628,7 +626,6 @@ def init(self, **kwargs) -> "ModelDeploymentInfrastructure":
628626
.with_compartment_id(self.compartment_id or "{Provide a compartment OCID}")
629627
.with_project_id(self.project_id or "{Provide a project OCID}")
630628
.with_bandwidth_mbps(self.bandwidth_mbps or DEFAULT_BANDWIDTH_MBPS)
631-
.with_web_concurrency(self.web_concurrency or DEFAULT_WEB_CONCURRENCY)
632629
.with_replica(self.replica or DEFAULT_REPLICA)
633630
.with_shape_name(self.shape_name or DEFAULT_SHAPE_NAME)
634631
.with_shape_config_details(

ads/model/generic_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
NB_SESSION_OCID,
3232
PIPELINE_RUN_COMPARTMENT_OCID,
3333
PROJECT_OCID,
34+
TMPDIR,
3435
)
3536
from ads.evaluations import EvaluatorMixin
3637
from ads.feature_engineering import ADSImage
@@ -190,7 +191,7 @@ def _prepare_artifact_dir(artifact_dir: str = None) -> str:
190191
if artifact_dir and isinstance(artifact_dir, str):
191192
return os.path.abspath(os.path.expanduser(artifact_dir))
192193

193-
artifact_dir = tempfile.mkdtemp()
194+
artifact_dir = TMPDIR or tempfile.mkdtemp()
194195
logger.info(
195196
f"The `artifact_dir` was not provided and "
196197
f"automatically set to: {artifact_dir}"

ads/model/service/oci_datascience_model_deployment.py

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ def wrapper(self, *args, **kwargs):
6767
return decorator
6868

6969

70-
class MissingModelDeploymentIdError(Exception): # pragma: no cover
70+
class MissingModelDeploymentIdError(Exception): # pragma: no cover
7171
pass
7272

7373

74-
class MissingModelDeploymentWorkflowIdError(Exception): # pragma: no cover
74+
class MissingModelDeploymentWorkflowIdError(Exception): # pragma: no cover
7575
pass
7676

7777

@@ -175,32 +175,23 @@ def activate(
175175
The `OCIDataScienceModelDeployment` instance (self).
176176
"""
177177
dsc_model_deployment = OCIDataScienceModelDeployment.from_id(self.id)
178-
if (
179-
dsc_model_deployment.lifecycle_state
180-
== self.LIFECYCLE_STATE_ACTIVE
181-
):
178+
if dsc_model_deployment.lifecycle_state == self.LIFECYCLE_STATE_ACTIVE:
182179
raise Exception(
183180
f"Model deployment {dsc_model_deployment.id} is already in active state."
184181
)
185182

186-
if (
187-
dsc_model_deployment.lifecycle_state
188-
== self.LIFECYCLE_STATE_INACTIVE
189-
):
183+
if dsc_model_deployment.lifecycle_state == self.LIFECYCLE_STATE_INACTIVE:
190184
logger.info(f"Activating model deployment `{self.id}`.")
191185
response = self.client.activate_model_deployment(
192186
self.id,
193187
)
194188

195189
if wait_for_completion:
196-
197190
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
198191

199192
try:
200193
self.wait_for_progress(
201-
self.workflow_req_id,
202-
max_wait_time,
203-
poll_interval
194+
self.workflow_req_id, max_wait_time, poll_interval
204195
)
205196
except Exception as e:
206197
logger.error(
@@ -241,21 +232,17 @@ def create(
241232
response = self.client.create_model_deployment(create_model_deployment_details)
242233
self.update_from_oci_model(response.data)
243234
logger.info(f"Creating model deployment `{self.id}`.")
235+
print(f"Model Deployment OCID: {self.id}")
244236

245237
if wait_for_completion:
246-
247238
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
248239

249240
try:
250241
self.wait_for_progress(
251-
self.workflow_req_id,
252-
max_wait_time,
253-
poll_interval
242+
self.workflow_req_id, max_wait_time, poll_interval
254243
)
255244
except Exception as e:
256-
logger.error(
257-
"Error while trying to create model deployment: " + str(e)
258-
)
245+
logger.error("Error while trying to create model deployment: " + str(e))
259246

260247
return self.sync()
261248

@@ -287,32 +274,23 @@ def deactivate(
287274
The `OCIDataScienceModelDeployment` instance (self).
288275
"""
289276
dsc_model_deployment = OCIDataScienceModelDeployment.from_id(self.id)
290-
if (
291-
dsc_model_deployment.lifecycle_state
292-
== self.LIFECYCLE_STATE_INACTIVE
293-
):
277+
if dsc_model_deployment.lifecycle_state == self.LIFECYCLE_STATE_INACTIVE:
294278
raise Exception(
295279
f"Model deployment {dsc_model_deployment.id} is already in inactive state."
296280
)
297281

298-
if (
299-
dsc_model_deployment.lifecycle_state
300-
== self.LIFECYCLE_STATE_ACTIVE
301-
):
282+
if dsc_model_deployment.lifecycle_state == self.LIFECYCLE_STATE_ACTIVE:
302283
logger.info(f"Deactivating model deployment `{self.id}`.")
303284
response = self.client.deactivate_model_deployment(
304285
self.id,
305286
)
306287

307288
if wait_for_completion:
308-
309289
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
310290

311291
try:
312292
self.wait_for_progress(
313-
self.workflow_req_id,
314-
max_wait_time,
315-
poll_interval
293+
self.workflow_req_id, max_wait_time, poll_interval
316294
)
317295
except Exception as e:
318296
logger.error(
@@ -355,7 +333,7 @@ def delete(
355333
dsc_model_deployment = OCIDataScienceModelDeployment.from_id(self.id)
356334
if dsc_model_deployment.lifecycle_state in [
357335
self.LIFECYCLE_STATE_DELETED,
358-
self.LIFECYCLE_STATE_DELETING
336+
self.LIFECYCLE_STATE_DELETING,
359337
]:
360338
raise Exception(
361339
f"Model deployment {dsc_model_deployment.id} is either deleted or being deleted."
@@ -374,19 +352,14 @@ def delete(
374352
)
375353

376354
if wait_for_completion:
377-
378355
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
379356

380357
try:
381358
self.wait_for_progress(
382-
self.workflow_req_id,
383-
max_wait_time,
384-
poll_interval
359+
self.workflow_req_id, max_wait_time, poll_interval
385360
)
386361
except Exception as e:
387-
logger.error(
388-
"Error while trying to delete model deployment: " + str(e)
389-
)
362+
logger.error("Error while trying to delete model deployment: " + str(e))
390363

391364
return self.sync()
392365

@@ -440,9 +413,7 @@ def update(
440413
)
441414
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
442415
except Exception as e:
443-
logger.error(
444-
"Error while trying to update model deployment: " + str(e)
445-
)
416+
logger.error("Error while trying to update model deployment: " + str(e))
446417

447418
return self.sync()
448419

ads/opctl/operator/common/operator_loader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ class OperatorInfo(DataClassSerializable):
136136
dataflow_default_params (DataFlowDefaultParams)
137137
The default params for the DataFlow service.
138138
Will be used when operator run on the DataFlow service.
139+
logo: str
140+
The logo of the operator.
141+
Needs to be attached in the "svg+xml;base64" format.
139142
140143
Properties
141144
----------
@@ -157,6 +160,7 @@ class OperatorInfo(DataClassSerializable):
157160
dataflow_default_params: DataFlowDefaultParams = field(
158161
default_factory=DataFlowDefaultParams
159162
)
163+
logo: str = ""
160164

161165
@property
162166
def conda_prefix(self) -> str:

0 commit comments

Comments
 (0)