Skip to content

Commit 212ec33

Browse files
committed
Improved register BYOM model.
1 parent 7cdaee1 commit 212ec33

File tree

8 files changed

+43
-26
lines changed

8 files changed

+43
-26
lines changed

ads/aqua/common/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,13 +788,14 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
788788
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
789789

790790

791-
def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
791+
def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None) -> str:
792792
"""Upload the local folder to the object storage
793793
794794
Args:
795795
os_path (str): object storage URI with prefix. This is the path to upload
796796
local_dir (str): Local directory where the object is downloaded
797797
model_name (str): Name of the huggingface model
798+
exclude_pattern (optional, str): The matching pattern of files to be excluded from uploading.
798799
Retuns:
799800
str: Object name inside the bucket
800801
"""
@@ -804,6 +805,8 @@ def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
804805
auth_state = AuthState()
805806
object_path = os_details.filepath.rstrip("/") + "/" + model_name + "/"
806807
command = f"oci os object bulk-upload --src-dir {local_dir} --prefix {object_path} -bn {os_details.bucket} -ns {os_details.namespace} --auth {auth_state.oci_iam_type} --profile {auth_state.oci_key_profile} --no-overwrite"
808+
if exclude_pattern:
809+
command += f" --exclude {exclude_pattern}"
807810
try:
808811
logger.info(f"Running: {command}")
809812
subprocess.check_call(shlex.split(command))

ads/aqua/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME = "_name_or_path"
3636
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE = "model_type"
3737
AQUA_MODEL_ARTIFACT_FILE = "model_file"
38+
HF_METADATA_FOLDER = ".cache/huggingface/"
3839
HF_LOGIN_DEFAULT_TIMEOUT = 2
3940

4041
TRAINING_METRICS_FINAL = "training_metrics_final"

ads/aqua/extension/model_handler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ def post(self, *args, **kwargs):
129129
str(input_data.get("download_from_hf", "false")).lower() == "true"
130130
)
131131
inference_container_uri = input_data.get("inference_container_uri")
132+
allow_patterns = input_data.get("allow_patterns")
133+
ignore_patterns = input_data.get("ignore_patterns")
132134

133135
return self.finish(
134136
AquaModelApp().register(
@@ -141,6 +143,8 @@ def post(self, *args, **kwargs):
141143
project_id=project_id,
142144
model_file=model_file,
143145
inference_container_uri=inference_container_uri,
146+
allow_patterns=allow_patterns,
147+
ignore_patterns=ignore_patterns,
144148
)
145149
)
146150

ads/aqua/model/entities.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,8 @@ class ImportModelDetails(CLIBuilderMixin):
289289
project_id: Optional[str] = None
290290
model_file: Optional[str] = None
291291
inference_container_uri: Optional[str] = None
292+
allow_patterns: Optional[List[str]] = None
293+
ignore_patterns: Optional[List[str]] = None
292294

293295
def __post_init__(self):
294296
self._command = "model register"

