Skip to content

Commit 0f446a5

Browse files
authored
Improved register BYOM model (#1005)
2 parents 8d358c7 + 89d6052 commit 0f446a5

File tree

8 files changed

+53
-26
lines changed

8 files changed

+53
-26
lines changed

ads/aqua/common/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,13 +788,14 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
788788
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
789789

790790

791-
def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
791+
def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None) -> str:
792792
"""Upload the local folder to the object storage
793793
794794
Args:
795795
os_path (str): object storage URI with prefix. This is the path to upload
796796
local_dir (str): Local directory where the object is downloaded
797797
model_name (str): Name of the huggingface model
798+
exclude_pattern (optional, str): The matching pattern of files to be excluded from uploading.
798799
Retuns:
799800
str: Object name inside the bucket
800801
"""
@@ -804,6 +805,8 @@ def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
804805
auth_state = AuthState()
805806
object_path = os_details.filepath.rstrip("/") + "/" + model_name + "/"
806807
command = f"oci os object bulk-upload --src-dir {local_dir} --prefix {object_path} -bn {os_details.bucket} -ns {os_details.namespace} --auth {auth_state.oci_iam_type} --profile {auth_state.oci_key_profile} --no-overwrite"
808+
if exclude_pattern:
809+
command += f" --exclude {exclude_pattern}"
807810
try:
808811
logger.info(f"Running: {command}")
809812
subprocess.check_call(shlex.split(command))

ads/aqua/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME = "_name_or_path"
3636
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE = "model_type"
3737
AQUA_MODEL_ARTIFACT_FILE = "model_file"
38+
HF_METADATA_FOLDER = ".cache/"
3839
HF_LOGIN_DEFAULT_TIMEOUT = 2
3940

4041
TRAINING_METRICS_FINAL = "training_metrics_final"

ads/aqua/extension/model_handler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ def post(self, *args, **kwargs):
129129
str(input_data.get("download_from_hf", "false")).lower() == "true"
130130
)
131131
inference_container_uri = input_data.get("inference_container_uri")
132+
allow_patterns = input_data.get("allow_patterns")
133+
ignore_patterns = input_data.get("ignore_patterns")
132134

133135
return self.finish(
134136
AquaModelApp().register(
@@ -141,6 +143,8 @@ def post(self, *args, **kwargs):
141143
project_id=project_id,
142144
model_file=model_file,
143145
inference_container_uri=inference_container_uri,
146+
allow_patterns=allow_patterns,
147+
ignore_patterns=ignore_patterns,
144148
)
145149
)
146150

ads/aqua/model/entities.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,8 @@ class ImportModelDetails(CLIBuilderMixin):
289289
project_id: Optional[str] = None
290290
model_file: Optional[str] = None
291291
inference_container_uri: Optional[str] = None
292+
allow_patterns: Optional[List[str]] = None
293+
ignore_patterns: Optional[List[str]] = None
292294

293295
def __post_init__(self):
294296
self._command = "model register"

