@@ -65,12 +65,13 @@ def __init__(
65
65
super ().__init__ ()
66
66
67
67
self .base_model_paths = base_model_paths
68
+
68
69
self .max_model_len = model_config .max_model_len
69
70
self .engine_client = engine_client
70
71
self .model_config = model_config
71
72
72
73
self .static_lora_modules = lora_modules
73
- self .lora_requests : list [ LoRARequest ] = []
74
+ self .lora_requests : dict [ str , LoRARequest ] = {}
74
75
self .lora_id_counter = AtomicCounter (0 )
75
76
76
77
self .lora_resolvers : list [LoRAResolver ] = []
@@ -138,7 +139,7 @@ async def show_available_models(self) -> ModelList:
138
139
parent = lora .base_model_name if lora .base_model_name else
139
140
self .base_model_paths [0 ].name ,
140
141
permission = [ModelPermission ()])
141
- for lora in self .lora_requests
142
+ for lora in self .lora_requests . values ()
142
143
]
143
144
prompt_adapter_cards = [
144
145
ModelCard (id = prompt_adapter .prompt_adapter_name ,
@@ -155,53 +156,60 @@ async def load_lora_adapter(
155
156
request : LoadLoRAAdapterRequest ,
156
157
base_model_name : Optional [str ] = None
157
158
) -> Union [ErrorResponse , str ]:
158
- error_check_ret = await self ._check_load_lora_adapter_request (request )
159
- if error_check_ret is not None :
160
- return error_check_ret
161
-
162
- lora_name , lora_path = request .lora_name , request .lora_path
163
- unique_id = self .lora_id_counter .inc (1 )
164
- lora_request = LoRARequest (lora_name = lora_name ,
165
- lora_int_id = unique_id ,
166
- lora_path = lora_path )
167
- if base_model_name is not None and self .is_base_model (base_model_name ):
168
- lora_request .base_model_name = base_model_name
169
-
170
- # Validate that the adapter can be loaded into the engine
171
- # This will also pre-load it for incoming requests
172
- try :
173
- await self .engine_client .add_lora (lora_request )
174
- except BaseException as e :
175
- error_type = "BadRequestError"
176
- status_code = HTTPStatus .BAD_REQUEST
177
- if "No adapter found" in str (e ):
178
- error_type = "NotFoundError"
179
- status_code = HTTPStatus .NOT_FOUND
180
-
181
- return create_error_response (message = str (e ),
182
- err_type = error_type ,
183
- status_code = status_code )
184
-
185
- self .lora_requests .append (lora_request )
186
- logger .info ("Loaded new LoRA adapter: name '%s', path '%s'" , lora_name ,
187
- lora_path )
188
- return f"Success: LoRA adapter '{ lora_name } ' added successfully."
159
+ lora_name = request .lora_name
160
+
161
+ # Ensure atomicity based on the lora name
162
+ async with self .lora_resolver_lock [lora_name ]:
163
+ error_check_ret = await self ._check_load_lora_adapter_request (
164
+ request )
165
+ if error_check_ret is not None :
166
+ return error_check_ret
167
+
168
+ lora_path = request .lora_path
169
+ unique_id = self .lora_id_counter .inc (1 )
170
+ lora_request = LoRARequest (lora_name = lora_name ,
171
+ lora_int_id = unique_id ,
172
+ lora_path = lora_path )
173
+ if base_model_name is not None and self .is_base_model (
174
+ base_model_name ):
175
+ lora_request .base_model_name = base_model_name
176
+
177
+ # Validate that the adapter can be loaded into the engine
178
+ # This will also pre-load it for incoming requests
179
+ try :
180
+ await self .engine_client .add_lora (lora_request )
181
+ except Exception as e :
182
+ error_type = "BadRequestError"
183
+ status_code = HTTPStatus .BAD_REQUEST
184
+ if "No adapter found" in str (e ):
185
+ error_type = "NotFoundError"
186
+ status_code = HTTPStatus .NOT_FOUND
187
+
188
+ return create_error_response (message = str (e ),
189
+ err_type = error_type ,
190
+ status_code = status_code )
191
+
192
+ self .lora_requests [lora_name ] = lora_request
193
+ logger .info ("Loaded new LoRA adapter: name '%s', path '%s'" ,
194
+ lora_name , lora_path )
195
+ return f"Success: LoRA adapter '{ lora_name } ' added successfully."
189
196
190
197
async def unload_lora_adapter (
191
198
self ,
192
199
request : UnloadLoRAAdapterRequest ) -> Union [ErrorResponse , str ]:
193
- error_check_ret = await self ._check_unload_lora_adapter_request (request
194
- )
195
- if error_check_ret is not None :
196
- return error_check_ret
197
-
198
200
lora_name = request .lora_name
199
- self .lora_requests = [
200
- lora_request for lora_request in self .lora_requests
201
- if lora_request .lora_name != lora_name
202
- ]
203
- logger .info ("Removed LoRA adapter: name '%s'" , lora_name )
204
- return f"Success: LoRA adapter '{ lora_name } ' removed successfully."
201
+
202
+ # Ensure atomicity based on the lora name
203
+ async with self .lora_resolver_lock [lora_name ]:
204
+ error_check_ret = await self ._check_unload_lora_adapter_request (
205
+ request )
206
+ if error_check_ret is not None :
207
+ return error_check_ret
208
+
209
+ # Safe to delete now since we hold the lock
210
+ del self .lora_requests [lora_name ]
211
+ logger .info ("Removed LoRA adapter: name '%s'" , lora_name )
212
+ return f"Success: LoRA adapter '{ lora_name } ' removed successfully."
205
213
206
214
async def _check_load_lora_adapter_request (
207
215
self , request : LoadLoRAAdapterRequest ) -> Optional [ErrorResponse ]:
@@ -213,8 +221,7 @@ async def _check_load_lora_adapter_request(
213
221
status_code = HTTPStatus .BAD_REQUEST )
214
222
215
223
# Check if the lora adapter with the given name already exists
216
- if any (lora_request .lora_name == request .lora_name
217
- for lora_request in self .lora_requests ):
224
+ if request .lora_name in self .lora_requests :
218
225
return create_error_response (
219
226
message =
220
227
f"The lora adapter '{ request .lora_name } ' has already been "
@@ -227,17 +234,16 @@ async def _check_load_lora_adapter_request(
227
234
async def _check_unload_lora_adapter_request (
228
235
self ,
229
236
request : UnloadLoRAAdapterRequest ) -> Optional [ErrorResponse ]:
230
- # Check if either 'lora_name' or 'lora_int_id' is provided
231
- if not request .lora_name and not request . lora_int_id :
237
+ # Check if 'lora_name' is not provided return an error
238
+ if not request .lora_name :
232
239
return create_error_response (
233
240
message =
234
- "either 'lora_name' and 'lora_int_id' needs to be provided." ,
241
+ "'lora_name' needs to be provided to unload a LoRA adapter ." ,
235
242
err_type = "InvalidUserInput" ,
236
243
status_code = HTTPStatus .BAD_REQUEST )
237
244
238
245
# Check if the lora adapter with the given name exists
239
- if not any (lora_request .lora_name == request .lora_name
240
- for lora_request in self .lora_requests ):
246
+ if request .lora_name not in self .lora_requests :
241
247
return create_error_response (
242
248
message =
243
249
f"The lora adapter '{ request .lora_name } ' cannot be found." ,
@@ -260,9 +266,8 @@ async def resolve_lora(
260
266
"""
261
267
async with self .lora_resolver_lock [lora_name ]:
262
268
# First check if this LoRA is already loaded
263
- for existing in self .lora_requests :
264
- if existing .lora_name == lora_name :
265
- return existing
269
+ if lora_name in self .lora_requests :
270
+ return self .lora_requests [lora_name ]
266
271
267
272
base_model_name = self .model_config .model
268
273
unique_id = self .lora_id_counter .inc (1 )
@@ -279,7 +284,7 @@ async def resolve_lora(
279
284
280
285
try :
281
286
await self .engine_client .add_lora (lora_request )
282
- self .lora_requests . append ( lora_request )
287
+ self .lora_requests [ lora_name ] = lora_request
283
288
logger .info (
284
289
"Resolved and loaded LoRA adapter '%s' using %s" ,
285
290
lora_name , resolver .__class__ .__name__ )
0 commit comments