Skip to content

Commit 7dedda3

Browse files
authored
Sandbox fixes (#197)
1 parent 4f1ee56 commit 7dedda3

File tree

6 files changed

+118
-3
lines changed

6 files changed

+118
-3
lines changed

temporalio/worker/workflow_sandbox/_importer.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def __init__(
6767
self.new_modules: Dict[str, types.ModuleType] = {
6868
"sys": sys,
6969
"builtins": builtins,
70+
# Even though we don't want to, we have to have __main__ because
71+
# stdlib packages like inspect and others expect it to be present
72+
"__main__": types.ModuleType("__main__"),
7073
}
7174
self.modules_checked_for_restrictions: Set[str] = set()
7275
self.import_func = self._import if not LOG_TRACE else self._traced_import
@@ -334,7 +337,7 @@ def unapplied(self) -> Iterator[None]:
334337

335338

336339
class _ThreadLocalSysModules(
337-
_ThreadLocalOverride[MutableMapping[str, types.ModuleType]],
340+
_ThreadLocalOverride[Dict[str, types.ModuleType]],
338341
MutableMapping[str, types.ModuleType],
339342
):
340343
def __contains__(self, key: object) -> bool:
@@ -355,6 +358,35 @@ def __iter__(self) -> Iterator[str]:
355358
def __setitem__(self, key: str, value: types.ModuleType) -> None:
356359
self.current[key] = value
357360

361+
# Below methods are not in mutable mapping. Python chose not to put
362+
# everything in MutableMapping they do in dict (see
363+
# https://bugs.python.org/issue22101). So when someone calls
364+
# sys.modules.copy() it breaks (which is exactly what the inspect module
365+
# does sometimes).
366+
367+
def __or__(
368+
self, other: Mapping[str, types.ModuleType]
369+
) -> Dict[str, types.ModuleType]:
370+
if sys.version_info < (3, 9):
371+
raise NotImplementedError
372+
return self.current.__or__(other)
373+
374+
def __ior__(
375+
self, other: Mapping[str, types.ModuleType]
376+
) -> Dict[str, types.ModuleType]:
377+
if sys.version_info < (3, 9):
378+
raise NotImplementedError
379+
return self.current.__ior__(other)
380+
381+
__ror__ = __or__
382+
383+
def copy(self) -> Dict[str, types.ModuleType]:
384+
return self.current.copy()
385+
386+
@classmethod
387+
def fromkeys(cls, *args, **kwargs) -> Any:
388+
return dict.fromkeys(*args, **kwargs)
389+
358390

359391
_thread_local_sys_modules = _ThreadLocalSysModules(sys.modules)
360392

temporalio/worker/workflow_sandbox/_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,9 @@ def _run_code(self, code: str, **extra_globals: Any) -> None:
156156
for k, v in extra_globals.items():
157157
self.globals_and_locals[k] = v
158158
try:
159+
temporalio.workflow.unsafe._set_in_sandbox(True)
159160
exec(code, self.globals_and_locals, self.globals_and_locals)
160161
finally:
162+
temporalio.workflow.unsafe._set_in_sandbox(False)
161163
for k, v in extra_globals.items():
162164
del self.globals_and_locals[k]

temporalio/workflow.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,7 @@ async def wait_condition(
713713

714714

715715
_sandbox_unrestricted = threading.local()
716+
_in_sandbox = threading.local()
716717

717718

718719
class unsafe:
@@ -723,6 +724,19 @@ class unsafe:
723724
def __init__(self) -> None: # noqa: D107
724725
raise NotImplementedError
725726

727+
@staticmethod
728+
def in_sandbox() -> bool:
729+
"""Whether the code is executing on a sandboxed thread.
730+
731+
Returns:
732+
True if the code is executing in the sandbox thread.
733+
"""
734+
return getattr(_in_sandbox, "value", False)
735+
736+
@staticmethod
737+
def _set_in_sandbox(v: bool) -> None:
738+
_in_sandbox.value = v
739+
726740
@staticmethod
727741
def is_replaying() -> bool:
728742
"""Whether the workflow is currently replaying.

tests/worker/test_workflow.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
from temporalio.converter import (
5252
DataConverter,
5353
DefaultFailureConverterWithEncodedAttributes,
54-
FailureConverter,
5554
PayloadCodec,
5655
PayloadConverter,
5756
)

tests/worker/workflow_sandbox/test_importer.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1+
import sys
2+
13
import pytest
24

3-
from temporalio.worker.workflow_sandbox._importer import Importer
5+
from temporalio.worker.workflow_sandbox._importer import (
6+
Importer,
7+
_thread_local_sys_modules,
8+
_ThreadLocalSysModules,
9+
)
410
from temporalio.worker.workflow_sandbox._restrictions import (
511
RestrictedWorkflowAccessError,
612
RestrictionContext,
@@ -76,3 +82,22 @@ def test_workflow_sandbox_importer_invalid_module_members():
7682
err.value.qualified_name
7783
== "tests.worker.workflow_sandbox.testmodules.invalid_module_members.invalid_function.__call__"
7884
)
85+
86+
87+
def test_thread_local_sys_module_attrs():
88+
if sys.version_info < (3, 9):
89+
pytest.skip("Dict or methods only in >= 3.9")
90+
# Python chose not to put everything in MutableMapping they do in dict, see
91+
# https://bugs.python.org/issue22101. Therefore we manually confirm that
92+
# every attribute of sys modules is also in thread local sys modules to
93+
# ensure compatibility.
94+
for attr in dir(sys.modules):
95+
getattr(_thread_local_sys_modules, attr)
96+
97+
# Let's also test "or" and "copy"
98+
norm = {"foo": 123}
99+
thread_local = _ThreadLocalSysModules({"foo": 123})
100+
assert (norm | {"bar": 456}) == (thread_local | {"bar": 456})
101+
norm |= {"baz": 789}
102+
thread_local |= {"baz": 789}
103+
assert norm.copy() == thread_local.copy()

tests/worker/workflow_sandbox/test_runner.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import dataclasses
44
import functools
5+
import inspect
56
import os
67
import time
78
import uuid
@@ -267,6 +268,48 @@ async def test_workflow_sandbox_operator(client: Client):
267268
)
268269

269270

271+
global_in_sandbox = workflow.unsafe.in_sandbox()
272+
273+
274+
@workflow.defn
275+
class InSandboxWorkflow:
276+
def __init__(self) -> None:
277+
assert global_in_sandbox
278+
assert workflow.unsafe.in_sandbox()
279+
280+
@workflow.run
281+
async def run(self) -> None:
282+
assert global_in_sandbox
283+
assert workflow.unsafe.in_sandbox()
284+
285+
286+
async def test_workflow_sandbox_assert(client: Client):
287+
async with new_worker(client, InSandboxWorkflow) as worker:
288+
assert not global_in_sandbox
289+
assert not workflow.unsafe.in_sandbox()
290+
await client.execute_workflow(
291+
InSandboxWorkflow.run,
292+
id=f"workflow-{uuid.uuid4()}",
293+
task_queue=worker.task_queue,
294+
)
295+
296+
297+
@workflow.defn
298+
class AccessStackWorkflow:
299+
@workflow.run
300+
async def run(self) -> str:
301+
return inspect.stack()[0].function
302+
303+
304+
async def test_workflow_sandbox_access_stack(client: Client):
305+
async with new_worker(client, AccessStackWorkflow) as worker:
306+
assert "run" == await client.execute_workflow(
307+
AccessStackWorkflow.run,
308+
id=f"workflow-{uuid.uuid4()}",
309+
task_queue=worker.task_queue,
310+
)
311+
312+
270313
def new_worker(
271314
client: Client,
272315
*workflows: Type,

0 commit comments

Comments
 (0)