Skip to content

Commit b40b9cc

Browse files
authored
Websocket refactoring: define BackgroundThread (#76)
* Move sync background thread to new websocket submodule * Split out the background thread functionality to a base class
1 parent bd04cdd commit b40b9cc

File tree

2 files changed

+358
-218
lines changed

2 files changed

+358
-218
lines changed

src/lmstudio/_ws_impl.py

Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
"""Shared core async websocket implementation for the LM Studio remote access API."""
2+
3+
# Sync API: runs in background thread with sync queues
4+
# Async convenience API: runs in background thread with async queues
5+
# Async structured API: runs in foreground event loop
6+
7+
# Callback handling rules:
8+
#
9+
# * All callbacks are synchronous (use external async queues if needed)
10+
# * All callbacks must be invoked from the *foreground* thread/event loop
11+
12+
import asyncio
13+
import threading
14+
import weakref
15+
16+
from concurrent.futures import Future as SyncFuture
17+
from contextlib import (
18+
asynccontextmanager,
19+
AsyncExitStack,
20+
)
21+
from typing import (
22+
Any,
23+
AsyncGenerator,
24+
Awaitable,
25+
Coroutine,
26+
Callable,
27+
NoReturn,
28+
TypeVar,
29+
)
30+
31+
# Synchronous API still uses an async websocket (just in a background thread)
32+
from anyio import create_task_group
33+
from exceptiongroup import suppress
34+
from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException
35+
36+
from .schemas import DictObject
37+
from .json_api import LMStudioWebsocket, LMStudioWebsocketError
38+
39+
from ._logging import get_logger, LogEventContext
40+
41+
42+
# Allow the core client websocket management to be shared across all SDK interaction APIs
43+
# See https://discuss.python.org/t/daemon-threads-and-background-task-termination/77604
44+
# (Note: this implementation has the elements needed to run on *current* Python versions
45+
# and omits the generalised features that the SDK doesn't need)
46+
# Already used by the sync API, async client is still to be migrated
47+
T = TypeVar("T")
48+
49+
50+
class BackgroundThread(threading.Thread):
51+
"""Background async event loop thread."""
52+
53+
def __init__(
54+
self,
55+
task_target: Callable[[], Coroutine[Any, Any, Any]] | None = None,
56+
name: str | None = None,
57+
) -> None:
58+
# Accepts the same args as `threading.Thread`, *except*:
59+
# * a `task_target` coroutine replaces the `target` function
60+
# * No `daemon` option (always runs as a daemon)
61+
# Variant: accept `debug` and `loop_factory` options to forward to `asyncio.run`
62+
# Alternative: accept a `task_runner` callback, defaulting to `asyncio.run`
63+
self._task_target = task_target
64+
self._loop_started = threading.Event()
65+
self._terminate = asyncio.Event()
66+
self._event_loop: asyncio.AbstractEventLoop | None = None
67+
# Annoyingly, we have to mark the background thread as a daemon thread to
68+
# prevent hanging at shutdown. Even checking `sys.is_finalizing()` is inadequate
69+
# https://discuss.python.org/t/should-sys-is-finalizing-report-interpreter-finalization-instead-of-runtime-finalization/76695
70+
super().__init__(name=name, daemon=True)
71+
weakref.finalize(self, self.terminate)
72+
73+
def run(self) -> None:
74+
"""Run an async event loop in the background thread."""
75+
# Only public to override threading.Thread.run
76+
asyncio.run(self._run_until_terminated())
77+
78+
def wait_for_loop(self) -> asyncio.AbstractEventLoop | None:
79+
"""Wait for the event loop to start from a synchronous foreground thread."""
80+
if self._event_loop is None and not self._loop_started.is_set():
81+
self._loop_started.wait()
82+
return self._event_loop
83+
84+
async def wait_for_loop_async(self) -> asyncio.AbstractEventLoop | None:
85+
"""Wait for the event loop to start from an asynchronous foreground thread."""
86+
return await asyncio.to_thread(self.wait_for_loop)
87+
88+
def called_in_background_loop(self) -> bool:
89+
"""Returns true if currently running in this thread's event loop, false otherwise."""
90+
loop = self._event_loop
91+
if loop is None:
92+
# No loop active in background thread -> can't be running in it
93+
return False
94+
try:
95+
running_loop = asyncio.get_running_loop()
96+
except RuntimeError:
97+
# No loop in this thread -> can't be running in the background thread's loop
98+
return False
99+
# Otherwise, check if the running loop is the background thread's loop
100+
return running_loop is loop
101+
102+
async def _request_termination(self) -> bool:
103+
assert self.called_in_background_loop()
104+
if self._terminate.is_set():
105+
return False
106+
self._terminate.set()
107+
return True
108+
109+
def request_termination(self) -> SyncFuture[bool]:
110+
"""Request termination of the event loop (without blocking)."""
111+
loop = self._event_loop
112+
if loop is None or self._terminate.is_set():
113+
result: SyncFuture[bool] = SyncFuture()
114+
result.set_result(False)
115+
return result
116+
return self.schedule_background_task(self._request_termination())
117+
118+
def terminate(self) -> bool:
119+
"""Request termination of the event loop from a synchronous foreground thread."""
120+
return self.request_termination().result()
121+
122+
async def terminate_async(self) -> bool:
123+
"""Request termination of the event loop from an asynchronous foreground thread."""
124+
return await asyncio.to_thread(self.terminate)
125+
126+
async def _run_until_terminated(self) -> None:
127+
"""Run task in the background thread until termination is requested."""
128+
self._event_loop = asyncio.get_running_loop()
129+
self._loop_started.set()
130+
# Use anyio and exceptiongroup to handle the lack of native task
131+
# and exception groups prior to Python 3.11
132+
raise_on_termination, terminated_exc = self._raise_on_termination()
133+
with suppress(terminated_exc):
134+
try:
135+
async with create_task_group() as tg:
136+
tg.start_soon(self._run_task_target)
137+
tg.start_soon(raise_on_termination)
138+
finally:
139+
# Event loop is about to shut down
140+
self._event_loop = None
141+
142+
async def _run_task_target(self) -> None:
143+
if self._task_target is not None:
144+
main_task = self._task_target()
145+
self._task_target = None
146+
await main_task
147+
148+
def _raise_on_termination(
149+
self,
150+
) -> tuple[Callable[[], Awaitable[NoReturn]], type[Exception]]:
151+
class TerminateTask(Exception):
152+
pass
153+
154+
async def raise_on_termination() -> NoReturn:
155+
await self._terminate.wait()
156+
raise TerminateTask
157+
158+
return raise_on_termination, TerminateTask
159+
160+
def schedule_background_task(self, coro: Coroutine[Any, Any, T]) -> SyncFuture[T]:
161+
"""Schedule given coroutine in the background event loop."""
162+
loop = self._event_loop
163+
assert loop is not None
164+
return asyncio.run_coroutine_threadsafe(coro, loop)
165+
166+
def run_background_task(self, coro: Coroutine[Any, Any, T]) -> T:
167+
"""Run given coroutine in the background event loop and wait for the result."""
168+
return self.schedule_background_task(coro).result()
169+
170+
async def run_background_task_async(self, coro: Coroutine[Any, Any, T]) -> T:
171+
"""Run given coroutine in the background event loop and await the result."""
172+
return await asyncio.to_thread(self.run_background_task, coro)
173+
174+
def call_in_background(self, callback: Callable[[], Any]) -> None:
175+
"""Call given non-blocking function in the background event loop."""
176+
loop = self._event_loop
177+
assert loop is not None
178+
loop.call_soon_threadsafe(callback)
179+
180+
181+
# TODO: Allow multiple websockets to share a single event loop thread
182+
# (reduces thread usage in sync API, blocker for async API migration)
183+
class AsyncWebsocketThread(BackgroundThread):
184+
def __init__(
185+
self,
186+
ws_url: str,
187+
auth_details: DictObject,
188+
enqueue_message: Callable[[DictObject], bool],
189+
log_context: LogEventContext,
190+
) -> None:
191+
self._auth_details = auth_details
192+
self._connection_attempted = asyncio.Event()
193+
self._connection_failure: Exception | None = None
194+
self._auth_failure: Any | None = None
195+
self._terminate = asyncio.Event()
196+
self._ws_url = ws_url
197+
self._ws: AsyncWebSocketSession | None = None
198+
self._rx_task: asyncio.Task[None] | None = None
199+
self._queue_message = enqueue_message
200+
super().__init__(task_target=self._run_main_task)
201+
self._logger = logger = get_logger(type(self).__name__)
202+
logger.update_context(log_context, thread_id=self.name)
203+
204+
def connect(self) -> bool:
205+
if not self.is_alive():
206+
self.start()
207+
loop = self.wait_for_loop() # Block until connection has been attempted
208+
if loop is None:
209+
return False
210+
asyncio.run_coroutine_threadsafe(
211+
self._connection_attempted.wait(), loop
212+
).result()
213+
return self._ws is not None
214+
215+
def disconnect(self) -> None:
216+
if self._ws is not None:
217+
self.terminate()
218+
# Ensure thread has terminated
219+
self.join()
220+
221+
async def _run_main_task(self) -> None:
222+
self._logger.info("Websocket thread started")
223+
try:
224+
await self._main_task()
225+
except BaseException:
226+
err_msg = "Terminating websocket thread due to exception"
227+
self._logger.debug(err_msg, exc_info=True)
228+
finally:
229+
# Ensure the foreground thread is unblocked even if the
230+
# background async task errors out completely
231+
self._connection_attempted.set()
232+
self._logger.info("Websocket thread terminated")
233+
234+
# TODO: Improve code sharing between this background thread async websocket
235+
# and the async-native AsyncLMStudioWebsocket implementation
236+
async def _main_task(self) -> None:
237+
resources = AsyncExitStack()
238+
try:
239+
ws: AsyncWebSocketSession = await resources.enter_async_context(
240+
aconnect_ws(self._ws_url)
241+
)
242+
except Exception as exc:
243+
self._connection_failure = exc
244+
raise
245+
246+
def _clear_task_state() -> None:
247+
# Break the reference cycle with the foreground thread
248+
del self._queue_message
249+
# Websocket is about to be disconnected
250+
self._ws = None
251+
252+
resources.callback(_clear_task_state)
253+
async with resources:
254+
self._logger.debug("Websocket connected")
255+
self._ws = ws
256+
if not await self._authenticate():
257+
return
258+
async with self._manage_demultiplexing():
259+
self._connection_attempted.set()
260+
self._logger.info(f"Websocket session established ({self._ws_url})")
261+
# Keep the event loop alive until termination is requested
262+
await self._terminate.wait()
263+
264+
async def _send_json(self, message: DictObject) -> None:
265+
# This is only called if the websocket has been created
266+
assert self.called_in_background_loop()
267+
ws = self._ws
268+
assert ws is not None
269+
try:
270+
await ws.send_json(message)
271+
except Exception as exc:
272+
err = LMStudioWebsocket._get_tx_error(message, exc)
273+
# Log the underlying exception info, but simplify the raised traceback
274+
self._logger.debug(str(err), exc_info=True)
275+
raise err from None
276+
277+
async def _receive_json(self) -> Any:
278+
# This is only called if the websocket has been created
279+
assert self.called_in_background_loop()
280+
ws = self._ws
281+
assert ws is not None
282+
try:
283+
return await ws.receive_json()
284+
except Exception as exc:
285+
err = LMStudioWebsocket._get_rx_error(exc)
286+
# Log the underlying exception info, but simplify the raised traceback
287+
self._logger.debug(str(err), exc_info=True)
288+
raise err from None
289+
290+
async def _authenticate(self) -> bool:
291+
# This is only called if the websocket has been created
292+
assert self.called_in_background_loop()
293+
ws = self._ws
294+
assert ws is not None
295+
auth_message = self._auth_details
296+
await self._send_json(auth_message)
297+
auth_result = await self._receive_json()
298+
self._logger.debug("Websocket authenticated", json=auth_result)
299+
if not auth_result["success"]:
300+
self._auth_failure = auth_result["error"]
301+
return False
302+
return True
303+
304+
@asynccontextmanager
305+
async def _manage_demultiplexing(
306+
self,
307+
) -> AsyncGenerator[asyncio.Task[None], None]:
308+
assert self.called_in_background_loop()
309+
self._rx_task = rx_task = asyncio.create_task(self._demultiplexing_task())
310+
try:
311+
yield rx_task
312+
finally:
313+
if rx_task.cancel():
314+
try:
315+
await rx_task
316+
except asyncio.CancelledError:
317+
pass
318+
319+
async def _process_next_message(self) -> bool:
320+
"""Process the next message received on the websocket.
321+
322+
Returns True if a message queue was updated.
323+
"""
324+
# This is only called if the websocket has been created
325+
assert self.called_in_background_loop()
326+
ws = self._ws
327+
assert ws is not None
328+
message = await ws.receive_json()
329+
return await asyncio.to_thread(self._queue_message, message)
330+
331+
async def _demultiplexing_task(self) -> None:
332+
"""Process received messages until connection is terminated."""
333+
try:
334+
await self._receive_messages()
335+
finally:
336+
self._logger.info("Websocket closed, terminating demultiplexing task.")
337+
338+
raise_on_termination, terminated_exc = self._raise_on_termination()
339+
340+
async def _receive_messages(self) -> None:
341+
"""Process received messages until task is cancelled."""
342+
while True:
343+
try:
344+
await self._process_next_message()
345+
except (LMStudioWebsocketError, HTTPXWSException):
346+
if self._ws is not None:
347+
# Websocket failed unexpectedly (rather than due to client shutdown)
348+
self._logger.exception("Websocket failed, terminating session.")
349+
self._terminate.set()
350+
break
351+
352+
def send_json(self, message: DictObject) -> None:
353+
# Block until message has been sent
354+
self.run_background_task(self._send_json(message))

0 commit comments

Comments
 (0)