Skip to content

Commit cf636ff

Browse files
authored
Updated download artifact path (#490)
2 parents f9163f7 + 7e4fd6f commit cf636ff

File tree

4 files changed

+30
-25
lines changed

4 files changed

+30
-25
lines changed

ads/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
NB_SESSION_OCID or JOB_RUN_OCID or MD_OCID or PIPELINE_RUN_OCID or DATAFLOW_RUN_OCID
4646
)
4747
NO_CONTAINER = os.environ.get("NO_CONTAINER")
48+
TMPDIR = os.environ.get("TMPDIR")
4849

4950

5051
def export(

ads/model/artifact_downloader.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def download(self):
6363
"Set `force_overwrite` to `True` if you wish to overwrite."
6464
)
6565
shutil.rmtree(self.target_dir)
66+
os.mkdir(self.target_dir)
6667
with utils.get_progress_bar(
6768
ArtifactDownloader.PROGRESS_STEPS_COUNT + self.PROGRESS_STEPS_COUNT
6869
) as progress:
@@ -85,15 +86,16 @@ def _download(self):
8586
"""Downloads model artifacts."""
8687
self.progress.update("Importing model artifacts from catalog")
8788
zip_content = self.dsc_model.get_model_artifact_content()
88-
with tempfile.TemporaryDirectory() as temp_dir:
89-
self.progress.update("Copying model artifacts to the artifact directory")
90-
zip_file_path = os.path.join(temp_dir, f"{str(uuid.uuid4())}.zip")
91-
with open(zip_file_path, "wb") as zip_file:
92-
zip_file.write(zip_content)
93-
self.progress.update("Extracting model artifacts")
94-
with ZipFile(zip_file_path) as zip_file:
95-
zip_file.extractall(self.target_dir)
96-
89+
self.progress.update("Copying model artifacts to the artifact directory")
90+
91+
zip_file_path = os.path.join(self.target_dir, f"{str(uuid.uuid4())}.zip")
92+
with open(zip_file_path, "wb") as zip_file:
93+
zip_file.write(zip_content)
94+
self.progress.update("Extracting model artifacts")
95+
with ZipFile(zip_file_path) as zip_file:
96+
zip_file.extractall(self.target_dir)
97+
98+
utils.remove_file(zip_file_path)
9799

98100
class LargeArtifactDownloader(ArtifactDownloader):
99101
PROGRESS_STEPS_COUNT = 4
@@ -157,17 +159,19 @@ def _download(self):
157159

158160
self.dsc_model.import_model_artifact(bucket_uri=bucket_uri, region=self.region)
159161
self.progress.update("Copying model artifacts to the artifact directory")
160-
with tempfile.TemporaryDirectory() as temp_dir:
161-
zip_file_path = os.path.join(temp_dir, f"{str(uuid.uuid4())}.zip")
162-
zip_file_path = utils.copy_file(
163-
uri_src=bucket_uri,
164-
uri_dst=zip_file_path,
165-
auth=self.auth,
166-
progressbar_description="Copying model artifacts to the artifact directory",
167-
)
168-
self.progress.update("Extracting model artifacts")
169-
with ZipFile(zip_file_path) as zip_file:
170-
zip_file.extractall(self.target_dir)
162+
163+
zip_file_path = os.path.join(self.target_dir, f"{str(uuid.uuid4())}.zip")
164+
zip_file_path = utils.copy_file(
165+
uri_src=bucket_uri,
166+
uri_dst=zip_file_path,
167+
auth=self.auth,
168+
progressbar_description="Copying model artifacts to the artifact directory",
169+
)
170+
self.progress.update("Extracting model artifacts")
171+
with ZipFile(zip_file_path) as zip_file:
172+
zip_file.extractall(self.target_dir)
173+
174+
utils.remove_file(zip_file_path)
171175

172176
if self.remove_existing_artifact:
173177
self.progress.update(

ads/model/generic_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
NB_SESSION_OCID,
3232
PIPELINE_RUN_COMPARTMENT_OCID,
3333
PROJECT_OCID,
34+
TMPDIR,
3435
)
3536
from ads.evaluations import EvaluatorMixin
3637
from ads.feature_engineering import ADSImage
@@ -183,7 +184,7 @@ def _prepare_artifact_dir(artifact_dir: str = None) -> str:
183184
if artifact_dir and isinstance(artifact_dir, str):
184185
return os.path.abspath(os.path.expanduser(artifact_dir))
185186

186-
artifact_dir = tempfile.mkdtemp()
187+
artifact_dir = TMPDIR or tempfile.mkdtemp()
187188
logger.info(
188189
f"The `artifact_dir` was not provided and "
189190
f"automatically set to: {artifact_dir}"

tests/unitary/default_setup/model/test_artifact_downloader.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,20 +100,19 @@ def test_downaload_small_artifact(self):
100100
expected_artifact_bytes_content
101101
)
102102
with tempfile.TemporaryDirectory() as tmp_dir:
103-
target_dir = os.path.join(tmp_dir, "model_artifacts/")
104103
SmallArtifactDownloader(
105104
dsc_model=self.mock_dsc_model,
106-
target_dir=target_dir,
105+
target_dir=tmp_dir,
107106
force_overwrite=True,
108107
).download()
109108

110109
self.mock_dsc_model.get_model_artifact_content.assert_called()
111110

112111
test_files = list(
113-
glob.iglob(os.path.join(target_dir, "**"), recursive=True)
112+
glob.iglob(os.path.join(tmp_dir, "**"), recursive=True)
114113
)
115114
expected_files = [
116-
os.path.join(target_dir, file_name)
115+
os.path.join(tmp_dir, file_name)
117116
for file_name in ["", "runtime.yaml", "score.py"]
118117
]
119118
assert sorted(test_files) == sorted(expected_files)

0 commit comments

Comments
 (0)