Skip to content

Commit 5451c2c

Browse files
[ODSC-6682] Delete HF cache by default while registering models (#1044)
2 parents 410dbe0 + a0645cb commit 5451c2c

File tree

6 files changed

+105
-23
lines changed

6 files changed

+105
-23
lines changed

ads/aqua/common/utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import random
1212
import re
1313
import shlex
14+
import shutil
1415
import subprocess
1516
from datetime import datetime, timedelta
1617
from functools import wraps
@@ -21,6 +22,8 @@
2122
import fsspec
2223
import oci
2324
from cachetools import TTLCache, cached
25+
from huggingface_hub.constants import HF_HUB_CACHE
26+
from huggingface_hub.file_download import repo_folder_name
2427
from huggingface_hub.hf_api import HfApi, ModelInfo
2528
from huggingface_hub.utils import (
2629
GatedRepoError,
@@ -821,6 +824,48 @@ def upload_folder(
821824
return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path
822825

823826

827+
def cleanup_local_hf_model_artifact(
828+
model_name: str,
829+
local_dir: str = None,
830+
):
831+
"""
832+
Helper function that deletes local artifacts downloaded from Hugging Face to free up disk space.
833+
Parameters
834+
----------
835+
model_name (str): Name of the huggingface model
836+
local_dir (str): Local directory where the object is downloaded
837+
838+
"""
839+
if local_dir and os.path.exists(local_dir):
840+
model_dir = os.path.join(local_dir, model_name)
841+
model_dir = (
842+
os.path.dirname(model_dir)
843+
if "/" in model_name or os.sep in model_name
844+
else model_dir
845+
)
846+
shutil.rmtree(model_dir, ignore_errors=True)
847+
if os.path.exists(model_dir):
848+
logger.debug(
849+
f"Could not delete local model artifact directory: {model_dir}"
850+
)
851+
else:
852+
logger.debug(f"Deleted local model artifact directory: {model_dir}.")
853+
854+
hf_local_path = os.path.join(
855+
HF_HUB_CACHE, repo_folder_name(repo_id=model_name, repo_type="model")
856+
)
857+
shutil.rmtree(hf_local_path, ignore_errors=True)
858+
859+
if os.path.exists(hf_local_path):
860+
logger.debug(
861+
f"Could not clear the local Hugging Face cache directory {hf_local_path} for the model {model_name}."
862+
)
863+
else:
864+
logger.debug(
865+
f"Cleared contents of local Hugging Face cache directory {hf_local_path} for the model {model_name}."
866+
)
867+
868+
824869
def is_service_managed_container(container):
825870
return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
826871

ads/aqua/extension/model_handler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
from typing import Optional
@@ -131,6 +131,10 @@ def post(self, *args, **kwargs): # noqa: ARG002
131131
download_from_hf = (
132132
str(input_data.get("download_from_hf", "false")).lower() == "true"
133133
)
134+
local_dir = input_data.get("local_dir")
135+
cleanup_model_cache = (
136+
str(input_data.get("cleanup_model_cache", "true")).lower() == "true"
137+
)
134138
inference_container_uri = input_data.get("inference_container_uri")
135139
allow_patterns = input_data.get("allow_patterns")
136140
ignore_patterns = input_data.get("ignore_patterns")
@@ -142,6 +146,8 @@ def post(self, *args, **kwargs): # noqa: ARG002
142146
model=model,
143147
os_path=os_path,
144148
download_from_hf=download_from_hf,
149+
local_dir=local_dir,
150+
cleanup_model_cache=cleanup_model_cache,
145151
inference_container=inference_container,
146152
finetuning_container=finetuning_container,
147153
compartment_id=compartment_id,

ads/aqua/model/entities.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
"""
@@ -283,6 +283,7 @@ class ImportModelDetails(CLIBuilderMixin):
283283
os_path: str
284284
download_from_hf: Optional[bool] = True
285285
local_dir: Optional[str] = None
286+
cleanup_model_cache: Optional[bool] = True
286287
inference_container: Optional[str] = None
287288
finetuning_container: Optional[str] = None
288289
compartment_id: Optional[str] = None

ads/aqua/model/model.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
import os
55
import pathlib
@@ -24,6 +24,7 @@
2424
from ads.aqua.common.utils import (
2525
LifecycleStatus,
2626
_build_resource_identifier,
27+
cleanup_local_hf_model_artifact,
2728
copy_model_config,
2829
create_word_icon,
2930
generate_tei_cmd_var,
@@ -1349,20 +1350,20 @@ def _download_model_from_hf(
13491350
Returns
13501351
-------
13511352
model_artifact_path (str): Location where the model artifacts are downloaded.
1352-
13531353
"""
13541354
# Download the model from hub
1355-
if not local_dir:
1356-
local_dir = os.path.join(os.path.expanduser("~"), "cached-model")
1357-
local_dir = os.path.join(local_dir, model_name)
1358-
os.makedirs(local_dir, exist_ok=True)
1359-
snapshot_download(
1355+
if local_dir:
1356+
local_dir = os.path.join(local_dir, model_name)
1357+
os.makedirs(local_dir, exist_ok=True)
1358+
1359+
# if local_dir is not set, the return value points to the cached data folder
1360+
local_dir = snapshot_download(
13601361
repo_id=model_name,
13611362
local_dir=local_dir,
13621363
allow_patterns=allow_patterns,
13631364
ignore_patterns=ignore_patterns,
13641365
)
1365-
# Upload to object storage and skip .cache/huggingface/ folder
1366+
# Upload to object storage
13661367
model_artifact_path = upload_folder(
13671368
os_path=os_path,
13681369
local_dir=local_dir,
@@ -1392,6 +1393,8 @@ def register(
13921393
ignore_patterns (list): Model files matching any of the patterns are not downloaded.
13931394
Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
13941395
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
1396+
cleanup_model_cache (bool): Deletes downloaded files from local machine after model is successfully
1397+
registered. Set to True by default.
13951398
13961399
Returns:
13971400
AquaModel:
@@ -1501,6 +1504,14 @@ def register(
15011504
detail=validation_result.telemetry_model_name,
15021505
)
15031506

1507+
if (
1508+
import_model_details.download_from_hf
1509+
and import_model_details.cleanup_model_cache
1510+
):
1511+
cleanup_local_hf_model_artifact(
1512+
model_name=model_name, local_dir=import_model_details.local_dir
1513+
)
1514+
15041515
return AquaModel(**aqua_model_attributes)
15051516

15061517
def _if_show(self, model: DataScienceModel) -> bool:

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*--
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import os
@@ -653,12 +653,12 @@ def test_get_model_fine_tuned(
653653
}
654654

655655
@pytest.mark.parametrize(
656-
("artifact_location_set", "download_from_hf"),
656+
("artifact_location_set", "download_from_hf", "cleanup_model_cache"),
657657
[
658-
(True, True),
659-
(True, False),
660-
(False, True),
661-
(False, False),
658+
(True, True, True),
659+
(True, False, True),
660+
(False, True, False),
661+
(False, False, True),
662662
],
663663
)
664664
@patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create")
@@ -683,6 +683,7 @@ def test_import_verified_model(
683683
mock_ocidsc_create,
684684
artifact_location_set,
685685
download_from_hf,
686+
cleanup_model_cache,
686687
mock_get_hf_model_info,
687688
mock_init_client,
688689
):
@@ -742,11 +743,13 @@ def test_import_verified_model(
742743
app = AquaModelApp()
743744
if download_from_hf:
744745
with tempfile.TemporaryDirectory() as tmpdir:
746+
mock_snapshot_download.return_value = f"{str(tmpdir)}/{model_name}"
745747
model: AquaModel = app.register(
746748
model=model_name,
747749
os_path=os_path,
748750
local_dir=str(tmpdir),
749751
download_from_hf=True,
752+
cleanup_model_cache=cleanup_model_cache,
750753
allow_patterns=["*.json"],
751754
ignore_patterns=["test.json"],
752755
)
@@ -761,6 +764,20 @@ def test_import_verified_model(
761764
f"oci os object bulk-upload --src-dir {str(tmpdir)}/{model_name} --prefix prefix/path/{model_name}/ -bn aqua-bkt -ns aqua-ns --auth api_key --profile DEFAULT --no-overwrite --exclude {HF_METADATA_FOLDER}*"
762765
)
763766
)
767+
if cleanup_model_cache:
768+
cache_dir = os.path.join(
769+
os.path.expanduser("~"),
770+
".cache",
771+
"huggingface",
772+
"hub",
773+
"models--oracle--aqua-1t-mega-model",
774+
)
775+
assert (
776+
os.path.exists(f"{str(tmpdir)}/{os.path.dirname(model_name)}")
777+
is False
778+
)
779+
assert os.path.exists(cache_dir) is False
780+
764781
else:
765782
model: AquaModel = app.register(
766783
model="ocid1.datasciencemodel.xxx.xxxx.",
@@ -1183,22 +1200,22 @@ def test_import_model_with_input_tags(
11831200
"model": "oracle/oracle-1it",
11841201
"inference_container": "odsc-vllm-serving",
11851202
},
1186-
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --inference_container odsc-vllm-serving",
1203+
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --cleanup_model_cache True --inference_container odsc-vllm-serving",
11871204
),
11881205
(
11891206
{
11901207
"os_path": "oci://aqua-bkt@aqua-ns/path",
11911208
"model": "ocid1.datasciencemodel.oc1.iad.<OCID>",
11921209
},
1193-
"ads aqua model register --model ocid1.datasciencemodel.oc1.iad.<OCID> --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True",
1210+
"ads aqua model register --model ocid1.datasciencemodel.oc1.iad.<OCID> --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --cleanup_model_cache True",
11941211
),
11951212
(
11961213
{
11971214
"os_path": "oci://aqua-bkt@aqua-ns/path",
11981215
"model": "oracle/oracle-1it",
11991216
"download_from_hf": False,
12001217
},
1201-
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf False",
1218+
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf False --cleanup_model_cache True",
12021219
),
12031220
(
12041221
{
@@ -1207,7 +1224,7 @@ def test_import_model_with_input_tags(
12071224
"download_from_hf": True,
12081225
"model_file": "test_model_file",
12091226
},
1210-
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --model_file test_model_file",
1227+
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --cleanup_model_cache True --model_file test_model_file",
12111228
),
12121229
(
12131230
{
@@ -1216,7 +1233,7 @@ def test_import_model_with_input_tags(
12161233
"inference_container": "odsc-tei-serving",
12171234
"inference_container_uri": "<region>.ocir.io/<your_tenancy>/<your_image>",
12181235
},
1219-
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --inference_container odsc-tei-serving --inference_container_uri <region>.ocir.io/<your_tenancy>/<your_image>",
1236+
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --cleanup_model_cache True --inference_container odsc-tei-serving --inference_container_uri <region>.ocir.io/<your_tenancy>/<your_image>",
12201237
),
12211238
(
12221239
{
@@ -1227,7 +1244,7 @@ def test_import_model_with_input_tags(
12271244
"defined_tags": {"dtag1": "dvalue1", "dtag2": "dvalue2"},
12281245
},
12291246
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path "
1230-
"--download_from_hf True --inference_container odsc-vllm-serving --freeform_tags "
1247+
"--download_from_hf True --cleanup_model_cache True --inference_container odsc-vllm-serving --freeform_tags "
12311248
'{"ftag1": "fvalue1", "ftag2": "fvalue2"} --defined_tags {"dtag1": "dvalue1", "dtag2": "dvalue2"}',
12321249
),
12331250
],

tests/unitary/with_extras/aqua/test_model_handler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*--
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
from unittest import TestCase
@@ -213,6 +213,8 @@ def test_register(
213213
project_id=None,
214214
model_file=model_file,
215215
download_from_hf=download_from_hf,
216+
local_dir=None,
217+
cleanup_model_cache=True,
216218
inference_container_uri=inference_container_uri,
217219
allow_patterns=allow_patterns,
218220
ignore_patterns=ignore_patterns,

0 commit comments

Comments
 (0)