40
40
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE ,
41
41
AQUA_MODEL_ARTIFACT_FILE ,
42
42
AQUA_MODEL_TYPE_CUSTOM ,
43
+ HF_METADATA_FOLDER ,
43
44
LICENSE_TXT ,
44
45
MODEL_BY_REFERENCE_OSS_PATH_KEY ,
45
46
README ,
@@ -1274,6 +1275,8 @@ def _download_model_from_hf(
1274
1275
model_name : str ,
1275
1276
os_path : str ,
1276
1277
local_dir : str = None ,
1278
+ allow_patterns : List [str ] = None ,
1279
+ ignore_patterns : List [str ] = None ,
1277
1280
) -> str :
1278
1281
"""This helper function downloads the model artifact from Hugging Face to a local folder, then uploads
1279
1282
to object storage location.
@@ -1283,6 +1286,12 @@ def _download_model_from_hf(
1283
1286
model_name (str): The huggingface model name.
1284
1287
os_path (str): The OS path where the model files are located.
1285
1288
local_dir (str): The local temp dir to store the huggingface model.
1289
+ allow_patterns (list): Model files matching at least one pattern are downloaded.
1290
+ Example: ["*.json"] will download all .json files. ["folder/*"] will download all files under `folder`.
1291
+ Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
1292
+ ignore_patterns (list): Model files matching any of the patterns are not downloaded.
1293
+ Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
1294
+ Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
1286
1295
1287
1296
Returns
1288
1297
-------
@@ -1293,30 +1302,19 @@ def _download_model_from_hf(
1293
1302
if not local_dir :
1294
1303
local_dir = os .path .join (os .path .expanduser ("~" ), "cached-model" )
1295
1304
local_dir = os .path .join (local_dir , model_name )
1296
- retry = 10
1297
- i = 0
1298
- huggingface_download_err_message = None
1299
- while i < retry :
1300
- try :
1301
- # Download to cache folder. The while loop retries when there is a network failure
1302
- snapshot_download (repo_id = model_name )
1303
- except Exception as e :
1304
- huggingface_download_err_message = str (e )
1305
- i += 1
1306
- else :
1307
- break
1308
- if i == retry :
1309
- raise Exception (
1310
- f"Could not download the model { model_name } from https://huggingface.co with message { huggingface_download_err_message } "
1311
- )
1312
1305
os .makedirs (local_dir , exist_ok = True )
1313
- # Copy the model from the cache to destination
1314
- snapshot_download (repo_id = model_name , local_dir = local_dir )
1315
- # Upload to object storage
1306
+ snapshot_download (
1307
+ repo_id = model_name ,
1308
+ local_dir = local_dir ,
1309
+ allow_patterns = allow_patterns ,
1310
+ ignore_patterns = ignore_patterns ,
1311
+ )
1312
+ # Upload to object storage and skip .cache/huggingface/ folder
1316
1313
model_artifact_path = upload_folder (
1317
1314
os_path = os_path ,
1318
1315
local_dir = local_dir ,
1319
1316
model_name = model_name ,
1317
+ exclude_pattern = f"{ HF_METADATA_FOLDER } *"
1320
1318
)
1321
1319
1322
1320
return model_artifact_path
@@ -1335,6 +1333,12 @@ def register(
1335
1333
os_path (str): Object storage destination URI to store the downloaded model. Format: oci://bucket-name@namespace/prefix
1336
1334
inference_container (str): selects service defaults
1337
1335
finetuning_container (str): selects service defaults
1336
+ allow_patterns (list): Model files matching at least one pattern are downloaded.
1337
+ Example: ["*.json"] will download all .json files. ["folder/*"] will download all files under `folder`.
1338
+ Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
1339
+ ignore_patterns (list): Model files matching any of the patterns are not downloaded.
1340
+ Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
1341
+ Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
1338
1342
1339
1343
Returns:
1340
1344
AquaModel:
@@ -1381,6 +1385,8 @@ def register(
1381
1385
model_name = model_name ,
1382
1386
os_path = import_model_details .os_path ,
1383
1387
local_dir = import_model_details .local_dir ,
1388
+ allow_patterns = import_model_details .allow_patterns ,
1389
+ ignore_patterns = import_model_details .ignore_patterns ,
1384
1390
).rstrip ("/" )
1385
1391
else :
1386
1392
artifact_path = import_model_details .os_path .rstrip ("/" )
0 commit comments