Skip to content

Commit f3afcba

Browse files
authored
[ray.llm] Refactor model download utilities (#51604)
Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
1 parent ffe5a09 commit f3afcba

File tree

10 files changed

+278
-181
lines changed

10 files changed

+278
-181
lines changed

python/ray/llm/_internal/batch/processor/vllm_engine_proc.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
ProcessorConfig,
1212
ProcessorBuilder,
1313
)
14-
from ray.llm._internal.batch.utils import download_hf_model
1514
from ray.llm._internal.batch.stages import (
1615
vLLMEngineStage,
1716
ChatTemplateStage,
@@ -25,6 +24,10 @@
2524
BatchModelTelemetry,
2625
)
2726
from ray.llm._internal.common.observability.telemetry_utils import DEFAULT_GPU_TYPE
27+
from ray.llm._internal.common.utils.download_utils import (
28+
download_model_files,
29+
NodeModelDownloadable,
30+
)
2831
from ray.llm._internal.batch.observability.usage_telemetry.usage import (
2932
get_or_create_telemetry_agent,
3033
)
@@ -215,7 +218,12 @@ def build_vllm_engine_processor(
215218
)
216219
)
217220

218-
model_path = download_hf_model(config.model_source, tokenizer_only=True)
221+
model_path = download_model_files(
222+
model_id=config.model_source,
223+
mirror_config=None,
224+
download_model=NodeModelDownloadable.TOKENIZER_ONLY,
225+
download_extra_files=False,
226+
)
219227
hf_config = transformers.AutoConfig.from_pretrained(model_path)
220228
architecture = getattr(hf_config, "architectures", [DEFAULT_MODEL_ARCHITECTURE])[0]
221229

python/ray/llm/_internal/batch/stages/chat_template_stage.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
StatefulStage,
77
StatefulStageUDF,
88
)
9-
from ray.llm._internal.batch.utils import download_hf_model
9+
from ray.llm._internal.common.utils.download_utils import (
10+
download_model_files,
11+
NodeModelDownloadable,
12+
)
1013

1114