ads/aqua/model/model.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE,
4141
AQUA_MODEL_ARTIFACT_FILE,
4242
AQUA_MODEL_TYPE_CUSTOM,
43+
HF_METADATA_FOLDER,
4344
LICENSE_TXT,
4445
MODEL_BY_REFERENCE_OSS_PATH_KEY,
4546
README,
@@ -1274,6 +1275,8 @@ def _download_model_from_hf(
12741275
model_name: str,
12751276
os_path: str,
12761277
local_dir: str = None,
1278+
allow_patterns: List[str] = None,
1279+
ignore_patterns: List[str] = None,
12771280
) -> str:
12781281
"""This helper function downloads the model artifact from Hugging Face to a local folder, then uploads
12791282
to object storage location.
@@ -1283,6 +1286,12 @@ def _download_model_from_hf(
12831286
model_name (str): The huggingface model name.
12841287
os_path (str): The OS path where the model files are located.
12851288
local_dir (str): The local temp dir to store the huggingface model.
1289+
allow_patterns (list): Model files matching at least one pattern are downloaded.
1290+
Example: ["*.json"] will download all .json files. ["folder/*"] will download all files under `folder`.
1291+
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
1292+
ignore_patterns (list): Model files matching any of the patterns are not downloaded.
1293+
Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
1294+
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
12861295
12871296
Returns
12881297
-------
@@ -1293,30 +1302,19 @@ def _download_model_from_hf(
12931302
if not local_dir:
12941303
local_dir = os.path.join(os.path.expanduser("~"), "cached-model")
12951304
local_dir = os.path.join(local_dir, model_name)
1296-
retry = 10
1297-
i = 0
1298-
huggingface_download_err_message = None
1299-
while i < retry:
1300-
try:
1301-
# Download to cache folder. The while loop retries when there is a network failure
1302-
snapshot_download(repo_id=model_name)
1303-
except Exception as e:
1304-
huggingface_download_err_message = str(e)
1305-
i += 1
1306-
else:
1307-
break
1308-
if i == retry:
1309-
raise Exception(
1310-
f"Could not download the model {model_name} from https://huggingface.co with message {huggingface_download_err_message}"
1311-
)
13121305
os.makedirs(local_dir, exist_ok=True)
1313-
# Copy the model from the cache to destination
1314-
snapshot_download(repo_id=model_name, local_dir=local_dir)
1315-
# Upload to object storage
1306+
snapshot_download(
1307+
repo_id=model_name,
1308+
local_dir=local_dir,
1309+
allow_patterns=allow_patterns,
1310+
ignore_patterns=ignore_patterns,
1311+
)
1312+
# Upload to object storage and skip .cache/huggingface/ folder
13161313
model_artifact_path = upload_folder(
13171314
os_path=os_path,
13181315
local_dir=local_dir,
13191316
model_name=model_name,
1317+
exclude_pattern=f"{HF_METADATA_FOLDER}*"
13201318
)
13211319

13221320
return model_artifact_path
@@ -1335,6 +1333,12 @@ def register(
13351333
os_path (str): Object storage destination URI to store the downloaded model. Format: oci://bucket-name@namespace/prefix
13361334
inference_container (str): selects service defaults
13371335
finetuning_container (str): selects service defaults
1336+
allow_patterns (list): Model files matching at least one pattern are downloaded.
1337+
Example: ["*.json"] will download all .json files. ["folder/*"] will download all files under `folder`.
1338+
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
1339+
ignore_patterns (list): Model files matching any of the patterns are not downloaded.
1340+
Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
1341+
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
13381342
13391343
Returns:
13401344
AquaModel:
@@ -1381,6 +1385,8 @@ def register(
13811385
model_name=model_name,
13821386
os_path=import_model_details.os_path,
13831387
local_dir=import_model_details.local_dir,
1388+
allow_patterns=import_model_details.allow_patterns,
1389+
ignore_patterns=import_model_details.ignore_patterns,
13841390
).rstrip("/")
13851391
else:
13861392
artifact_path = import_model_details.os_path.rstrip("/")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ opctl = [
125125
"rich",
126126
"fire",
127127
"cachetools",
128-
"huggingface_hub==0.23.4"
128+
"huggingface_hub==0.26.2"
129129
]
130130
optuna = ["optuna==2.9.0", "oracle_ads[viz]"]
131131
spark = ["pyspark>=3.0.0"]

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from unittest.mock import MagicMock, patch
1313

1414
import oci
15+
from ads.aqua.constants import HF_METADATA_FOLDER
1516
import pytest
1617
from ads.aqua.ui import ModelFormat
1718
from parameterized import parameterized
@@ -746,14 +747,18 @@ def test_import_verified_model(
746747
os_path=os_path,
747748
local_dir=str(tmpdir),
748749
download_from_hf=True,
750+
allow_patterns=["*.json"],
751+
ignore_patterns=["test.json"]
749752
)
750753
mock_snapshot_download.assert_called_with(
751754
repo_id=model_name,
752755
local_dir=f"{str(tmpdir)}/{model_name}",
756+
allow_patterns=["*.json"],
757+
ignore_patterns=["test.json"]
753758
)
754759
mock_subprocess.assert_called_with(
755760
shlex.split(
756-
f"oci os object bulk-upload --src-dir {str(tmpdir)}/{model_name} --prefix prefix/path/{model_name}/ -bn aqua-bkt -ns aqua-ns --auth api_key --profile DEFAULT --no-overwrite"
761+
f"oci os object bulk-upload --src-dir {str(tmpdir)}/{model_name} --prefix prefix/path/{model_name}/ -bn aqua-bkt -ns aqua-ns --auth api_key --profile DEFAULT --no-overwrite --exclude {HF_METADATA_FOLDER}*"
757762
)
758763
)
759764
else:

tests/unitary/with_extras/aqua/test_model_handler.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,10 @@ def test_list(self, mock_list):
132132

133133
@parameterized.expand(
134134
[
135-
(None, None, False, None),
136-
("odsc-llm-fine-tuning", None, False, None),
137-
(None, "test.gguf", True, None),
138-
(None, None, True, "iad.ocir.io/<namespace>/<image>:<tag>"),
135+
(None, None, False, None, None, None),
136+
("odsc-llm-fine-tuning", None, False, None, None, ["test.json"]),
137+
(None, "test.gguf", True, None, ["*.json"], None),
138+
(None, None, True, "iad.ocir.io/<namespace>/<image>:<tag>", ["*.json"], ["test.json"]),
139139
],
140140
)
141141
@patch("notebook.base.handlers.APIHandler.finish")
@@ -146,6 +146,8 @@ def test_register(
146146
model_file,
147147
download_from_hf,
148148
inference_container_uri,
149+
allow_patterns,
150+
ignore_patterns,
149151
mock_register,
150152
mock_finish,
151153
):
@@ -165,6 +167,8 @@ def test_register(
165167
model_file=model_file,
166168
download_from_hf=download_from_hf,
167169
inference_container_uri=inference_container_uri,
170+
allow_patterns=allow_patterns,
171+
ignore_patterns=ignore_patterns
168172
)
169173
)
170174
result = self.model_handler.post()
@@ -178,6 +182,8 @@ def test_register(
178182
model_file=model_file,
179183
download_from_hf=download_from_hf,
180184
inference_container_uri=inference_container_uri,
185+
allow_patterns=allow_patterns,
186+
ignore_patterns=ignore_patterns
181187
)
182188
assert result["id"] == "test_id"
183189
assert result["inference_container"] == "odsc-tgi-serving"

0 commit comments

Comments
 (0)