Skip to content

Commit 3d30db2

Browse files
authored
[ray.data.llm] Support S3 paths for model checkpoint and LoRA path (#51103)
## Why are these changes needed? - Pre-download model checkpoints from remote path (e.g. S3). - Update the document about how to use RunAI streamer in vLLM to load model checkpoints from remote path. - Support remote path in LoRA. ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [x] I've included any doc changes needed for https://docs.ray.io/en/master/. - [x] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [x] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [ ] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
1 parent 6268a2d commit 3d30db2

File tree

16 files changed

+304
-41
lines changed

16 files changed

+304
-41
lines changed

.vale/styles/config/vocabularies/Data/accept.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Predibase('s)?
2424
[Pp]reprocess
2525
[Pp]reprocessor(s)?
2626
[Pp]ushdown
27+
runai
2728
[Ss]calers
2829
Spotify('s)?
2930
TFRecord(s)?

doc/source/data/working-with-llms.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,25 @@ Some models may require a Hugging Face token to be specified. You can specify th
8888
batch_size=64,
8989
)
9090

91+
If your model is hosted on AWS S3, you can specify the S3 path in the `model_source` argument, and specify `load_format="runai_streamer"` in the `engine_kwargs` argument.
92+
93+
.. note::
94+
Install vLLM with runai dependencies: `pip install -U "vllm[runai]==0.7.2"`
95+
96+
.. testcode::
97+
98+
config = vLLMEngineProcessorConfig(
99+
model_source="s3://your-bucket/your-model/", # Make sure adding the trailing slash!
100+
engine_kwargs={"load_format": "runai_streamer"},
101+
runtime_env={"env_vars": {
102+
"AWS_ACCESS_KEY_ID": "your_access_key_id",
103+
"AWS_SECRET_ACCESS_KEY": "your_secret_access_key",
104+
"AWS_REGION": "your_region",
105+
}},
106+
concurrency=1,
107+
batch_size=64,
108+
)
109+
91110
.. _vllm_llm:
92111

93112
Configure vLLM for LLM inference

doc/source/llm/examples/batch/vllm-with-lora.ipynb

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
" batch_size=16,\n",
5656
" # Use one GPU in this example.\n",
5757
" concurrency=1,\n",
58+
" # If you save the LoRA adapter in S3, you can set the following path.\n",
59+
" # dynamic_lora_loading_path=\"s3://your-lora-bucket/\",\n",
5860
")\n",
5961
"\n",
6062
"# 2. Construct a processor using the processor config.\n",
@@ -66,6 +68,9 @@
6668
" # from the model you specify in the processor config, then this\n",
6769
" # is the LoRA adapter. The \"model\" here can be a LoRA adapter\n",
6870
" # available in the HuggingFace Hub or a local path.\n",
71+
" #\n",
72+
" # If you set dynamic_lora_loading_path, then only specify the LoRA\n",
73+
" # path under dynamic_lora_loading_path.\n",
6974
" model=\"EdBergJr/Llama32_Baha_3\",\n",
7075
" messages=[\n",
7176
" {\"role\": \"system\",\n",

python/ray/llm/_internal/batch/processor/vllm_engine_proc.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,15 @@ class vLLMEngineProcessorConfig(ProcessorConfig):
9292
description="Whether the input messages have images.",
9393
)
9494

95+
# LoRA configurations.
96+
dynamic_lora_loading_path: Optional[str] = Field(
97+
default=None,
98+
description="The path to the dynamic LoRA adapter. It is expected "
99+
"to hold subfolders each for a different lora checkpoint. If not "
100+
"specified and LoRA is enabled, then the 'model' in LoRA "
101+
"requests will be interpreted as model ID used by HF transformers.",
102+
)
103+
95104
@root_validator(pre=True)
96105
def validate_task_type(cls, values):
97106
task_type_str = values.get("task_type", "generate")
@@ -169,6 +178,7 @@ def build_vllm_engine_processor(
169178
engine_kwargs=config.engine_kwargs,
170179
task_type=config.task_type,
171180
max_pending_requests=config.max_pending_requests,
181+
dynamic_lora_loading_path=config.dynamic_lora_loading_path,
172182
),
173183
map_batches_kwargs=dict(
174184
zero_copy_batch=True,

python/ray/llm/_internal/batch/stages/chat_template_stage.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
StatefulStage,
77
StatefulStageUDF,
88
)
9+
from ray.llm._internal.batch.utils import download_hf_model
910

1011

1112
class ChatTemplateUDF(StatefulStageUDF):
@@ -33,7 +34,10 @@ def __init__(
3334
# because tokenizers of VLM models may not have chat template attribute.
3435
# However, this may not be a reliable solution, because processors and
3536
# tokenizers are not standardized across different models.
36-
self.processor = AutoProcessor.from_pretrained(model, trust_remote_code=True)
37+
model_path = download_hf_model(model, tokenizer_only=True)
38+
self.processor = AutoProcessor.from_pretrained(
39+
model_path, trust_remote_code=True
40+
)
3741
self.chat_template = chat_template
3842

3943
async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]]:

