Skip to content

Commit 8284014

Browse files
authored
Websocket refactoring: per-client threads (#77)
Sync API clients now create a single background event loop and share it across all of their websockets (instead of giving each websocket handler its own dedicated background thread).
1 parent b40b9cc commit 8284014

File tree

3 files changed

+168
-76
lines changed

3 files changed

+168
-76
lines changed

src/lmstudio/_ws_impl.py

Lines changed: 135 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
)
3030

3131
# Synchronous API still uses an async websocket (just in a background thread)
32-
from anyio import create_task_group
32+
from anyio import create_task_group, get_cancelled_exc_class
3333
from exceptiongroup import suppress
3434
from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException
3535

@@ -47,43 +47,11 @@
4747
T = TypeVar("T")
4848

4949

50-
class BackgroundThread(threading.Thread):
51-
"""Background async event loop thread."""
52-
53-
def __init__(
54-
self,
55-
task_target: Callable[[], Coroutine[Any, Any, Any]] | None = None,
56-
name: str | None = None,
57-
) -> None:
58-
# Accepts the same args as `threading.Thread`, *except*:
59-
# * a `task_target` coroutine replaces the `target` function
60-
# * No `daemon` option (always runs as a daemon)
61-
# Variant: accept `debug` and `loop_factory` options to forward to `asyncio.run`
62-
# Alternative: accept a `task_runner` callback, defaulting to `asyncio.run`
63-
self._task_target = task_target
64-
self._loop_started = threading.Event()
65-
self._terminate = asyncio.Event()
66-
self._event_loop: asyncio.AbstractEventLoop | None = None
67-
# Annoyingly, we have to mark the background thread as a daemon thread to
68-
# prevent hanging at shutdown. Even checking `sys.is_finalizing()` is inadequate
69-
# https://discuss.python.org/t/should-sys-is-finalizing-report-interpreter-finalization-instead-of-runtime-finalization/76695
70-
super().__init__(name=name, daemon=True)
71-
weakref.finalize(self, self.terminate)
72-
73-
def run(self) -> None:
74-
"""Run an async event loop in the background thread."""
75-
# Only public to override threading.Thread.run
76-
asyncio.run(self._run_until_terminated())
77-
78-
def wait_for_loop(self) -> asyncio.AbstractEventLoop | None:
79-
"""Wait for the event loop to start from a synchronous foreground thread."""
80-
if self._event_loop is None and not self._loop_started.is_set():
81-
self._loop_started.wait()
82-
return self._event_loop
83-
84-
async def wait_for_loop_async(self) -> asyncio.AbstractEventLoop | None:
85-
"""Wait for the event loop to start from an asynchronous foreground thread."""
86-
return await asyncio.to_thread(self.wait_for_loop)
50+
class _BackgroundTaskHandlerMixin:
51+
# Subclasses need to handle providing these instance attributes
52+
_event_loop: asyncio.AbstractEventLoop | None
53+
_task_target: Callable[[], Coroutine[Any, Any, Any]] | None
54+
_terminate: asyncio.Event
8755

8856
def called_in_background_loop(self) -> bool:
8957
"""Returns true if currently running in this thread's event loop, false otherwise."""
@@ -123,10 +91,12 @@ async def terminate_async(self) -> bool:
12391
"""Request termination of the event loop from an asynchronous foreground thread."""
12492
return await asyncio.to_thread(self.terminate)
12593

