29
29
)
30
30
31
31
# Synchronous API still uses an async websocket (just in a background thread)
32
- from anyio import create_task_group
32
+ from anyio import create_task_group , get_cancelled_exc_class
33
33
from exceptiongroup import suppress
34
34
from httpx_ws import aconnect_ws , AsyncWebSocketSession , HTTPXWSException
35
35
47
47
T = TypeVar ("T" )
48
48
49
49
50
- class BackgroundThread (threading .Thread ):
51
- """Background async event loop thread."""
52
-
53
- def __init__ (
54
- self ,
55
- task_target : Callable [[], Coroutine [Any , Any , Any ]] | None = None ,
56
- name : str | None = None ,
57
- ) -> None :
58
- # Accepts the same args as `threading.Thread`, *except*:
59
- # * a `task_target` coroutine replaces the `target` function
60
- # * No `daemon` option (always runs as a daemon)
61
- # Variant: accept `debug` and `loop_factory` options to forward to `asyncio.run`
62
- # Alternative: accept a `task_runner` callback, defaulting to `asyncio.run`
63
- self ._task_target = task_target
64
- self ._loop_started = threading .Event ()
65
- self ._terminate = asyncio .Event ()
66
- self ._event_loop : asyncio .AbstractEventLoop | None = None
67
- # Annoyingly, we have to mark the background thread as a daemon thread to
68
- # prevent hanging at shutdown. Even checking `sys.is_finalizing()` is inadequate
69
- # https://discuss.python.org/t/should-sys-is-finalizing-report-interpreter-finalization-instead-of-runtime-finalization/76695
70
- super ().__init__ (name = name , daemon = True )
71
- weakref .finalize (self , self .terminate )
72
-
73
- def run (self ) -> None :
74
- """Run an async event loop in the background thread."""
75
- # Only public to override threading.Thread.run
76
- asyncio .run (self ._run_until_terminated ())
77
-
78
- def wait_for_loop (self ) -> asyncio .AbstractEventLoop | None :
79
- """Wait for the event loop to start from a synchronous foreground thread."""
80
- if self ._event_loop is None and not self ._loop_started .is_set ():
81
- self ._loop_started .wait ()
82
- return self ._event_loop
83
-
84
- async def wait_for_loop_async (self ) -> asyncio .AbstractEventLoop | None :
85
- """Wait for the event loop to start from an asynchronous foreground thread."""
86
- return await asyncio .to_thread (self .wait_for_loop )
50
+ class _BackgroundTaskHandlerMixin :
51
+ # Subclasses need to handle providing these instance attributes
52
+ _event_loop : asyncio .AbstractEventLoop | None
53
+ _task_target : Callable [[], Coroutine [Any , Any , Any ]] | None
54
+ _terminate : asyncio .Event
87
55
88
56
def called_in_background_loop (self ) -> bool :
89
57
"""Returns true if currently running in this thread's event loop, false otherwise."""
@@ -123,10 +91,12 @@ async def terminate_async(self) -> bool:
123
91
"""Request termination of the event loop from an asynchronous foreground thread."""
124
92
return await asyncio .to_thread (self .terminate )
125
93
94
+ def _init_event_loop (self ) -> None :
95
+ self ._event_loop = asyncio .get_running_loop ()
96
+
126
97
async def _run_until_terminated (self ) -> None :
127
98
"""Run task in the background thread until termination is requested."""
128
- self ._event_loop = asyncio .get_running_loop ()
129
- self ._loop_started .set ()
99
+ self ._init_event_loop ()
130
100
# Use anyio and exceptiongroup to handle the lack of native task
131
101
# and exception groups prior to Python 3.11
132
102
raise_on_termination , terminated_exc = self ._raise_on_termination ()
@@ -163,6 +133,49 @@ def schedule_background_task(self, coro: Coroutine[Any, Any, T]) -> SyncFuture[T
163
133
assert loop is not None
164
134
return asyncio .run_coroutine_threadsafe (coro , loop )
165
135
136
+
137
+ class BackgroundThread (_BackgroundTaskHandlerMixin , threading .Thread ):
138
+ """Background async event loop thread."""
139
+
140
+ def __init__ (
141
+ self ,
142
+ task_target : Callable [[], Coroutine [Any , Any , Any ]] | None = None ,
143
+ name : str | None = None ,
144
+ ) -> None :
145
+ # Accepts the same args as `threading.Thread`, *except*:
146
+ # * a `task_target` coroutine replaces the `target` function
147
+ # * No `daemon` option (always runs as a daemon)
148
+ # Variant: accept `debug` and `loop_factory` options to forward to `asyncio.run`
149
+ # Alternative: accept a `task_runner` callback, defaulting to `asyncio.run`
150
+ self ._task_target = task_target
151
+ self ._loop_started = threading .Event ()
152
+ self ._terminate = asyncio .Event ()
153
+ self ._event_loop : asyncio .AbstractEventLoop | None = None
154
+ # Annoyingly, we have to mark the background thread as a daemon thread to
155
+ # prevent hanging at shutdown. Even checking `sys.is_finalizing()` is inadequate
156
+ # https://discuss.python.org/t/should-sys-is-finalizing-report-interpreter-finalization-instead-of-runtime-finalization/76695
157
+ super ().__init__ (name = name , daemon = True )
158
+ weakref .finalize (self , self .terminate )
159
+
160
+ def run (self ) -> None :
161
+ """Run an async event loop in the background thread."""
162
+ # Only public to override threading.Thread.run
163
+ asyncio .run (self ._run_until_terminated ())
164
+
165
+ def _init_event_loop (self ) -> None :
166
+ super ()._init_event_loop ()
167
+ self ._loop_started .set ()
168
+
169
+ def wait_for_loop (self ) -> asyncio .AbstractEventLoop | None :
170
+ """Wait for the event loop to start from a synchronous foreground thread."""
171
+ if self ._event_loop is None and not self ._loop_started .is_set ():
172
+ self ._loop_started .wait ()
173
+ return self ._event_loop
174
+
175
+ async def wait_for_loop_async (self ) -> asyncio .AbstractEventLoop | None :
176
+ """Wait for the event loop to start from an asynchronous foreground thread."""
177
+ return await asyncio .to_thread (self .wait_for_loop )
178
+
166
179
def run_background_task (self , coro : Coroutine [Any , Any , T ]) -> T :
167
180
"""Run given coroutine in the background event loop and wait for the result."""
168
181
return self .schedule_background_task (coro ).result ()
@@ -178,62 +191,83 @@ def call_in_background(self, callback: Callable[[], Any]) -> None:
178
191
loop .call_soon_threadsafe (callback )
179
192
180
193
181
- # TODO: Allow multiple websockets to share a single event loop thread
182
- # (reduces thread usage in sync API, blocker for async API migration)
183
194
class AsyncWebsocketThread (BackgroundThread ):
195
+ def __init__ (self , log_context : LogEventContext | None = None ) -> None :
196
+ super ().__init__ (task_target = self ._run_main_task )
197
+ self ._logger = logger = get_logger (type (self ).__name__ )
198
+ logger .update_context (log_context , thread_id = self .name )
199
+
200
+ async def _run_main_task (self ) -> None :
201
+ self ._logger .info ("Websocket handling thread started" )
202
+ never_set = asyncio .Event ()
203
+ try :
204
+ # Run the event loop until termination is requested
205
+ await never_set .wait ()
206
+ except get_cancelled_exc_class ():
207
+ pass
208
+ except BaseException :
209
+ err_msg = "Terminating websocket thread due to exception"
210
+ self ._logger .debug (err_msg , exc_info = True )
211
+ self ._logger .info ("Websocket thread terminated" )
212
+
213
+
214
+ # TODO: Improve code sharing between AsyncWebsocketHandler and
215
+ # the async-native AsyncLMStudioWebsocket implementation
216
+ class AsyncWebsocketHandler (_BackgroundTaskHandlerMixin ):
217
+ """Async task handler for a single websocket connection."""
218
+
184
219
def __init__ (
185
220
self ,
221
+ ws_thread : AsyncWebsocketThread ,
186
222
ws_url : str ,
187
223
auth_details : DictObject ,
188
224
enqueue_message : Callable [[DictObject ], bool ],
189
- log_context : LogEventContext ,
225
+ log_context : LogEventContext | None = None ,
190
226
) -> None :
191
227
self ._auth_details = auth_details
192
228
self ._connection_attempted = asyncio .Event ()
193
229
self ._connection_failure : Exception | None = None
194
230
self ._auth_failure : Any | None = None
195
231
self ._terminate = asyncio .Event ()
232
+ self ._ws_thread = ws_thread
196
233
self ._ws_url = ws_url
197
234
self ._ws : AsyncWebSocketSession | None = None
198
235
self ._rx_task : asyncio .Task [None ] | None = None
199
236
self ._queue_message = enqueue_message
200
- super (). __init__ ( task_target = self . _run_main_task )
237
+ self . _logger = get_logger ( type ( self ). __name__ )
201
238
self ._logger = logger = get_logger (type (self ).__name__ )
202
- logger .update_context (log_context , thread_id = self . name )
239
+ logger .update_context (log_context )
203
240
204
241
def connect (self ) -> bool :
205
- if not self .is_alive ():
206
- self .start ()
207
- loop = self .wait_for_loop () # Block until connection has been attempted
242
+ ws_thread = self ._ws_thread
243
+ if not ws_thread .is_alive ():
244
+ raise RuntimeError ("Websocket handling thread has failed unexpectedly" )
245
+ loop = ws_thread .wait_for_loop () # Block until loop is available
208
246
if loop is None :
209
- return False
247
+ raise RuntimeError ("Websocket handling thread has no event loop" )
248
+ ws_thread .schedule_background_task (self ._run_until_terminated ())
210
249
asyncio .run_coroutine_threadsafe (
211
250
self ._connection_attempted .wait (), loop
212
251
).result ()
213
252
return self ._ws is not None
214
253
215
- def disconnect (self ) -> None :
216
- if self ._ws is not None :
217
- self .terminate ()
218
- # Ensure thread has terminated
219
- self .join ()
220
-
221
- async def _run_main_task (self ) -> None :
222
- self ._logger .info ("Websocket thread started" )
254
+ async def _task_target (self ) -> None :
255
+ self ._logger .info ("Websocket handling task started" )
256
+ self ._init_event_loop ()
223
257
try :
224
- await self ._main_task ()
258
+ await self ._handle_ws ()
259
+ except get_cancelled_exc_class ():
260
+ pass
225
261
except BaseException :
226
- err_msg = "Terminating websocket thread due to exception"
262
+ err_msg = "Terminating websocket task due to exception"
227
263
self ._logger .debug (err_msg , exc_info = True )
228
264
finally :
229
265
# Ensure the foreground thread is unblocked even if the
230
266
# background async task errors out completely
231
267
self ._connection_attempted .set ()
232
- self ._logger .info ("Websocket thread terminated" )
268
+ self ._logger .info ("Websocket task terminated" )
233
269
234
- # TODO: Improve code sharing between this background thread async websocket
235
- # and the async-native AsyncLMStudioWebsocket implementation
236
- async def _main_task (self ) -> None :
270
+ async def _handle_ws (self ) -> None :
237
271
resources = AsyncExitStack ()
238
272
try :
239
273
ws : AsyncWebSocketSession = await resources .enter_async_context (
@@ -274,6 +308,10 @@ async def _send_json(self, message: DictObject) -> None:
274
308
self ._logger .debug (str (err ), exc_info = True )
275
309
raise err from None
276
310
311
+ def send_json (self , message : DictObject ) -> None :
312
+ future = self .schedule_background_task (self ._send_json (message ))
313
+ future .result () # Block until the message is sent
314
+
277
315
async def _receive_json (self ) -> Any :
278
316
# This is only called if the websocket has been created
279
317
assert self .called_in_background_loop ()
@@ -335,8 +373,6 @@ async def _demultiplexing_task(self) -> None:
335
373
finally :
336
374
self ._logger .info ("Websocket closed, terminating demultiplexing task." )
337
375
338
- raise_on_termination , terminated_exc = self ._raise_on_termination ()
339
-
340
376
async def _receive_messages (self ) -> None :
341
377
"""Process received messages until task is cancelled."""
342
378
while True :
@@ -349,6 +385,38 @@ async def _receive_messages(self) -> None:
349
385
self ._terminate .set ()
350
386
break
351
387
388
+
389
+ class SyncToAsyncWebsocketBridge :
390
+ def __init__ (
391
+ self ,
392
+ ws_thread : AsyncWebsocketThread ,
393
+ ws_url : str ,
394
+ auth_details : DictObject ,
395
+ enqueue_message : Callable [[DictObject ], bool ],
396
+ log_context : LogEventContext ,
397
+ ) -> None :
398
+ self ._ws_handler = AsyncWebsocketHandler (
399
+ ws_thread , ws_url , auth_details , enqueue_message , log_context
400
+ )
401
+
402
+ def connect (self ) -> bool :
403
+ return self ._ws_handler .connect ()
404
+
405
+ def disconnect (self ) -> None :
406
+ self ._ws_handler .terminate ()
407
+
352
408
def send_json (self , message : DictObject ) -> None :
353
- # Block until message has been sent
354
- self .run_background_task (self ._send_json (message ))
409
+ self ._ws_handler .send_json (message )
410
+
411
+ # These attributes are currently accessed directly...
412
+ @property
413
+ def _ws (self ) -> AsyncWebSocketSession | None :
414
+ return self ._ws_handler ._ws
415
+
416
+ @property
417
+ def _connection_failure (self ) -> Exception | None :
418
+ return self ._ws_handler ._connection_failure
419
+
420
+ @property
421
+ def _auth_failure (self ) -> Any | None :
422
+ return self ._ws_handler ._auth_failure
0 commit comments