ads/aqua/model/model.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE,
4141
AQUA_MODEL_ARTIFACT_FILE,
4242
AQUA_MODEL_TYPE_CUSTOM,
43+
HF_METADATA_FOLDER,
4344
LICENSE_TXT,
4445
MODEL_BY_REFERENCE_OSS_PATH_KEY,
4546
README,
@@ -1274,6 +1275,8 @@ def _download_model_from_hf(
12741275
model_name: str,
12751276
os_path: str,
12761277
local_dir: str = None,
1278+
allow_patterns: List[str] = None,
1279+
ignore_patterns: List[str] = None,
12771280
) -> str:
12781281
"""This helper function downloads the model artifact from Hugging Face to a local folder, then uploads
12791282
to object storage location.
@@ -1283,6 +1286,8 @@ def _download_model_from_hf(
12831286
model_name (str): The huggingface model name.
12841287
os_path (str): The OS path where the model files are located.
12851288
local_dir (str): The local temp dir to store the huggingface model.
1289+
allow_patterns (list): Model files matching at least one pattern are downloaded.
1290+
ignore_patterns (list): Model files matching any of the patterns are not downloaded.
12861291
12871292
Returns
12881293
-------
@@ -1293,30 +1298,19 @@ def _download_model_from_hf(
12931298
if not local_dir:
12941299
local_dir = os.path.join(os.path.expanduser("~"), "cached-model")
12951300
local_dir = os.path.join(local_dir, model_name)
1296-
retry = 10
1297-
i = 0
1298-
huggingface_download_err_message = None
1299-
while i < retry:
1300-
try:
1301-
# Download to cache folder. The while loop retries when there is a network failure
1302-
snapshot_download(repo_id=model_name)
1303-
except Exception as e:
1304-
huggingface_download_err_message = str(e)
1305-
i += 1
1306-
else:
1307-
break
1308-
if i == retry:
1309-
raise Exception(
1310-
f"Could not download the model {model_name} from https://huggingface.co with message {huggingface_download_err_message}"
1311-
)
13121301
os.makedirs(local_dir, exist_ok=True)
1313-
# Copy the model from the cache to destination
1314-
snapshot_download(repo_id=model_name, local_dir=local_dir)
1315-
# Upload to object storage
1302+
snapshot_download(
1303+
repo_id=model_name,
1304+
local_dir=local_dir,
1305+
allow_patterns=allow_patterns,
1306+
ignore_patterns=ignore_patterns,
1307+
)
1308+
# Upload to object storage and skip .cache/huggingface/ folder
13161309
model_artifact_path = upload_folder(
13171310
os_path=os_path,
13181311
local_dir=local_dir,
13191312
model_name=model_name,
1313+
exclude_pattern=f"{HF_METADATA_FOLDER}*"
13201314
)
13211315

13221316
return model_artifact_path
@@ -1381,6 +1375,8 @@ def register(
13811375
model_name=model_name,
13821376
os_path=import_model_details.os_path,
13831377
local_dir=import_model_details.local_dir,
1378+
allow_patterns=import_model_details.allow_patterns,
1379+
ignore_patterns=import_model_details.ignore_patterns,
13841380
).rstrip("/")
13851381
else:
13861382
artifact_path = import_model_details.os_path.rstrip("/")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ opctl = [
125125
"rich",
126126
"fire",
127127
"cachetools",
128-
"huggingface_hub==0.23.4"
128+
"huggingface_hub==0.26.2"
129129
]
130130
optuna = ["optuna==2.9.0", "oracle_ads[viz]"]
131131
spark = ["pyspark>=3.0.0"]

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from unittest.mock import MagicMock, patch
1313

1414
import oci
15+
from ads.aqua.constants import HF_METADATA_FOLDER
1516
import pytest
1617
from ads.aqua.ui import ModelFormat
1718
from parameterized import parameterized
@@ -746,14 +747,18 @@ def test_import_verified_model(
746747
os_path=os_path,
747748
local_dir=str(tmpdir),
748749
download_from_hf=True,
750+
allow_patterns=["*.json"],
751+
ignore_patterns=["test.json"]
749752
)
750753
mock_snapshot_download.assert_called_with(
751754
repo_id=model_name,
752755
local_dir=f"{str(tmpdir)}/{model_name}",
756+
allow_patterns=["*.json"],
757+
ignore_patterns=["test.json"]
753758
)
754759
mock_subprocess.assert_called_with(
755760
shlex.split(
756-
f"oci os object bulk-upload --src-dir {str(tmpdir)}/{model_name} --prefix prefix/path/{model_name}/ -bn aqua-bkt -ns aqua-ns --auth api_key --profile DEFAULT --no-overwrite"
761+
f"oci os object bulk-upload --src-dir {str(tmpdir)}/{model_name} --prefix prefix/path/{model_name}/ -bn aqua-bkt -ns aqua-ns --auth api_key --profile DEFAULT --no-overwrite --exclude {HF_METADATA_FOLDER}*"
757762
)
758763
)
759764
else:

tests/unitary/with_extras/aqua/test_model_handler.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,10 @@ def test_list(self, mock_list):
132132

133133
@parameterized.expand(
134134
[
135-
(None, None, False, None),
136-
("odsc-llm-fine-tuning", None, False, None),
137-
(None, "test.gguf", True, None),
138-
(None, None, True, "iad.ocir.io/<namespace>/<image>:<tag>"),
135+
(None, None, False, None, None, None),
136+
("odsc-llm-fine-tuning", None, False, None, None, ["test.json"]),
137+
(None, "test.gguf", True, None, ["*.json"], None),
138+
(None, None, True, "iad.ocir.io/<namespace>/<image>:<tag>", ["*.json"], ["test.json"]),
139139
],
140140
)
141141
@patch("notebook.base.handlers.APIHandler.finish")
@@ -146,6 +146,8 @@ def test_register(
146146
model_file,
147147
download_from_hf,
148148
inference_container_uri,
149+
allow_patterns,
150+
ignore_patterns,
149151
mock_register,
150152
mock_finish,
151153
):
@@ -165,6 +167,8 @@ def test_register(
165167
model_file=model_file,
166168
download_from_hf=download_from_hf,
167169
inference_container_uri=inference_container_uri,
170+
allow_patterns=allow_patterns,
171+
ignore_patterns=ignore_patterns
168172
)
169173
)
170174
result = self.model_handler.post()
@@ -178,6 +182,8 @@ def test_register(
178182
model_file=model_file,
179183
download_from_hf=download_from_hf,
180184
inference_container_uri=inference_container_uri,
185+
allow_patterns=allow_patterns,
186+
ignore_patterns=ignore_patterns
181187
)
182188
assert result["id"] == "test_id"
183189
assert result["inference_container"] == "odsc-tgi-serving"

0 commit comments

Comments
 (0)