Skip to content

Commit baed180

Browse files
authored
[tech debt] Revisit lora request model checker (#20636)
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
1 parent 0b40747 commit baed180

File tree

3 files changed

+65
-62
lines changed

3 files changed

+65
-62
lines changed

tests/entrypoints/openai/test_serving_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ async def test_load_lora_adapter_success():
5757
response = await serving_models.load_lora_adapter(request)
5858
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
5959
assert len(serving_models.lora_requests) == 1
60-
assert serving_models.lora_requests[0].lora_name == "adapter"
60+
assert "adapter" in serving_models.lora_requests
61+
assert serving_models.lora_requests["adapter"].lora_name == "adapter"
6162

6263

6364
@pytest.mark.asyncio

vllm/entrypoints/openai/serving_engine.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -438,9 +438,7 @@ async def _check_model(
438438

439439
if self._is_model_supported(request.model):
440440
return None
441-
if request.model in [
442-
lora.lora_name for lora in self.models.lora_requests
443-
]:
441+
if request.model in self.models.lora_requests:
444442
return None
445443
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and (
446444
load_result := await self.models.resolve_lora(request.model)):
@@ -466,9 +464,8 @@ def _maybe_get_adapters(
466464
None, PromptAdapterRequest]]:
467465
if self._is_model_supported(request.model):
468466
return None, None
469-
for lora in self.models.lora_requests:
470-
if request.model == lora.lora_name:
471-
return lora, None
467+
if request.model in self.models.lora_requests:
468+
return self.models.lora_requests[request.model], None
472469
for prompt_adapter in self.models.prompt_adapter_requests:
473470
if request.model == prompt_adapter.prompt_adapter_name:
474471
return None, prompt_adapter

vllm/entrypoints/openai/serving_models.py

Lines changed: 60 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,13 @@ def __init__(
6565
super().__init__()
6666

6767
self.base_model_paths = base_model_paths
68+
6869
self.max_model_len = model_config.max_model_len
6970
self.engine_client = engine_client
7071
self.model_config = model_config
7172

7273
self.static_lora_modules = lora_modules
73-
self.lora_requests: list[LoRARequest] = []
74+
self.lora_requests: dict[str, LoRARequest] = {}
7475
self.lora_id_counter = AtomicCounter(0)
7576

7677
self.lora_resolvers: list[LoRAResolver] = []
@@ -138,7 +139,7 @@ async def show_available_models(self) -> ModelList:
138139
parent=lora.base_model_name if lora.base_model_name else
139140
self.base_model_paths[0].name,
140141
permission=[ModelPermission()])
141-
for lora in self.lora_requests
142+
for lora in self.lora_requests.values()
142143
]
143144
prompt_adapter_cards = [
144145
ModelCard(id=prompt_adapter.prompt_adapter_name,
@@ -155,53 +156,60 @@ async def load_lora_adapter(
155156
request: LoadLoRAAdapterRequest,
156157
base_model_name: Optional[str] = None
157158
) -> Union[ErrorResponse, str]:
158-
error_check_ret = await self._check_load_lora_adapter_request(request)
159-
if error_check_ret is not None:
160-
return error_check_ret
161-
162-
lora_name, lora_path = request.lora_name, request.lora_path
163-
unique_id = self.lora_id_counter.inc(1)
164-
lora_request = LoRARequest(lora_name=lora_name,
165-
lora_int_id=unique_id,
166-
lora_path=lora_path)
167-
if base_model_name is not None and self.is_base_model(base_model_name):
168-
lora_request.base_model_name = base_model_name
169-
170-
# Validate that the adapter can be loaded into the engine
171-
# This will also pre-load it for incoming requests
172-
try:
173-
await self.engine_client.add_lora(lora_request)
174-
except BaseException as e:
175-
error_type = "BadRequestError"
176-
status_code = HTTPStatus.BAD_REQUEST
177-
if "No adapter found" in str(e):
178-
error_type = "NotFoundError"
179-
status_code = HTTPStatus.NOT_FOUND
180-
181-
return create_error_response(message=str(e),
182-
err_type=error_type,
183-
status_code=status_code)
184-
185-
self.lora_requests.append(lora_request)
186-
logger.info("Loaded new LoRA adapter: name '%s', path '%s'", lora_name,
187-
lora_path)
188-
return f"Success: LoRA adapter '{lora_name}' added successfully."
159+
lora_name = request.lora_name
160+
161+
# Ensure atomicity based on the lora name
162+
async with self.lora_resolver_lock[lora_name]:
163+
error_check_ret = await self._check_load_lora_adapter_request(
164+
request)
165+
if error_check_ret is not None:
166+
return error_check_ret
167+
168+
lora_path = request.lora_path
169+
unique_id = self.lora_id_counter.inc(1)
170+
lora_request = LoRARequest(lora_name=lora_name,
171+
lora_int_id=unique_id,
172+
lora_path=lora_path)
173+
if base_model_name is not None and self.is_base_model(
174+
base_model_name):
175+
lora_request.base_model_name = base_model_name
176+
177+
# Validate that the adapter can be loaded into the engine
178+
# This will also pre-load it for incoming requests
179+
try:
180+
await self.engine_client.add_lora(lora_request)
181+
except Exception as e:
182+
error_type = "BadRequestError"
183+
status_code = HTTPStatus.BAD_REQUEST
184+
if "No adapter found" in str(e):
185+
error_type = "NotFoundError"
186+
status_code = HTTPStatus.NOT_FOUND
187+
188+
return create_error_response(message=str(e),
189+
err_type=error_type,
190+
status_code=status_code)
191+
192+
self.lora_requests[lora_name] = lora_request
193+
logger.info("Loaded new LoRA adapter: name '%s', path '%s'",
194+
lora_name, lora_path)
195+
return f"Success: LoRA adapter '{lora_name}' added successfully."
189196

190197
async def unload_lora_adapter(
191198
self,
192199
request: UnloadLoRAAdapterRequest) -> Union[ErrorResponse, str]:
193-
error_check_ret = await self._check_unload_lora_adapter_request(request
194-
)
195-
if error_check_ret is not None:
196-
return error_check_ret
197-
198200
lora_name = request.lora_name
199-
self.lora_requests = [
200-
lora_request for lora_request in self.lora_requests
201-
if lora_request.lora_name != lora_name
202-
]
203-
logger.info("Removed LoRA adapter: name '%s'", lora_name)
204-
return f"Success: LoRA adapter '{lora_name}' removed successfully."
201+
202+
# Ensure atomicity based on the lora name
203+
async with self.lora_resolver_lock[lora_name]:
204+
error_check_ret = await self._check_unload_lora_adapter_request(
205+
request)
206+
if error_check_ret is not None:
207+
return error_check_ret
208+
209+
# Safe to delete now since we hold the lock
210+
del self.lora_requests[lora_name]
211+
logger.info("Removed LoRA adapter: name '%s'", lora_name)
212+
return f"Success: LoRA adapter '{lora_name}' removed successfully."
205213

206214
async def _check_load_lora_adapter_request(
207215
self, request: LoadLoRAAdapterRequest) -> Optional[ErrorResponse]:
@@ -213,8 +221,7 @@ async def _check_load_lora_adapter_request(
213221
status_code=HTTPStatus.BAD_REQUEST)
214222

215223
# Check if the lora adapter with the given name already exists
216-
if any(lora_request.lora_name == request.lora_name
217-
for lora_request in self.lora_requests):
224+
if request.lora_name in self.lora_requests:
218225
return create_error_response(
219226
message=
220227
f"The lora adapter '{request.lora_name}' has already been "
@@ -227,17 +234,16 @@ async def _check_load_lora_adapter_request(
227234
async def _check_unload_lora_adapter_request(
228235
self,
229236
request: UnloadLoRAAdapterRequest) -> Optional[ErrorResponse]:
230-
# Check if either 'lora_name' or 'lora_int_id' is provided
231-
if not request.lora_name and not request.lora_int_id:
237+
# Check if 'lora_name' is not provided return an error
238+
if not request.lora_name:
232239
return create_error_response(
233240
message=
234-
"either 'lora_name' and 'lora_int_id' needs to be provided.",
241+
"'lora_name' needs to be provided to unload a LoRA adapter.",
235242
err_type="InvalidUserInput",
236243
status_code=HTTPStatus.BAD_REQUEST)
237244

238245
# Check if the lora adapter with the given name exists
239-
if not any(lora_request.lora_name == request.lora_name
240-
for lora_request in self.lora_requests):
246+
if request.lora_name not in self.lora_requests:
241247
return create_error_response(
242248
message=
243249
f"The lora adapter '{request.lora_name}' cannot be found.",
@@ -260,9 +266,8 @@ async def resolve_lora(
260266
"""
261267
async with self.lora_resolver_lock[lora_name]:
262268
# First check if this LoRA is already loaded
263-
for existing in self.lora_requests:
264-
if existing.lora_name == lora_name:
265-
return existing
269+
if lora_name in self.lora_requests:
270+
return self.lora_requests[lora_name]
266271

267272
base_model_name = self.model_config.model
268273
unique_id = self.lora_id_counter.inc(1)
@@ -279,7 +284,7 @@ async def resolve_lora(
279284

280285
try:
281286
await self.engine_client.add_lora(lora_request)
282-
self.lora_requests.append(lora_request)
287+
self.lora_requests[lora_name] = lora_request
283288
logger.info(
284289
"Resolved and loaded LoRA adapter '%s' using %s",
285290
lora_name, resolver.__class__.__name__)

0 commit comments

Comments
 (0)