Skip to content

Commit f0fd18b

Browse files
authored
Improve websocket failure handling (#121)
* sync client calls are interrupted and subsequently refused * async clients avoid emitting structured concurrency errors
1 parent b3f7c66 commit f0fd18b

File tree

5 files changed

+143
-33
lines changed

5 files changed

+143
-33
lines changed

misc/open_client.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#!/usr/bin/env python
2+
"""Open a client instance for link failure testing."""
3+
import asyncio
4+
import logging
5+
import sys
6+
import time
7+
8+
from lmstudio import AsyncClient, Client
9+
10+
LINK_POLLING_INTERVAL = 1
11+
12+
async def open_client_async():
13+
"""Start async client, wait for link failure."""
14+
print("Connecting async client...")
15+
async with AsyncClient() as client:
16+
await client.list_downloaded_models()
17+
print ("Async client connected. Close LM Studio to terminate.")
18+
while True:
19+
await asyncio.sleep(LINK_POLLING_INTERVAL)
20+
await client.list_downloaded_models()
21+
22+
def open_client_sync():
23+
"""Start sync client, wait for link failure."""
24+
print("Connecting sync client...")
25+
with Client() as client:
26+
client.list_downloaded_models()
27+
print ("Sync client connected. Close LM Studio to terminate.")
28+
while True:
29+
time.sleep(LINK_POLLING_INTERVAL)
30+
client.list_downloaded_models()
31+
32+
if __name__ == "__main__":
33+
# Link polling makes debug logging excessively spammy
34+
log_level = logging.DEBUG if "--debug" in sys.argv else logging.INFO
35+
logging.basicConfig(level=log_level)
36+
if "--async" in sys.argv:
37+
asyncio.run(open_client_async())
38+
else:
39+
open_client_sync()

src/lmstudio/_ws_impl.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,8 @@ async def _log_thread_execution(self) -> None:
290290
try:
291291
# Run the event loop until termination is requested
292292
await never_set.wait()
293+
except asyncio.CancelledError:
294+
raise
293295
except BaseException:
294296
err_msg = "Terminating websocket thread due to exception"
295297
self._logger.debug(err_msg, exc_info=True)
@@ -309,7 +311,7 @@ def __init__(
309311
task_manager: AsyncTaskManager,
310312
ws_url: str,
311313
auth_details: DictObject,
312-
enqueue_message: Callable[[DictObject], bool],
314+
enqueue_message: Callable[[DictObject | None], Awaitable[bool]],
313315
log_context: LogEventContext | None = None,
314316
) -> None:
315317
self._auth_details = auth_details
@@ -357,14 +359,16 @@ async def _logged_ws_handler(self) -> None:
357359
self._logger.info("Websocket handling task started")
358360
try:
359361
await self._handle_ws()
362+
except asyncio.CancelledError:
363+
raise
360364
except BaseException:
361365
err_msg = "Terminating websocket task due to exception"
362366
self._logger.debug(err_msg, exc_info=True)
363367
finally:
364368
# Ensure the foreground thread is unblocked even if the
365369
# background async task errors out completely
366370
self._connection_attempted.set()
367-
self._logger.info("Websocket task terminated")
371+
self._logger.info("Websocket task terminated")
368372

369373
async def _handle_ws(self) -> None:
370374
assert self._task_manager.check_running_in_task_loop()
@@ -396,12 +400,19 @@ def _clear_task_state() -> None:
396400
await self._receive_messages()
397401
finally:
398402
self._logger.info("Websocket demultiplexing task terminated.")
403+
# Notify foreground thread of background thread termination
404+
# (this covers termination due to link failure)
405+
await self._enqueue_message(None)
399406
dc_timeout = self.WS_DISCONNECT_TIMEOUT
400407
with move_on_after(dc_timeout, shield=True) as cancel_scope:
401408
# Workaround an anyio/httpx-ws issue with task cancellation:
402409
# https://github.com/frankie567/httpx-ws/issues/107
403410
self._ws = None
404-
await ws.close()
411+
try:
412+
await ws.close()
413+
except Exception:
414+
# Closing may fail if the link is already down
415+
pass
405416
if cancel_scope.cancelled_caught:
406417
self._logger.warn(
407418
f"Failed to close websocket in {dc_timeout} seconds."
@@ -413,7 +424,9 @@ async def send_json(self, message: DictObject) -> None:
413424
# This is only called if the websocket has been created
414425
assert self._task_manager.check_running_in_task_loop()
415426
ws = self._ws
416-
assert ws is not None
427+
if ws is None:
428+
# Assume app is shutting down and the owning task has already been cancelled
429+
return
417430
try:
418431
await ws.send_json(message)
419432
except Exception as exc:
@@ -430,7 +443,9 @@ async def _receive_json(self) -> Any:
430443
# This is only called if the websocket has been created
431444
assert self._task_manager.check_running_in_task_loop()
432445
ws = self._ws
433-
assert ws is not None
446+
if ws is None:
447+
# Assume app is shutting down and the owning task has already been cancelled
448+
return
434449
try:
435450
return await ws.receive_json()
436451
except Exception as exc:
@@ -443,7 +458,9 @@ async def _authenticate(self) -> bool:
443458
# This is only called if the websocket has been created
444459
assert self._task_manager.check_running_in_task_loop()
445460
ws = self._ws
446-
assert ws is not None
461+
if ws is None:
462+
# Assume app is shutting down and the owning task has already been cancelled
463+
return False
447464
auth_message = self._auth_details
448465
await self.send_json(auth_message)
449466
auth_result = await self._receive_json()
@@ -461,11 +478,11 @@ async def _process_next_message(self) -> bool:
461478
# This is only called if the websocket has been created
462479
assert self._task_manager.check_running_in_task_loop()
463480
ws = self._ws
464-
assert ws is not None
481+
if ws is None:
482+
# Assume app is shutting down and the owning task has already been cancelled
483+
return False
465484
message = await ws.receive_json()
466-
# Enqueueing messages may be a blocking call
467-
# TODO: Require it to return an Awaitable, move to_thread call to the sync bridge
468-
return await asyncio.to_thread(self._enqueue_message, message)
485+
return await self._enqueue_message(message)
469486

470487
async def _receive_messages(self) -> None:
471488
"""Process received messages until task is cancelled."""
@@ -475,7 +492,7 @@ async def _receive_messages(self) -> None:
475492
except (LMStudioWebsocketError, HTTPXWSException):
476493
if self._ws is not None and not self._ws_disconnected.is_set():
477494
# Websocket failed unexpectedly (rather than due to client shutdown)
478-
self._logger.exception("Websocket failed, terminating session.")
495+
self._logger.error("Websocket failed, terminating session.")
479496
break
480497

481498

@@ -485,11 +502,14 @@ def __init__(
485502
ws_thread: AsyncWebsocketThread,
486503
ws_url: str,
487504
auth_details: DictObject,
488-
enqueue_message: Callable[[DictObject], bool],
505+
enqueue_message: Callable[[DictObject | None], bool],
489506
log_context: LogEventContext,
490507
) -> None:
508+
async def enqueue_async(message: DictObject | None) -> bool:
509+
return await asyncio.to_thread(enqueue_message, message)
510+
491511
self._ws_handler = AsyncWebsocketHandler(
492-
ws_thread.task_manager, ws_url, auth_details, enqueue_message, log_context
512+
ws_thread.task_manager, ws_url, auth_details, enqueue_async, log_context
493513
)
494514

495515
def connect(self) -> bool:

src/lmstudio/async_api.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Async I/O protocol implementation for the LM Studio remote access API."""
22

33
import asyncio
4-
import asyncio.queues
54
import warnings
65

76
from abc import abstractmethod
@@ -28,6 +27,8 @@
2827
TypeIs,
2928
)
3029

30+
from anyio import create_task_group
31+
from anyio.abc import TaskGroup
3132
from httpx import RequestError, HTTPStatusError
3233
from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException
3334

@@ -163,7 +164,10 @@ async def rx_stream(
163164
# Avoid emitting tracebacks that delve into supporting libraries
164165
# (we can't easily suppress the SDK's own frames for iterators)
165166
message = await self._rx_queue.get()
166-
contents = self._api_channel.handle_rx_message(message)
167+
if message is None:
168+
contents = None
169+
else:
170+
contents = self._api_channel.handle_rx_message(message)
167171
if contents is None:
168172
self._is_finished = True
169173
break
@@ -204,6 +208,8 @@ def get_rpc_message(
204208
async def receive_result(self) -> Any:
205209
"""Receive call response on the receive queue."""
206210
message = await self._rx_queue.get()
211+
if message is None:
212+
return None
207213
return self._rpc.handle_rx_message(message)
208214

209215

@@ -220,8 +226,10 @@ def __init__(
220226
) -> None:
221227
"""Initialize asynchronous websocket client."""
222228
super().__init__(ws_url, auth_details, log_context)
223-
self._resource_manager = AsyncExitStack()
229+
self._resource_manager = rm = AsyncExitStack()
230+
rm.push_async_callback(self._notify_client_termination)
224231
self._rx_task: asyncio.Task[None] | None = None
232+
self._terminate = asyncio.Event()
225233

226234
@property
227235
def _httpx_ws(self) -> AsyncWebSocketSession | None:
@@ -241,7 +249,9 @@ async def __aexit__(self, *args: Any) -> None:
241249
async def _send_json(self, message: DictObject) -> None:
242250
# Callers are expected to call `_ensure_connected` before this method
243251
ws = self._ws
244-
assert ws is not None
252+
if ws is None:
253+
# Assume app is shutting down and the owning task has already been cancelled
254+
return
245255
try:
246256
await ws.send_json(message)
247257
except Exception as exc:
@@ -253,7 +263,9 @@ async def _send_json(self, message: DictObject) -> None:
253263
async def _receive_json(self) -> Any:
254264
# Callers are expected to call `_ensure_connected` before this method
255265
ws = self._ws
256-
assert ws is not None
266+
if ws is None:
267+
# Assume app is shutting down and the owning task has already been cancelled
268+
return
257269
try:
258270
return await ws.receive_json()
259271
except Exception as exc:
@@ -291,7 +303,7 @@ async def connect(self) -> Self:
291303
self._rx_task = rx_task = asyncio.create_task(self._receive_messages())
292304

293305
async def _terminate_rx_task() -> None:
294-
rx_task.cancel()
306+
self._terminate.set()
295307
try:
296308
await rx_task
297309
except asyncio.CancelledError:
@@ -305,19 +317,34 @@ async def disconnect(self) -> None:
305317
"""Drop the LM Studio API connection."""
306318
self._ws = None
307319
self._rx_task = None
308-
await self._notify_client_termination()
320+
self._terminate.set()
309321
await self._resource_manager.aclose()
310322
self._logger.info(f"Websocket session disconnected ({self._ws_url})")
311323

312324
aclose = disconnect
313325

326+
async def _cancel_on_termination(self, tg: TaskGroup) -> None:
327+
await self._terminate.wait()
328+
tg.cancel_scope.cancel()
329+
314330
async def _process_next_message(self) -> bool:
315331
"""Process the next message received on the websocket.
316332
317333
Returns True if a message queue was updated.
318334
"""
319335
self._ensure_connected("receive messages")
320-
message = await self._receive_json()
336+
async with create_task_group() as tg:
337+
tg.start_soon(self._cancel_on_termination, tg)
338+
try:
339+
message = await self._receive_json()
340+
except (LMStudioWebsocketError, HTTPXWSException):
341+
if self._ws is not None and not self._terminate.is_set():
342+
# Websocket failed unexpectedly (rather than due to client shutdown)
343+
self._logger.error("Websocket failed, terminating session.")
344+
self._terminate.set()
345+
tg.cancel_scope.cancel()
346+
if self._terminate.is_set():
347+
return (await self._notify_client_termination()) > 0
321348
rx_queue = self._mux.map_rx_message(message)
322349
if rx_queue is None:
323350
return False
@@ -326,18 +353,20 @@ async def _process_next_message(self) -> bool:
326353

327354
async def _receive_messages(self) -> None:
328355
"""Process received messages until connection is terminated."""
329-
while True:
330-
try:
331-
await self._process_next_message()
332-
except (LMStudioWebsocketError, HTTPXWSException):
333-
self._logger.exception("Websocket failed, terminating session.")
334-
await self.disconnect()
335-
break
356+
while not self._terminate.is_set():
357+
await self._process_next_message()
336358

337-
async def _notify_client_termination(self) -> None:
359+
async def _notify_client_termination(self) -> int:
338360
"""Send None to all clients with open receive queues."""
361+
num_clients = 0
339362
for rx_queue in self._mux.all_queues():
340363
await rx_queue.put(None)
364+
num_clients += 1
365+
self._logger.info(
366+
f"Notified {num_clients} clients of websocket termination",
367+
num_clients=num_clients,
368+
)
369+
return num_clients
341370

342371
async def _connect_to_endpoint(self, channel: AsyncChannel[Any]) -> None:
343372
"""Connect channel to specified endpoint."""
@@ -362,6 +391,9 @@ async def open_channel(
362391
self._logger.event_context,
363392
)
364393
await self._connect_to_endpoint(channel)
394+
if self._terminate.is_set():
395+
# Link has been terminated, ensure client gets a response
396+
await rx_queue.put(None)
365397
yield channel
366398

367399
async def _send_call(
@@ -396,6 +428,9 @@ async def remote_call(
396428
call_id, rx_queue, self._logger.event_context, notice_prefix
397429
)
398430
await self._send_call(rpc, endpoint, params)
431+
if self._terminate.is_set():
432+
# Link has been terminated, ensure client gets a response
433+
await rx_queue.put(None)
399434
return await rpc.receive_result()
400435

401436

src/lmstudio/json_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,10 +352,10 @@ def _format_server_error(details: SerializedLMSExtendedError) -> str:
352352
lines.extend(_get_data_lines(details.error_data, " "))
353353
if details.cause is not None:
354354
lines.extend(("", " Reported cause:"))
355-
lines.extend(f" {details.cause}")
355+
lines.append(f" {details.cause}")
356356
if details.suggestion is not None:
357357
lines.extend(("", " Suggested potential remedy:"))
358-
lines.extend(f" {details.suggestion}")
358+
lines.append(f" {details.suggestion}")
359359
# Only use the multi-line format if at least one
360360
# of the extended error fields is populated
361361
if lines:

0 commit comments

Comments
 (0)