2020from typing import Any , Callable
2121
2222import zmq
23+ import zmq_anyio
2324from anyio import create_task_group , run , sleep , to_thread
2425from jupyter_client .session import extract_header
2526
@@ -55,11 +56,11 @@ def run(self):
5556 run (self ._main )
5657
5758 async def _main (self ):
58- async with create_task_group () as tg :
59+ async with create_task_group () as self . _task_group :
5960 for task in self ._tasks :
60- tg .start_soon (task )
61+ self . _task_group .start_soon (task )
6162 await to_thread .run_sync (self .__stop .wait )
62- tg .cancel_scope .cancel ()
63+ self . _task_group .cancel_scope .cancel ()
6364
6465 def stop (self ):
6566 """Stop the thread.
@@ -78,7 +79,7 @@ class IOPubThread:
7879 whose IO is always run in a thread.
7980 """
8081
81- def __init__ (self , socket , pipe = False ):
82+ def __init__ (self , socket : zmq_anyio . Socket , pipe = False ):
8283 """Create IOPub thread
8384
8485 Parameters
@@ -91,10 +92,7 @@ def __init__(self, socket, pipe=False):
9192 """
9293 # ensure all of our sockets as sync zmq.Sockets
9394 # don't create async wrappers until we are within the appropriate coroutines
94- self .socket : zmq .Socket [bytes ] | None = zmq .Socket (socket )
95- if self .socket .context is None :
96- # bug in pyzmq, shadow socket doesn't always inherit context attribute
97- self .socket .context = socket .context # type:ignore[unreachable]
95+ self .socket : zmq_anyio .Socket = socket
9896 self ._context = socket .context
9997
10098 self .background_socket = BackgroundSocket (self )
@@ -108,14 +106,14 @@ def __init__(self, socket, pipe=False):
108106 self ._event_pipe_gc_lock : threading .Lock = threading .Lock ()
109107 self ._event_pipe_gc_seconds : float = 10
110108 self ._setup_event_pipe ()
111- tasks = [self ._handle_event , self ._run_event_pipe_gc ]
109+ tasks = [self ._handle_event , self ._run_event_pipe_gc , self . socket . start ]
112110 if pipe :
113111 tasks .append (self ._handle_pipe_msgs )
114112 self .thread = _IOPubThread (tasks )
115113
116114 def _setup_event_pipe (self ):
117115 """Create the PULL socket listening for events that should fire in this thread."""
118- self ._pipe_in0 = self ._context .socket (zmq .PULL , socket_class = zmq . Socket )
116+ self ._pipe_in0 = self ._context .socket (zmq .PULL )
119117 self ._pipe_in0 .linger = 0
120118
121119 _uuid = b2a_hex (os .urandom (16 )).decode ("ascii" )
@@ -150,7 +148,7 @@ def _event_pipe(self):
150148 except AttributeError :
151149 # new thread, new event pipe
152150 # create sync base socket
153- event_pipe = self ._context .socket (zmq .PUSH , socket_class = zmq . Socket )
151+ event_pipe = self ._context .socket (zmq .PUSH )
154152 event_pipe .linger = 0
155153 event_pipe .connect (self ._event_interface )
156154 self ._local .event_pipe = event_pipe
@@ -169,30 +167,28 @@ async def _handle_event(self):
169167 Whenever *an* event arrives on the event stream,
170168 *all* waiting events are processed in order.
171169 """
172- # create async wrapper within coroutine
173- pipe_in = zmq . asyncio . Socket ( self . _pipe_in0 )
174- try :
175- while True :
176- await pipe_in .recv ()
177- # freeze event count so new writes don't extend the queue
178- # while we are processing
179- n_events = len (self ._events )
180- for _ in range (n_events ):
181- event_f = self ._events .popleft ()
182- event_f ()
183- except Exception :
184- if self .thread .__stop .is_set ():
185- return
186- raise
170+ pipe_in = zmq_anyio . Socket ( self . _pipe_in0 )
171+ async with pipe_in :
172+ try :
173+ while True :
174+ await pipe_in .arecv ()
175+ # freeze event count so new writes don't extend the queue
176+ # while we are processing
177+ n_events = len (self ._events )
178+ for _ in range (n_events ):
179+ event_f = self ._events .popleft ()
180+ event_f ()
181+ except Exception :
182+ if self .thread .__stop .is_set ():
183+ return
184+ raise
187185
188186 def _setup_pipe_in (self ):
189187 """setup listening pipe for IOPub from forked subprocesses"""
190- ctx = self ._context
191-
192188 # use UUID to authenticate pipe messages
193189 self ._pipe_uuid = os .urandom (16 )
194190
195- self ._pipe_in1 = ctx . socket (zmq .PULL , socket_class = zmq . Socket )
191+ self ._pipe_in1 = zmq_anyio . Socket ( self . _context . socket (zmq .PULL ) )
196192 self ._pipe_in1 .linger = 0
197193
198194 try :
@@ -210,18 +206,18 @@ def _setup_pipe_in(self):
210206 async def _handle_pipe_msgs (self ):
211207 """handle pipe messages from a subprocess"""
212208 # create async wrapper within coroutine
213- self . _async_pipe_in1 = zmq . asyncio . Socket ( self ._pipe_in1 )
214- try :
215- while True :
216- await self ._handle_pipe_msg ()
217- except Exception :
218- if self .thread .__stop .is_set ():
219- return
220- raise
209+ async with self ._pipe_in1 :
210+ try :
211+ while True :
212+ await self ._handle_pipe_msg ()
213+ except Exception :
214+ if self .thread .__stop .is_set ():
215+ return
216+ raise
221217
222218 async def _handle_pipe_msg (self , msg = None ):
223219 """handle a pipe message from a subprocess"""
224- msg = msg or await self ._async_pipe_in1 . recv_multipart ()
220+ msg = msg or await self ._pipe_in1 . arecv_multipart ()
225221 if not self ._pipe_flag or not self ._is_main_process ():
226222 return
227223 if msg [0 ] != self ._pipe_uuid :
0 commit comments