13
13
import requests # pylint: disable=import-error
14
14
15
15
from . import logging , tqdm
16
- from .constants import MLC_LLM_HOME , MLC_TEMP_DIR
16
+ from .constants import (
17
+ MLC_DOWNLOAD_POLICY ,
18
+ MLC_LLM_HOME ,
19
+ MLC_LLM_READONLY_WEIGHT_CACHES ,
20
+ MLC_TEMP_DIR ,
21
+ )
17
22
from .style import bold
18
23
19
24
logger = logging .getLogger (__name__ )
20
25
21
26
27
+ def log_download_policy ():
28
+ """log current download policy"""
29
+ logger .info (
30
+ "%s = %s. Can be one of: ON, OFF, REDO, READONLY" ,
31
+ bold ("MLC_DOWNLOAD_POLICY" ),
32
+ MLC_DOWNLOAD_POLICY ,
33
+ )
34
+
35
+
22
36
def _ensure_directory_not_exist (path : Path , force_redo : bool ) -> None :
23
37
if path .exists ():
24
38
if force_redo :
@@ -110,12 +124,16 @@ def download_file(
110
124
return url , destination
111
125
112
126
113
- def download_mlc_weights ( # pylint: disable=too-many-locals
127
+ def download_and_cache_mlc_weights ( # pylint: disable=too-many-locals
114
128
model_url : str ,
115
129
num_processes : int = 4 ,
116
- force_redo : bool = False ,
130
+ force_redo : Optional [ bool ] = None ,
117
131
) -> Path :
118
132
"""Download weights for a model from the HuggingFace Git LFS repo."""
133
+ log_download_policy ()
134
+ if MLC_DOWNLOAD_POLICY == "OFF" :
135
+ raise RuntimeError (f"Cannot download { model_url } as MLC_DOWNLOAD_POLICY=OFF" )
136
+
119
137
prefixes , mlc_prefix = ["HF://" , "https://huggingface.co/" ], ""
120
138
mlc_prefix = next (p for p in prefixes if model_url .startswith (p ))
121
139
assert mlc_prefix
@@ -126,12 +144,36 @@ def download_mlc_weights( # pylint: disable=too-many-locals
126
144
if model_url .count ("/" ) != 1 + mlc_prefix .count ("/" ) or not model_url .startswith (mlc_prefix ):
127
145
raise ValueError (f"Invalid model URL: { model_url } " )
128
146
user , repo = model_url [len (mlc_prefix ) :].split ("/" )
129
- git_dir = MLC_LLM_HOME / "model_weights" / user / repo
147
+ domain = "hf"
148
+
149
+ readonly_cache_dirs = []
150
+ for base in MLC_LLM_READONLY_WEIGHT_CACHES :
151
+ cache_dir = base / domain / user / repo
152
+ readonly_cache_dirs .append (str (cache_dir ))
153
+ if (cache_dir / "mlc-chat-config.json" ).is_file ():
154
+ logger .info ("Use cached weight: %s" , bold (str (cache_dir )))
155
+ return cache_dir
156
+
157
+ if force_redo is None :
158
+ force_redo = MLC_DOWNLOAD_POLICY == "REDO"
159
+
160
+ git_dir = MLC_LLM_HOME / "model_weights" / domain / user / repo
161
+ readonly_cache_dirs .append (str (git_dir ))
162
+
130
163
try :
131
164
_ensure_directory_not_exist (git_dir , force_redo = force_redo )
132
165
except ValueError :
133
166
logger .info ("Weights already downloaded: %s" , bold (str (git_dir )))
134
167
return git_dir
168
+
169
+ if MLC_DOWNLOAD_POLICY == "READONLY" :
170
+ raise RuntimeError (
171
+ f"Cannot find cache for { model_url } , "
172
+ "cannot proceed to download as MLC_DOWNLOAD_POLICY=READONLY, "
173
+ "please check settings MLC_LLM_READONLY_WEIGHT_CACHES, "
174
+ f"local path candidates: { readonly_cache_dirs } "
175
+ )
176
+
135
177
with tempfile .TemporaryDirectory (dir = MLC_TEMP_DIR ) as tmp_dir_prefix :
136
178
tmp_dir = Path (tmp_dir_prefix ) / "tmp"
137
179
git_url = git_url_template .format (user = user , repo = repo )
0 commit comments