Skip to content

Commit 065eeb5

Browse files
use snapshot download return value as local dir
1 parent 3569ba7 commit 065eeb5

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

ads/aqua/model/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1328,7 +1328,9 @@ def _download_model_from_hf(
13281328
if local_dir:
13291329
local_dir = os.path.join(local_dir, model_name)
13301330
os.makedirs(local_dir, exist_ok=True)
1331-
snapshot_download(
1331+
1332+
# if local_dir is not set, the return value points to the cached data folder
1333+
local_dir = snapshot_download(
13321334
repo_id=model_name,
13331335
local_dir=local_dir,
13341336
allow_patterns=allow_patterns,

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,7 @@ def test_import_verified_model(
743743
app = AquaModelApp()
744744
if download_from_hf:
745745
with tempfile.TemporaryDirectory() as tmpdir:
746+
mock_snapshot_download.return_value = f"{str(tmpdir)}/{model_name}"
746747
model: AquaModel = app.register(
747748
model=model_name,
748749
os_path=os_path,

0 commit comments

Comments
 (0)