Skip to content

Commit 3b272eb

Browse files
authored
[DOWNLOAD] MLC_DOWNLOAD_POLICY and MLC_LLM_READONLY_WEIGHT_CACHES (#2421)
This PR introduces support for MLC_DOWNLOAD_POLICY and MLC_LLM_READONLY_WEIGHT_CACHES Allows reading from readonly cache besides MLC_LLM_HOME. Also introduces a domain subfolder in cached weights
1 parent 0b2cbb2 commit 3b272eb

File tree

4 files changed

+71
-14
lines changed

4 files changed

+71
-14
lines changed

python/mlc_llm/chat_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,11 +356,11 @@ def _get_model_path(model: str) -> Tuple[str, str]:
356356
"""
357357
if model.startswith("HF://"):
358358
from mlc_llm.support.download import ( # pylint: disable=import-outside-toplevel
359-
download_mlc_weights,
359+
download_and_cache_mlc_weights,
360360
)
361361

362362
logger.info("Downloading model from HuggingFace: %s", model)
363-
mlc_dir = download_mlc_weights(model)
363+
mlc_dir = download_and_cache_mlc_weights(model)
364364
cfg_dir = mlc_dir / "mlc-chat-config.json"
365365
return str(mlc_dir), str(cfg_dir)
366366

python/mlc_llm/support/auto_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ def detect_mlc_chat_config(mlc_chat_config: str) -> Path:
3535
# pylint: disable=import-outside-toplevel
3636
from mlc_llm.model import MODEL_PRESETS
3737

38-
from .download import download_mlc_weights
38+
from .download import download_and_cache_mlc_weights
3939

4040
# pylint: enable=import-outside-toplevel
4141

4242
if mlc_chat_config.startswith("HF://") or mlc_chat_config.startswith("http"):
43-
mlc_chat_config_path = Path(download_mlc_weights(model_url=mlc_chat_config))
43+
mlc_chat_config_path = Path(download_and_cache_mlc_weights(model_url=mlc_chat_config))
4444
elif isinstance(mlc_chat_config, str) and mlc_chat_config in MODEL_PRESETS:
4545
logger.info("%s mlc preset model: %s", FOUND, mlc_chat_config)
4646
content = MODEL_PRESETS[mlc_chat_config].copy()

python/mlc_llm/support/constants.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@ def _check():
1313
f"but got {MLC_JIT_POLICY}."
1414
)
1515

16+
if MLC_DOWNLOAD_POLICY not in ["ON", "OFF", "REDO", "READONLY"]:
17+
raise ValueError(
18+
"Invalid MLC_AUTO_DOWNLOAD_POLICY. "
19+
'It has to be one of "ON", "OFF", "REDO", "READONLY"'
20+
f"but got {MLC_DOWNLOAD_POLICY}."
21+
)
22+
1623

1724
def _get_cache_dir() -> Path:
1825
if "MLC_LLM_HOME" in os.environ:
@@ -48,23 +55,31 @@ def _get_dso_suffix() -> str:
4855

4956

5057
def _get_test_model_path() -> List[Path]:
51-
if "MLC_TEST_MODEL_PATH" in os.environ:
52-
return [Path(p) for p in os.environ["MLC_TEST_MODEL_PATH"].split(os.pathsep)]
58+
if "MLC_LLM_TEST_MODEL_PATH" in os.environ:
59+
return [Path(p) for p in os.environ["MLC_LLM_TEST_MODEL_PATH"].split(os.pathsep)]
5360
# by default, we reuse the cache dir via mlc_llm chat
5461
# note that we do not auto download for testcase
5562
# to avoid networking dependencies
56-
return [
57-
_get_cache_dir() / "model_weights" / "mlc-ai",
58-
Path(os.path.abspath(os.path.curdir)),
63+
base_list = ["hf"]
64+
return [_get_cache_dir() / "model_weights" / base / "mlc-ai" for base in base_list] + [
65+
Path(os.path.abspath(os.path.curdir))
5966
]
6067

6168

69+
def _get_read_only_weight_caches() -> List[Path]:
70+
if "MLC_LLM_READONLY_WEIGHT_CACHES" in os.environ:
71+
return [Path(p) for p in os.environ["MLC_LLM_READONLY_WEIGHT_CACHES"].split(os.pathsep)]
72+
return []
73+
74+
6275
MLC_TEMP_DIR = os.getenv("MLC_TEMP_DIR", None)
6376
MLC_MULTI_ARCH = os.environ.get("MLC_MULTI_ARCH", None)
64-
MLC_LLM_HOME: Path = _get_cache_dir()
6577
MLC_JIT_POLICY = os.environ.get("MLC_JIT_POLICY", "ON")
6678
MLC_DSO_SUFFIX = _get_dso_suffix()
6779
MLC_TEST_MODEL_PATH: List[Path] = _get_test_model_path()
6880

81+
MLC_DOWNLOAD_POLICY = os.environ.get("MLC_DOWNLOAD_POLICY", "ON")
82+
MLC_LLM_HOME: Path = _get_cache_dir()
83+
MLC_LLM_READONLY_WEIGHT_CACHES = _get_read_only_weight_caches()
6984

7085
_check()

python/mlc_llm/support/download.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,26 @@
1313
import requests # pylint: disable=import-error
1414

1515
from . import logging, tqdm
16-
from .constants import MLC_LLM_HOME, MLC_TEMP_DIR
16+
from .constants import (
17+
MLC_DOWNLOAD_POLICY,
18+
MLC_LLM_HOME,
19+
MLC_LLM_READONLY_WEIGHT_CACHES,
20+
MLC_TEMP_DIR,
21+
)
1722
from .style import bold
1823

1924
logger = logging.getLogger(__name__)
2025

2126

27+
def log_download_policy():
28+
"""log current download policy"""
29+
logger.info(
30+
"%s = %s. Can be one of: ON, OFF, REDO, READONLY",
31+
bold("MLC_DOWNLOAD_POLICY"),
32+
MLC_DOWNLOAD_POLICY,
33+
)
34+
35+
2236
def _ensure_directory_not_exist(path: Path, force_redo: bool) -> None:
2337
if path.exists():
2438
if force_redo:
@@ -110,12 +124,16 @@ def download_file(
110124
return url, destination
111125

112126

113-
def download_mlc_weights( # pylint: disable=too-many-locals
127+
def download_and_cache_mlc_weights( # pylint: disable=too-many-locals
114128
model_url: str,
115129
num_processes: int = 4,
116-
force_redo: bool = False,
130+
force_redo: Optional[bool] = None,
117131
) -> Path:
118132
"""Download weights for a model from the HuggingFace Git LFS repo."""
133+
log_download_policy()
134+
if MLC_DOWNLOAD_POLICY == "OFF":
135+
raise RuntimeError(f"Cannot download {model_url} as MLC_DOWNLOAD_POLICY=OFF")
136+
119137
prefixes, mlc_prefix = ["HF://", "https://huggingface.co/"], ""
120138
mlc_prefix = next(p for p in prefixes if model_url.startswith(p))
121139
assert mlc_prefix
@@ -126,12 +144,36 @@ def download_mlc_weights( # pylint: disable=too-many-locals
126144
if model_url.count("/") != 1 + mlc_prefix.count("/") or not model_url.startswith(mlc_prefix):
127145
raise ValueError(f"Invalid model URL: {model_url}")
128146
user, repo = model_url[len(mlc_prefix) :].split("/")
129-
git_dir = MLC_LLM_HOME / "model_weights" / user / repo
147+
domain = "hf"
148+
149+
readonly_cache_dirs = []
150+
for base in MLC_LLM_READONLY_WEIGHT_CACHES:
151+
cache_dir = base / domain / user / repo
152+
readonly_cache_dirs.append(str(cache_dir))
153+
if (cache_dir / "mlc-chat-config.json").is_file():
154+
logger.info("Use cached weight: %s", bold(str(cache_dir)))
155+
return cache_dir
156+
157+
if force_redo is None:
158+
force_redo = MLC_DOWNLOAD_POLICY == "REDO"
159+
160+
git_dir = MLC_LLM_HOME / "model_weights" / domain / user / repo
161+
readonly_cache_dirs.append(str(git_dir))
162+
130163
try:
131164
_ensure_directory_not_exist(git_dir, force_redo=force_redo)
132165
except ValueError:
133166
logger.info("Weights already downloaded: %s", bold(str(git_dir)))
134167
return git_dir
168+
169+
if MLC_DOWNLOAD_POLICY == "READONLY":
170+
raise RuntimeError(
171+
f"Cannot find cache for {model_url}, "
172+
"cannot proceed to download as MLC_DOWNLOAD_POLICY=READONLY, "
173+
"please check settings MLC_LLM_READONLY_WEIGHT_CACHES, "
174+
f"local path candidates: {readonly_cache_dirs}"
175+
)
176+
135177
with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir_prefix:
136178
tmp_dir = Path(tmp_dir_prefix) / "tmp"
137179
git_url = git_url_template.format(user=user, repo=repo)

0 commit comments

Comments
 (0)