python/ray/llm/_internal/batch/stages/tokenize_stage.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
StatefulStage,
77
StatefulStageUDF,
88
)
9-
from ray.llm._internal.batch.utils import get_cached_tokenizer
9+
from ray.llm._internal.batch.utils import (
10+
get_cached_tokenizer,
11+
download_hf_model,
12+
)
1013

1114

1215
class TokenizeUDF(StatefulStageUDF):
@@ -25,9 +28,10 @@ def __init__(
2528
from transformers import AutoTokenizer
2629

2730
super().__init__(data_column)
31+
model_path = download_hf_model(model, tokenizer_only=True)
2832
self.tokenizer = get_cached_tokenizer(
2933
AutoTokenizer.from_pretrained(
30-
model,
34+
model_path,
3135
trust_remote_code=True,
3236
)
3337
)
@@ -81,9 +85,10 @@ def __init__(
8185
from transformers import AutoTokenizer
8286

8387
super().__init__(data_column)
88+
model_path = download_hf_model(model, tokenizer_only=True)
8489
self.tokenizer = get_cached_tokenizer(
8590
AutoTokenizer.from_pretrained(
86-
model,
91+
model_path,
8792
trust_remote_code=True,
8893
)
8994
)

python/ray/llm/_internal/batch/stages/vllm_engine_stage.py

Lines changed: 72 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
StatefulStage,
1818
StatefulStageUDF,
1919
)
20+
from ray.llm._internal.batch.utils import (
21+
download_lora_adapter,
22+
download_hf_model,
23+
)
24+
from ray.llm._internal.common.utils.cloud_utils import is_remote_path
2025
from ray.llm._internal.utils import try_import
2126
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
2227

@@ -119,22 +124,30 @@ class vLLMEngineWrapper:
119124
Args:
120125
*args: The positional arguments for the engine.
121126
max_pending_requests: The maximum number of pending requests in the queue.
127+
dynamic_lora_loading_path: The S3 path to the dynamic LoRA adapter.
122128
**kwargs: The keyword arguments for the engine.
123129
"""
124130

125131
def __init__(
126132
self,
127133
idx_in_batch_column: str,
128134
max_pending_requests: int = -1,
135+
dynamic_lora_loading_path: Optional[str] = None,
129136
**kwargs,
130137
):
131138
self.request_id = 0
132139
self.idx_in_batch_column = idx_in_batch_column
133140
self.task_type = kwargs.get("task", vLLMTaskType.GENERATE)
134-
self.model = kwargs.get("model", None)
135-
assert self.model is not None
141+
142+
# Use model_source in kwargs["model"] because "model" is actually
143+
# the model source in vLLM.
144+
self.model = kwargs.pop("model", None)
145+
self.model_source = kwargs.pop("model_source", None)
146+
assert self.model is not None and self.model_source is not None
147+
kwargs["model"] = self.model_source
136148

137149
# LoRA related.
150+
self.dynamic_lora_loading_path = dynamic_lora_loading_path
138151
self.lora_lock = asyncio.Lock()
139152
self.lora_name_to_request = {}
140153

@@ -196,47 +209,77 @@ def _maybe_convert_ndarray_to_list(self, params: Dict[str, Any]) -> Dict[str, An
196209
return params.tolist()
197210
return params
198211

199-
async def _prepare_llm_request(self, row: Dict[str, Any]) -> vLLMEngineRequest:
200-
"""Prepare the inputs for LLM inference.
212+
async def _maybe_get_lora_request(
213+
self,
214+
row: Dict[str, Any],
215+
) -> Optional[Any]:
216+
"""Get the LoRA request for the given row.
217+
Specifically, if the model name is given and is different from the model
218+
set in the config, then this request has LoRA.
201219
202220
Args:
203221
row: The row.
204222
205223
Returns:
206-
A single vLLMEngineRequest.
224+
The LoRA request (vllm.lora.request.LoRARequest),
225+
or None if there is no LoRA. We use Any in type hint to
226+
pass doc build in the environment without vLLM.
207227
"""
208-
prompt = row.pop("prompt")
209-
210-
if "tokenized_prompt" in row:
211-
tokenized_prompt = row.pop("tokenized_prompt").tolist()
212-
else:
213-
tokenized_prompt = None
214-
215-
if "image" in row:
216-
image = row.pop("image")
217-
else:
218-
image = []
219-
220-
# If the model name is given and is different from the model
221-
# set in the config, then this is a LoRA.
222228
lora_request = None
223229
if "model" in row and row["model"] != self.model:
224230
if self.vllm_use_v1:
225231
raise ValueError("LoRA is only supported with vLLM v0")
226232

227233
lora_name = row["model"]
228234
if lora_name not in self.lora_name_to_request:
235+
if is_remote_path(lora_name):
236+
raise ValueError(
237+
"LoRA name cannot be a remote path (s3:// or gs://). "
238+
"Please specify dynamic_lora_loading_path in the processor config."
239+
)
240+
229241
async with self.lora_lock:
230242
if lora_name not in self.lora_name_to_request:
231243
# Load a new LoRA adapter if it is not loaded yet.
244+
lora_path = download_lora_adapter(
245+
lora_name,
246+
remote_path=self.dynamic_lora_loading_path,
247+
)
248+
logger.info(
249+
"Downloaded LoRA adapter for %s to %s", lora_name, lora_path
250+
)
232251
lora_request = vllm.lora.request.LoRARequest(
233252
lora_name=lora_name,
234253
# LoRA ID starts from 1.
235254
lora_int_id=len(self.lora_name_to_request) + 1,
236-
lora_path=lora_name,
255+
lora_path=lora_path,
237256
)
238257
self.lora_name_to_request[lora_name] = lora_request
239258
lora_request = self.lora_name_to_request[lora_name]
259+
return lora_request
260+
261+
async def _prepare_llm_request(self, row: Dict[str, Any]) -> vLLMEngineRequest:
262+
"""Prepare the inputs for LLM inference.
263+
264+
Args:
265+
row: The row.
266+
267+
Returns:
268+
A single vLLMEngineRequest.
269+
"""
270+
prompt = row.pop("prompt")
271+
272+
if "tokenized_prompt" in row:
273+
tokenized_prompt = row.pop("tokenized_prompt").tolist()
274+
else:
275+
tokenized_prompt = None
276+
277+
if "image" in row:
278+
image = row.pop("image")
279+
else:
280+
image = []
281+
282+
lora_request = await self._maybe_get_lora_request(row)
240283

241284
# Prepare sampling parameters.
242285
if self.task_type == vLLMTaskType.GENERATE:
@@ -396,6 +439,7 @@ def __init__(
396439
engine_kwargs: Dict[str, Any],
397440
task_type: vLLMTaskType = vLLMTaskType.GENERATE,
398441
max_pending_requests: Optional[int] = None,
442+
dynamic_lora_loading_path: Optional[str] = None,
399443
):
400444
"""
401445
Initialize the vLLMEngineStageUDF.
@@ -407,6 +451,8 @@ def __init__(
407451
task_type: The task to use for the vLLM engine (e.g., "generate", "embed", etc).
408452
max_pending_requests: The maximum number of pending requests. If None,
409453
it will be set to 1.1 * max_num_seqs * pipeline_parallel_size.
454+
dynamic_lora_loading_path: The path to the dynamic LoRA adapter. It is expected
455+
to hold subfolders each for a different lora checkpoint.
410456
"""
411457
super().__init__(data_column)
412458
self.model = model
@@ -423,12 +469,17 @@ def __init__(
423469
if self.max_pending_requests > 0:
424470
logger.info("Max pending requests is set to %d", self.max_pending_requests)
425471

472+
# Download the model if needed.
473+
model_source = download_hf_model(self.model, tokenizer_only=False)
474+
426475
# Create an LLM engine.
427476
self.llm = vLLMEngineWrapper(
428477
model=self.model,
478+
model_source=model_source,
429479
idx_in_batch_column=self.IDX_IN_BATCH_COLUMN,
430480
disable_log_stats=False,
431481
max_pending_requests=self.max_pending_requests,
482+
dynamic_lora_loading_path=dynamic_lora_loading_path,
432483
**self.engine_kwargs,
433484
)
434485

@@ -518,6 +569,7 @@ def expected_input_keys(self) -> List[str]:
518569

519570
def __del__(self):
520571
if hasattr(self, "llm"):
572+
# Kill the engine processes.
521573
self.llm.shutdown()
522574

523575

python/ray/llm/_internal/batch/utils.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
"""Utility functions for batch processing."""
22
import logging
3-
from typing import TYPE_CHECKING, Any, Union
3+
import os
4+
from typing import TYPE_CHECKING, Any, Optional, Union
5+
6+
from ray.llm._internal.common.utils.cloud_utils import (
7+
CloudMirrorConfig,
8+
is_remote_path,
9+
)
10+
from ray.llm._internal.common.utils.download_utils import CloudModelDownloader
411

512
if TYPE_CHECKING:
613
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
@@ -60,3 +67,51 @@ def __len__(self):
6067

6168
tokenizer.__class__ = CachedTokenizer
6269
return tokenizer
70+
71+
72+
def download_hf_model(model_source: str, tokenizer_only: bool = True) -> str:
73+
"""Download the HF model from the model source.
74+
75+
Args:
76+
model_source: The model source path.
77+
tokenizer_only: Whether to download only the tokenizer.
78+
79+
Returns:
80+
The local path to the downloaded model.
81+
"""
82+
83+
bucket_uri = None
84+
if is_remote_path(model_source):
85+
bucket_uri = model_source
86+
87+
mirror_config = CloudMirrorConfig(bucket_uri=bucket_uri)
88+
downloader = CloudModelDownloader(model_source, mirror_config)
89+
return downloader.get_model(tokenizer_only=tokenizer_only)
90+
91+
92+
def download_lora_adapter(
93+
lora_name: str,
94+
remote_path: Optional[str] = None,
95+
) -> str:
96+
"""If remote_path is specified, pull the lora to the local
97+
directory and return the local path.
98+
99+
Args:
100+
lora_name: The lora name.
101+
remote_path: The remote path to the lora. If specified, the remote_path will be
102+
used as the base path to load the lora.
103+
104+
Returns:
105+
The local path to the lora if remote_path is specified, otherwise the lora name.
106+
"""
107+
assert not is_remote_path(
108+
lora_name
109+
), "lora_name cannot be a remote path (s3:// or gs://)"
110+
111+
if remote_path is None:
112+
return lora_name
113+
114+
lora_path = os.path.join(remote_path, lora_name)
115+
mirror_config = CloudMirrorConfig(bucket_uri=lora_path)
116+
downloader = CloudModelDownloader(lora_name, mirror_config)
117+
return downloader.get_model(tokenizer_only=False)

0 commit comments

Comments
 (0)