Skip to content

Commit 95f2451

Browse files
authored
fix ContextVar persistence across cells (#1462)
1 parent c56a7aa commit 95f2451

File tree

5 files changed

+93
-5
lines changed

5 files changed

+93
-5
lines changed

ipykernel/kernelbase.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060

6161
from ._version import kernel_protocol_version
6262
from .iostream import OutStream
63-
from .utils import LazyDict
63+
from .utils import LazyDict, _async_in_context
6464

6565
_AWAITABLE_MESSAGE: str = (
6666
"For consistency across implementations, it is recommended that `{func_name}`"
@@ -557,7 +557,7 @@ def start(self):
557557
self.shell_stream.on_recv(self.shell_channel_thread_main, copy=False)
558558
else:
559559
self.shell_stream.on_recv(
560-
partial(self.shell_main, None),
560+
_async_in_context(partial(self.shell_main, None)),
561561
copy=False,
562562
)
563563

ipykernel/subshell_manager.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .socket_pair import SocketPair
1616
from .subshell import SubshellThread
1717
from .thread import SHELL_CHANNEL_THREAD_NAME
18+
from .utils import _async_in_context
1819

1920

2021
class SubshellManager:
@@ -129,7 +130,9 @@ def set_on_recv_callback(self, on_recv_callback):
129130
"""
130131
assert current_thread() == self._parent_thread
131132
self._on_recv_callback = on_recv_callback
132-
self._shell_channel_to_main.on_recv(IOLoop.current(), partial(self._on_recv_callback, None))
133+
self._shell_channel_to_main.on_recv(
134+
IOLoop.current(), _async_in_context(partial(on_recv_callback, None))
135+
)
133136

134137
def set_subshell_aborting(self, subshell_id: str, aborting: bool) -> None:
135138
"""Set the aborting flag of the specified subshell."""
@@ -165,7 +168,7 @@ def _create_subshell(self) -> str:
165168

166169
subshell_thread.shell_channel_to_subshell.on_recv(
167170
subshell_thread.io_loop,
168-
partial(self._on_recv_callback, subshell_id),
171+
_async_in_context(partial(self._on_recv_callback, subshell_id)),
169172
)
170173

171174
subshell_thread.subshell_to_shell_channel.on_recv(

ipykernel/utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
"""Utilities"""
22

3+
from __future__ import annotations
4+
5+
import asyncio
6+
import sys
37
import typing as t
48
from collections.abc import Mapping
9+
from contextvars import copy_context
10+
from functools import partial, wraps
11+
12+
if t.TYPE_CHECKING:
13+
from collections.abc import Callable
14+
from contextvars import Context
515

616

717
class LazyDict(Mapping[str, t.Any]):
@@ -24,3 +34,48 @@ def __len__(self):
2434

2535
def __iter__(self):
2636
return iter(self._dict)
37+
38+
39+
T = t.TypeVar("T")
40+
U = t.TypeVar("U")
41+
V = t.TypeVar("V")
42+
43+
44+
def _async_in_context(
45+
f: Callable[..., t.Coroutine[T, U, V]], context: Context | None = None
46+
) -> Callable[..., t.Coroutine[T, U, V]]:
47+
"""
48+
Wrapper to run a coroutine in a persistent ContextVar Context.
49+
50+
Backports asyncio.create_task(context=...) behavior from Python 3.11
51+
"""
52+
if context is None:
53+
context = copy_context()
54+
55+
if sys.version_info >= (3, 11):
56+
57+
@wraps(f)
58+
async def run_in_context(*args, **kwargs):
59+
coro = f(*args, **kwargs)
60+
return await asyncio.create_task(coro, context=context)
61+
62+
return run_in_context
63+
64+
# don't need this backport when we require 3.11
65+
# context_holder so we have a modifiable container for later calls
66+
context_holder = [context] # type: ignore[unreachable]
67+
68+
async def preserve_context(f, *args, **kwargs):
69+
"""call a coroutine, preserving the context after it is called"""
70+
try:
71+
return await f(*args, **kwargs)
72+
finally:
73+
# persist changes to the context for future calls
74+
context_holder[0] = copy_context()
75+
76+
@wraps(f)
77+
async def run_in_context_pre311(*args, **kwargs):
78+
ctx = context_holder[0]
79+
return await ctx.run(partial(asyncio.create_task, preserve_context(f, *args, **kwargs)))
80+
81+
return run_in_context_pre311

ipykernel/zmqshell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ def _update_exit_now(self, change):
570570
# Over ZeroMQ, GUI control isn't done with PyOS_InputHook as there is no
571571
# interactive input being read; we provide event loop support in ipkernel
572572
def enable_gui(self, gui: typing.Any = None) -> None:
573-
"""Enable a given guil."""
573+
"""Enable a given gui."""
574574
from .eventloops import enable_gui as real_enable_gui
575575

576576
try:

tests/test_kernel.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,3 +848,33 @@ def test_parent_header_and_ident():
848848
msg_id, _ = execute(kc=kc, code="print(k._parent_ident['control'])")
849849
stdout, _ = assemble_output(kc.get_iopub_msg, parent_msg_id=msg_id)
850850
assert stdout == f"[b'{session}']\n"
851+
852+
853+
def test_context_vars():
854+
with new_kernel() as kc:
855+
msg_id, _ = execute(
856+
kc=kc,
857+
code="from contextvars import ContextVar, copy_context\nctxvar = ContextVar('var', default='default')",
858+
)
859+
stdout, _ = assemble_output(kc.get_iopub_msg, parent_msg_id=msg_id)
860+
861+
msg_id, _ = execute(
862+
kc=kc,
863+
code="print(ctxvar.get())",
864+
)
865+
stdout, _ = assemble_output(kc.get_iopub_msg, parent_msg_id=msg_id)
866+
assert stdout.strip() == "default"
867+
868+
msg_id, _ = execute(
869+
kc=kc,
870+
code="ctxvar.set('set'); print(ctxvar.get())",
871+
)
872+
stdout, _ = assemble_output(kc.get_iopub_msg, parent_msg_id=msg_id)
873+
assert stdout.strip() == "set"
874+
875+
msg_id, _ = execute(
876+
kc=kc,
877+
code="print(ctxvar.get())",
878+
)
879+
stdout, _ = assemble_output(kc.get_iopub_msg, parent_msg_id=msg_id)
880+
assert stdout.strip() == "set"

0 commit comments

Comments
 (0)