1
1
"""Async I/O protocol implementation for the LM Studio remote access API."""
2
2
3
3
import asyncio
4
- import asyncio .queues
5
4
import warnings
6
5
7
6
from abc import abstractmethod
28
27
TypeIs ,
29
28
)
30
29
30
+ from anyio import create_task_group
31
+ from anyio .abc import TaskGroup
31
32
from httpx import RequestError , HTTPStatusError
32
33
from httpx_ws import aconnect_ws , AsyncWebSocketSession , HTTPXWSException
33
34
@@ -163,7 +164,10 @@ async def rx_stream(
163
164
# Avoid emitting tracebacks that delve into supporting libraries
164
165
# (we can't easily suppress the SDK's own frames for iterators)
165
166
message = await self ._rx_queue .get ()
166
- contents = self ._api_channel .handle_rx_message (message )
167
+ if message is None :
168
+ contents = None
169
+ else :
170
+ contents = self ._api_channel .handle_rx_message (message )
167
171
if contents is None :
168
172
self ._is_finished = True
169
173
break
@@ -204,6 +208,8 @@ def get_rpc_message(
204
208
async def receive_result (self ) -> Any :
205
209
"""Receive call response on the receive queue."""
206
210
message = await self ._rx_queue .get ()
211
+ if message is None :
212
+ return None
207
213
return self ._rpc .handle_rx_message (message )
208
214
209
215
@@ -220,8 +226,10 @@ def __init__(
220
226
) -> None :
221
227
"""Initialize asynchronous websocket client."""
222
228
super ().__init__ (ws_url , auth_details , log_context )
223
- self ._resource_manager = AsyncExitStack ()
229
+ self ._resource_manager = rm = AsyncExitStack ()
230
+ rm .push_async_callback (self ._notify_client_termination )
224
231
self ._rx_task : asyncio .Task [None ] | None = None
232
+ self ._terminate = asyncio .Event ()
225
233
226
234
@property
227
235
def _httpx_ws (self ) -> AsyncWebSocketSession | None :
@@ -241,7 +249,9 @@ async def __aexit__(self, *args: Any) -> None:
241
249
async def _send_json (self , message : DictObject ) -> None :
242
250
# Callers are expected to call `_ensure_connected` before this method
243
251
ws = self ._ws
244
- assert ws is not None
252
+ if ws is None :
253
+ # Assume app is shutting down and the owning task has already been cancelled
254
+ return
245
255
try :
246
256
await ws .send_json (message )
247
257
except Exception as exc :
@@ -253,7 +263,9 @@ async def _send_json(self, message: DictObject) -> None:
253
263
async def _receive_json (self ) -> Any :
254
264
# Callers are expected to call `_ensure_connected` before this method
255
265
ws = self ._ws
256
- assert ws is not None
266
+ if ws is None :
267
+ # Assume app is shutting down and the owning task has already been cancelled
268
+ return
257
269
try :
258
270
return await ws .receive_json ()
259
271
except Exception as exc :
@@ -291,7 +303,7 @@ async def connect(self) -> Self:
291
303
self ._rx_task = rx_task = asyncio .create_task (self ._receive_messages ())
292
304
293
305
async def _terminate_rx_task () -> None :
294
- rx_task . cancel ()
306
+ self . _terminate . set ()
295
307
try :
296
308
await rx_task
297
309
except asyncio .CancelledError :
@@ -305,19 +317,34 @@ async def disconnect(self) -> None:
305
317
"""Drop the LM Studio API connection."""
306
318
self ._ws = None
307
319
self ._rx_task = None
308
- await self ._notify_client_termination ()
320
+ self ._terminate . set ()
309
321
await self ._resource_manager .aclose ()
310
322
self ._logger .info (f"Websocket session disconnected ({ self ._ws_url } )" )
311
323
312
324
aclose = disconnect
313
325
326
+ async def _cancel_on_termination (self , tg : TaskGroup ) -> None :
327
+ await self ._terminate .wait ()
328
+ tg .cancel_scope .cancel ()
329
+
314
330
async def _process_next_message (self ) -> bool :
315
331
"""Process the next message received on the websocket.
316
332
317
333
Returns True if a message queue was updated.
318
334
"""
319
335
self ._ensure_connected ("receive messages" )
320
- message = await self ._receive_json ()
336
+ async with create_task_group () as tg :
337
+ tg .start_soon (self ._cancel_on_termination , tg )
338
+ try :
339
+ message = await self ._receive_json ()
340
+ except (LMStudioWebsocketError , HTTPXWSException ):
341
+ if self ._ws is not None and not self ._terminate .is_set ():
342
+ # Websocket failed unexpectedly (rather than due to client shutdown)
343
+ self ._logger .error ("Websocket failed, terminating session." )
344
+ self ._terminate .set ()
345
+ tg .cancel_scope .cancel ()
346
+ if self ._terminate .is_set ():
347
+ return (await self ._notify_client_termination ()) > 0
321
348
rx_queue = self ._mux .map_rx_message (message )
322
349
if rx_queue is None :
323
350
return False
@@ -326,18 +353,20 @@ async def _process_next_message(self) -> bool:
326
353
327
354
async def _receive_messages (self ) -> None :
328
355
"""Process received messages until connection is terminated."""
329
- while True :
330
- try :
331
- await self ._process_next_message ()
332
- except (LMStudioWebsocketError , HTTPXWSException ):
333
- self ._logger .exception ("Websocket failed, terminating session." )
334
- await self .disconnect ()
335
- break
356
+ while not self ._terminate .is_set ():
357
+ await self ._process_next_message ()
336
358
337
- async def _notify_client_termination (self ) -> None :
359
+ async def _notify_client_termination (self ) -> int :
338
360
"""Send None to all clients with open receive queues."""
361
+ num_clients = 0
339
362
for rx_queue in self ._mux .all_queues ():
340
363
await rx_queue .put (None )
364
+ num_clients += 1
365
+ self ._logger .info (
366
+ f"Notified { num_clients } clients of websocket termination" ,
367
+ num_clients = num_clients ,
368
+ )
369
+ return num_clients
341
370
342
371
async def _connect_to_endpoint (self , channel : AsyncChannel [Any ]) -> None :
343
372
"""Connect channel to specified endpoint."""
@@ -362,6 +391,9 @@ async def open_channel(
362
391
self ._logger .event_context ,
363
392
)
364
393
await self ._connect_to_endpoint (channel )
394
+ if self ._terminate .is_set ():
395
+ # Link has been terminated, ensure client gets a response
396
+ await rx_queue .put (None )
365
397
yield channel
366
398
367
399
async def _send_call (
@@ -396,6 +428,9 @@ async def remote_call(
396
428
call_id , rx_queue , self ._logger .event_context , notice_prefix
397
429
)
398
430
await self ._send_call (rpc , endpoint , params )
431
+ if self ._terminate .is_set ():
432
+ # Link has been terminated, ensure client gets a response
433
+ await rx_queue .put (None )
399
434
return await rpc .receive_result ()
400
435
401
436
0 commit comments