94+
def _init_event_loop(self) -> None:
95+
self._event_loop = asyncio.get_running_loop()
96+
12697
async def _run_until_terminated(self) -> None:
12798
"""Run task in the background thread until termination is requested."""
128-
self._event_loop = asyncio.get_running_loop()
129-
self._loop_started.set()
99+
self._init_event_loop()
130100
# Use anyio and exceptiongroup to handle the lack of native task
131101
# and exception groups prior to Python 3.11
132102
raise_on_termination, terminated_exc = self._raise_on_termination()
@@ -163,6 +133,49 @@ def schedule_background_task(self, coro: Coroutine[Any, Any, T]) -> SyncFuture[T
163133
assert loop is not None
164134
return asyncio.run_coroutine_threadsafe(coro, loop)
165135

136+
137+
class BackgroundThread(_BackgroundTaskHandlerMixin, threading.Thread):
138+
"""Background async event loop thread."""
139+
140+
def __init__(
141+
self,
142+
task_target: Callable[[], Coroutine[Any, Any, Any]] | None = None,
143+
name: str | None = None,
144+
) -> None:
145+
# Accepts the same args as `threading.Thread`, *except*:
146+
# * a `task_target` coroutine replaces the `target` function
147+
# * No `daemon` option (always runs as a daemon)
148+
# Variant: accept `debug` and `loop_factory` options to forward to `asyncio.run`
149+
# Alternative: accept a `task_runner` callback, defaulting to `asyncio.run`
150+
self._task_target = task_target
151+
self._loop_started = threading.Event()
152+
self._terminate = asyncio.Event()
153+
self._event_loop: asyncio.AbstractEventLoop | None = None
154+
# Annoyingly, we have to mark the background thread as a daemon thread to
155+
# prevent hanging at shutdown. Even checking `sys.is_finalizing()` is inadequate
156+
# https://discuss.python.org/t/should-sys-is-finalizing-report-interpreter-finalization-instead-of-runtime-finalization/76695
157+
super().__init__(name=name, daemon=True)
158+
weakref.finalize(self, self.terminate)
159+
160+
def run(self) -> None:
161+
"""Run an async event loop in the background thread."""
162+
# Only public to override threading.Thread.run
163+
asyncio.run(self._run_until_terminated())
164+
165+
def _init_event_loop(self) -> None:
166+
super()._init_event_loop()
167+
self._loop_started.set()
168+
169+
def wait_for_loop(self) -> asyncio.AbstractEventLoop | None:
170+
"""Wait for the event loop to start from a synchronous foreground thread."""
171+
if self._event_loop is None and not self._loop_started.is_set():
172+
self._loop_started.wait()
173+
return self._event_loop
174+
175+
async def wait_for_loop_async(self) -> asyncio.AbstractEventLoop | None:
176+
"""Wait for the event loop to start from an asynchronous foreground thread."""
177+
return await asyncio.to_thread(self.wait_for_loop)
178+
166179
def run_background_task(self, coro: Coroutine[Any, Any, T]) -> T:
167180
"""Run given coroutine in the background event loop and wait for the result."""
168181
return self.schedule_background_task(coro).result()
@@ -178,62 +191,83 @@ def call_in_background(self, callback: Callable[[], Any]) -> None:
178191
loop.call_soon_threadsafe(callback)
179192

180193

181-
# TODO: Allow multiple websockets to share a single event loop thread
182-
# (reduces thread usage in sync API, blocker for async API migration)
183194
class AsyncWebsocketThread(BackgroundThread):
195+
def __init__(self, log_context: LogEventContext | None = None) -> None:
196+
super().__init__(task_target=self._run_main_task)
197+
self._logger = logger = get_logger(type(self).__name__)
198+
logger.update_context(log_context, thread_id=self.name)
199+
200+
async def _run_main_task(self) -> None:
201+
self._logger.info("Websocket handling thread started")
202+
never_set = asyncio.Event()
203+
try:
204+
# Run the event loop until termination is requested
205+
await never_set.wait()
206+
except get_cancelled_exc_class():
207+
pass
208+
except BaseException:
209+
err_msg = "Terminating websocket thread due to exception"
210+
self._logger.debug(err_msg, exc_info=True)
211+
self._logger.info("Websocket thread terminated")
212+
213+
214+
# TODO: Improve code sharing between AsyncWebsocketHandler and
215+
# the async-native AsyncLMStudioWebsocket implementation
216+
class AsyncWebsocketHandler(_BackgroundTaskHandlerMixin):
217+
"""Async task handler for a single websocket connection."""
218+
184219
def __init__(
185220
self,
221+
ws_thread: AsyncWebsocketThread,
186222
ws_url: str,
187223
auth_details: DictObject,
188224
enqueue_message: Callable[[DictObject], bool],
189-
log_context: LogEventContext,
225+
log_context: LogEventContext | None = None,
190226
) -> None:
191227
self._auth_details = auth_details
192228
self._connection_attempted = asyncio.Event()
193229
self._connection_failure: Exception | None = None
194230
self._auth_failure: Any | None = None
195231
self._terminate = asyncio.Event()
232+
self._ws_thread = ws_thread
196233
self._ws_url = ws_url
197234
self._ws: AsyncWebSocketSession | None = None
198235
self._rx_task: asyncio.Task[None] | None = None
199236
self._queue_message = enqueue_message
200-
super().__init__(task_target=self._run_main_task)
237+
self._logger = get_logger(type(self).__name__)
201238
self._logger = logger = get_logger(type(self).__name__)
202-
logger.update_context(log_context, thread_id=self.name)
239+
logger.update_context(log_context)
203240

204241
def connect(self) -> bool:
205-
if not self.is_alive():
206-
self.start()
207-
loop = self.wait_for_loop() # Block until connection has been attempted
242+
ws_thread = self._ws_thread
243+
if not ws_thread.is_alive():
244+
raise RuntimeError("Websocket handling thread has failed unexpectedly")
245+
loop = ws_thread.wait_for_loop() # Block until loop is available
208246
if loop is None:
209-
return False
247+
raise RuntimeError("Websocket handling thread has no event loop")
248+
ws_thread.schedule_background_task(self._run_until_terminated())
210249
asyncio.run_coroutine_threadsafe(
211250
self._connection_attempted.wait(), loop
212251
).result()
213252
return self._ws is not None
214253

215-
def disconnect(self) -> None:
216-
if self._ws is not None:
217-
self.terminate()
218-
# Ensure thread has terminated
219-
self.join()
220-
221-
async def _run_main_task(self) -> None:
222-
self._logger.info("Websocket thread started")
254+
async def _task_target(self) -> None:
255+
self._logger.info("Websocket handling task started")
256+
self._init_event_loop()
223257
try:
224-
await self._main_task()
258+
await self._handle_ws()
259+
except get_cancelled_exc_class():
260+
pass
225261
except BaseException:
226-
err_msg = "Terminating websocket thread due to exception"
262+
err_msg = "Terminating websocket task due to exception"
227263
self._logger.debug(err_msg, exc_info=True)
228264
finally:
229265
# Ensure the foreground thread is unblocked even if the
230266
# background async task errors out completely
231267
self._connection_attempted.set()
232-
self._logger.info("Websocket thread terminated")
268+
self._logger.info("Websocket task terminated")
233269

234-
# TODO: Improve code sharing between this background thread async websocket
235-
# and the async-native AsyncLMStudioWebsocket implementation
236-
async def _main_task(self) -> None:
270+
async def _handle_ws(self) -> None:
237271
resources = AsyncExitStack()
238272
try:
239273
ws: AsyncWebSocketSession = await resources.enter_async_context(
@@ -274,6 +308,10 @@ async def _send_json(self, message: DictObject) -> None:
274308
self._logger.debug(str(err), exc_info=True)
275309
raise err from None
276310

311+
def send_json(self, message: DictObject) -> None:
312+
future = self.schedule_background_task(self._send_json(message))
313+
future.result() # Block until the message is sent
314+
277315
async def _receive_json(self) -> Any:
278316
# This is only called if the websocket has been created
279317
assert self.called_in_background_loop()
@@ -335,8 +373,6 @@ async def _demultiplexing_task(self) -> None:
335373
finally:
336374
self._logger.info("Websocket closed, terminating demultiplexing task.")
337375

338-
raise_on_termination, terminated_exc = self._raise_on_termination()
339-
340376
async def _receive_messages(self) -> None:
341377
"""Process received messages until task is cancelled."""
342378
while True:
@@ -349,6 +385,38 @@ async def _receive_messages(self) -> None:
349385
self._terminate.set()
350386
break
351387

388+
389+
class SyncToAsyncWebsocketBridge:
390+
def __init__(
391+
self,
392+
ws_thread: AsyncWebsocketThread,
393+
ws_url: str,
394+
auth_details: DictObject,
395+
enqueue_message: Callable[[DictObject], bool],
396+
log_context: LogEventContext,
397+
) -> None:
398+
self._ws_handler = AsyncWebsocketHandler(
399+
ws_thread, ws_url, auth_details, enqueue_message, log_context
400+
)
401+
402+
def connect(self) -> bool:
403+
return self._ws_handler.connect()
404+
405+
def disconnect(self) -> None:
406+
self._ws_handler.terminate()
407+
352408
def send_json(self, message: DictObject) -> None:
353-
# Block until message has been sent
354-
self.run_background_task(self._send_json(message))
409+
self._ws_handler.send_json(message)
410+
411+
# These attributes are currently accessed directly...
412+
@property
413+
def _ws(self) -> AsyncWebSocketSession | None:
414+
return self._ws_handler._ws
415+
416+
@property
417+
def _connection_failure(self) -> Exception | None:
418+
return self._ws_handler._connection_failure
419+
420+
@property
421+
def _auth_failure(self) -> Any | None:
422+
return self._ws_handler._auth_failure

src/lmstudio/sync_api.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
_model_spec_to_api_dict,
110110
_redact_json,
111111
)
112-
from ._ws_impl import AsyncWebsocketThread
112+
from ._ws_impl import AsyncWebsocketThread, SyncToAsyncWebsocketBridge
113113
from ._kv_config import TLoadConfig, TLoadConfigDict, parse_server_config
114114
from ._sdk_models import (
115115
EmbeddingRpcCountTokensParameter,
@@ -233,17 +233,21 @@ def receive_result(self) -> Any:
233233
return self._rpc.handle_rx_message(message)
234234

235235

236-
class SyncLMStudioWebsocket(LMStudioWebsocket[AsyncWebsocketThread, queue.Queue[Any]]):
236+
class SyncLMStudioWebsocket(
237+
LMStudioWebsocket[SyncToAsyncWebsocketBridge, queue.Queue[Any]]
238+
):
237239
"""Synchronous websocket client that handles demultiplexing of reply messages."""
238240

239241
def __init__(
240242
self,
243+
ws_thread: AsyncWebsocketThread,
241244
ws_url: str,
242245
auth_details: DictObject,
243246
log_context: LogEventContext | None = None,
244247
) -> None:
245248
"""Initialize synchronous websocket client."""
246249
super().__init__(ws_url, auth_details, log_context)
250+
self._ws_thread = ws_thread
247251

248252
@property
249253
def _httpx_ws(self) -> AsyncWebSocketSession | None:
@@ -266,7 +270,8 @@ def __exit__(self, *args: Any) -> None:
266270
def connect(self) -> Self:
267271
"""Connect to and authenticate with the LM Studio API."""
268272
self._fail_if_connected("Attempted to connect already connected websocket")
269-
ws = AsyncWebsocketThread(
273+
ws = SyncToAsyncWebsocketBridge(
274+
self._ws_thread,
270275
self._ws_url,
271276
self._auth_details,
272277
self._enqueue_message,
@@ -409,7 +414,9 @@ def connect(self) -> SyncLMStudioWebsocket:
409414
session_url = f"ws://{api_host}/{namespace}"
410415
resources = self._resources
411416
self._lmsws = lmsws = resources.enter_context(
412-
SyncLMStudioWebsocket(session_url, self._client._auth_details)
417+
SyncLMStudioWebsocket(
418+
self._client._ws_thread, session_url, self._client._auth_details
419+
)
413420
)
414421
return lmsws
415422

@@ -1482,8 +1489,11 @@ def __init__(self, api_host: str | None = None) -> None:
14821489
"""Initialize API client."""
14831490
super().__init__(api_host)
14841491
self._resources = rm = ExitStack()
1492+
self._ws_thread = ws_thread = AsyncWebsocketThread(dict(client=repr(self)))
1493+
ws_thread.start()
1494+
rm.callback(ws_thread.terminate)
14851495
self._sessions: dict[str, SyncSession] = {}
1486-
# Suport GC-based resource management in the sync API by
1496+
# Support GC-based resource management in the sync API by
14871497
# finalizing at the client layer, and letting its resource
14881498
# manager handle clearing up everything else
14891499
rm.callback(self._sessions.clear)

0 commit comments

Comments
 (0)