1215
class ChatTemplateUDF(StatefulStageUDF):
@@ -36,7 +39,12 @@ def __init__(
3639
# because tokenizers of VLM models may not have chat template attribute.
3740
# However, this may not be a reliable solution, because processors and
3841
# tokenizers are not standardized across different models.
39-
model_path = download_hf_model(model, tokenizer_only=True)
42+
model_path = download_model_files(
43+
model_id=model,
44+
mirror_config=None,
45+
download_model=NodeModelDownloadable.TOKENIZER_ONLY,
46+
download_extra_files=False,
47+
)
4048
self.processor = AutoProcessor.from_pretrained(
4149
model_path, trust_remote_code=True
4250
)

python/ray/llm/_internal/batch/stages/tokenize_stage.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
StatefulStage,
77
StatefulStageUDF,
88
)
9-
from ray.llm._internal.batch.utils import (
10-
get_cached_tokenizer,
11-
download_hf_model,
9+
from ray.llm._internal.common.utils.download_utils import (
10+
download_model_files,
11+
NodeModelDownloadable,
1212
)
13+
from ray.llm._internal.batch.utils import get_cached_tokenizer
1314

1415

1516
class TokenizeUDF(StatefulStageUDF):
@@ -30,7 +31,12 @@ def __init__(
3031
from transformers import AutoTokenizer
3132

3233
super().__init__(data_column, expected_input_keys)
33-
model_path = download_hf_model(model, tokenizer_only=True)
34+
model_path = download_model_files(
35+
model_id=model,
36+
mirror_config=None,
37+
download_model=NodeModelDownloadable.TOKENIZER_ONLY,
38+
download_extra_files=False,
39+
)
3440
self.tokenizer = get_cached_tokenizer(
3541
AutoTokenizer.from_pretrained(
3642
model_path,
@@ -88,7 +94,12 @@ def __init__(
8894
from transformers import AutoTokenizer
8995

9096
super().__init__(data_column, expected_input_keys)
91-
model_path = download_hf_model(model, tokenizer_only=True)
97+
model_path = download_model_files(
98+
model_id=model,
99+
mirror_config=None,
100+
download_model=NodeModelDownloadable.TOKENIZER_ONLY,
101+
download_extra_files=False,
102+
)
92103
self.tokenizer = get_cached_tokenizer(
93104
AutoTokenizer.from_pretrained(
94105
model_path,

python/ray/llm/_internal/batch/stages/vllm_engine_stage.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
StatefulStage,
1818
StatefulStageUDF,
1919
)
20-
from ray.llm._internal.batch.utils import (
20+
from ray.llm._internal.common.utils.cloud_utils import is_remote_path
21+
from ray.llm._internal.common.utils.download_utils import (
2122
download_lora_adapter,
22-
download_hf_model,
23+
download_model_files,
24+
NodeModelDownloadable,
2325
)
24-
from ray.llm._internal.common.utils.cloud_utils import is_remote_path
2526
from ray.llm._internal.utils import try_import
2627
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
2728

@@ -476,7 +477,12 @@ def __init__(
476477
logger.info("Max pending requests is set to %d", self.max_pending_requests)
477478

478479
# Download the model if needed.
479-
model_source = download_hf_model(self.model, tokenizer_only=False)
480+
model_source = download_model_files(
481+
model_id=self.model,
482+
mirror_config=None,
483+
download_model=NodeModelDownloadable.MODEL_AND_TOKENIZER,
484+
download_extra_files=False,
485+
)
480486

481487
# Create an LLM engine.
482488
self.llm = vLLMEngineWrapper(
Lines changed: 1 addition & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
11
"""Utility functions for batch processing."""
22
import logging
3-
import os
4-
from typing import TYPE_CHECKING, Any, Optional, Union
5-
6-
from ray.llm._internal.common.utils.cloud_utils import (
7-
CloudMirrorConfig,
8-
is_remote_path,
9-
)
10-
from ray.llm._internal.common.utils.download_utils import CloudModelDownloader
3+
from typing import TYPE_CHECKING, Any, Union
114

125
if TYPE_CHECKING:
136
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
@@ -67,51 +60,3 @@ def __len__(self):
6760

6861
tokenizer.__class__ = CachedTokenizer
6962
return tokenizer
70-
71-
72-
def download_hf_model(model_source: str, tokenizer_only: bool = True) -> str:
73-
"""Download the HF model from the model source.
74-
75-
Args:
76-
model_source: The model source path.
77-
tokenizer_only: Whether to download only the tokenizer.
78-
79-
Returns:
80-
The local path to the downloaded model.
81-
"""
82-
83-
bucket_uri = None
84-
if is_remote_path(model_source):
85-
bucket_uri = model_source
86-
87-
mirror_config = CloudMirrorConfig(bucket_uri=bucket_uri)
88-
downloader = CloudModelDownloader(model_source, mirror_config)
89-
return downloader.get_model(tokenizer_only=tokenizer_only)
90-
91-
92-
def download_lora_adapter(
93-
lora_name: str,
94-
remote_path: Optional[str] = None,
95-
) -> str:
96-
"""If remote_path is specified, pull the lora to the local
97-
directory and return the local path.
98-
99-
Args:
100-
lora_name: The lora name.
101-
remote_path: The remote path to the lora. If specified, the remote_path will be
102-
used as the base path to load the lora.
103-
104-
Returns:
105-
The local path to the lora if remote_path is specified, otherwise the lora name.
106-
"""
107-
assert not is_remote_path(
108-
lora_name
109-
), "lora_name cannot be a remote path (s3:// or gs://)"
110-
111-
if remote_path is None:
112-
return lora_name
113-
114-
lora_path = os.path.join(remote_path, lora_name)
115-
mirror_config = CloudMirrorConfig(bucket_uri=lora_path)
116-
downloader = CloudModelDownloader(lora_name, mirror_config)
117-
return downloader.get_model(tokenizer_only=False)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import logging
2+
from typing import Optional
3+
4+
from ray._private.ray_logging.filters import CoreContextFilter
5+
6+
7+
def _setup_logger(logger_name: str):
8+
"""Setup logger given the logger name.
9+
10+
This function is idempotent and won't set up the same logger multiple times. It will
11+
also skip the setup if logger is already setup and has handlers.
12+
13+
Args:
14+
logger_name: logger name used to get the logger.
15+
"""
16+
logger = logging.getLogger(logger_name)
17+
llm_logger = logging.getLogger("ray.llm")
18+
19+
# Skip setup if the logger already has handlers setup or if the parent (Data
20+
# logger) has handlers.
21+
if logger.handlers or llm_logger.handlers:
22+
return
23+
24+
# Set up stream handler, which logs to console as plaintext.
25+
stream_handler = logging.StreamHandler()
26+
stream_handler.addFilter(CoreContextFilter())
27+
logger.addHandler(stream_handler)
28+
logger.setLevel(logging.INFO)
29+
logger.propagate = False
30+
31+
32+
def get_logger(name: Optional[str] = None):
33+
"""Get a structured logger inherited from the Ray Data logger.
34+
35+
Loggers by default are logging to stdout, and are expected to be scraped by an
36+
external process.
37+
"""
38+
logger_name = f"ray.llm.{name}"
39+
_setup_logger(logger_name)
40+
return logging.getLogger(logger_name)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import logging
2+
3+
from ray._private.ray_logging.filters import CoreContextFilter
4+
from ray._private.ray_logging.formatters import JSONFormatter
5+
6+
7+
def _configure_stdlib_logging():
8+
"""Configures stdlib root logger to make sure stdlib loggers (created as
9+
`logging.getLogger(...)`) are using Ray's `JSONFormatter` with Core and Serve
10+
context filters.
11+
"""
12+
13+
handler = logging.StreamHandler()
14+
handler.addFilter(CoreContextFilter())
15+
handler.setFormatter(JSONFormatter())
16+
17+
root_logger = logging.getLogger()
18+
# NOTE: It's crucial we reset all the handlers of the root logger,
19+
# to make sure that logs aren't emitted twice
20+
root_logger.handlers = []
21+
root_logger.addHandler(handler)
22+
root_logger.setLevel(logging.INFO)
23+
24+
25+
def setup_logging():
26+
_configure_stdlib_logging()

python/ray/llm/_internal/common/utils/cloud_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# Use pyarrow for cloud storage access
2020
import pyarrow.fs as pa_fs
2121

22-
from ray.llm._internal.serve.observability.logging import get_logger
22+
from ray.llm._internal.common.observability.logging import get_logger
2323
from ray.llm._internal.common.base_pydantic import BaseModelExtended
2424

2525

0 commit comments

Comments
 (0)