Skip to content

Commit 9d40e86

Browse files
authored
💥 BREAKING CHANGE - Lazy passthrough for sys.modules and OpenAI converter/sandbox improvements (#936)
Fixes #912
1 parent 633b90c commit 9d40e86

File tree

9 files changed

+1013
-1061
lines changed

9 files changed

+1013
-1061
lines changed

‎pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ opentelemetry = [
2525
]
2626
pydantic = ["pydantic>=2.0.0,<3"]
2727
openai-agents = [
28-
"openai-agents >= 0.0.19,<0.1",
28+
"openai-agents >= 0.1,<0.2",
2929
"eval-type-backport>=0.2.2; python_version < '3.10'"
3030
]
3131

‎temporalio/contrib/openai_agents/_temporal_model_stub.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,36 +11,35 @@
1111

1212
logger = logging.getLogger(__name__)
1313

14-
with workflow.unsafe.imports_passed_through():
15-
from typing import Any, AsyncIterator, Optional, Sequence, Union, cast
14+
from typing import Any, AsyncIterator, Optional, Sequence, Union, cast
1615

17-
from agents import (
18-
AgentOutputSchema,
19-
AgentOutputSchemaBase,
20-
ComputerTool,
21-
FileSearchTool,
22-
FunctionTool,
23-
Handoff,
24-
Model,
25-
ModelResponse,
26-
ModelSettings,
27-
ModelTracing,
28-
Tool,
29-
TResponseInputItem,
30-
WebSearchTool,
31-
)
32-
from agents.items import TResponseStreamEvent
33-
from openai.types.responses.response_prompt_param import ResponsePromptParam
16+
from agents import (
17+
AgentOutputSchema,
18+
AgentOutputSchemaBase,
19+
ComputerTool,
20+
FileSearchTool,
21+
FunctionTool,
22+
Handoff,
23+
Model,
24+
ModelResponse,
25+
ModelSettings,
26+
ModelTracing,
27+
Tool,
28+
TResponseInputItem,
29+
WebSearchTool,
30+
)
31+
from agents.items import TResponseStreamEvent
32+
from openai.types.responses.response_prompt_param import ResponsePromptParam
3433

35-
from temporalio.contrib.openai_agents.invoke_model_activity import (
36-
ActivityModelInput,
37-
AgentOutputSchemaInput,
38-
FunctionToolInput,
39-
HandoffInput,
40-
ModelActivity,
41-
ModelTracingInput,
42-
ToolInput,
43-
)
34+
from temporalio.contrib.openai_agents.invoke_model_activity import (
35+
ActivityModelInput,
36+
AgentOutputSchemaInput,
37+
FunctionToolInput,
38+
HandoffInput,
39+
ModelActivity,
40+
ModelTracingInput,
41+
ToolInput,
42+
)
4443

4544

4645
class _TemporalModelStub(Model):

‎temporalio/contrib/openai_agents/open_ai_data_converter.py

Lines changed: 3 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -5,147 +5,7 @@
55

66
from __future__ import annotations
77

8-
import importlib
9-
import inspect
10-
from typing import Any, Optional, Type, TypeVar
8+
import temporalio.contrib.pydantic
119

12-
from agents import Usage
13-
from agents.items import TResponseOutputItem
14-
from openai import NOT_GIVEN, BaseModel
15-
from pydantic import RootModel, TypeAdapter
16-
17-
import temporalio.api.common.v1
18-
from temporalio import workflow
19-
from temporalio.converter import (
20-
CompositePayloadConverter,
21-
DataConverter,
22-
DefaultPayloadConverter,
23-
EncodingPayloadConverter,
24-
JSONPlainPayloadConverter,
25-
)
26-
27-
T = TypeVar("T", bound=BaseModel)
28-
29-
30-
class _WrapperModel(RootModel[T]):
31-
model_config = {
32-
"arbitrary_types_allowed": True,
33-
}
34-
35-
36-
class _OpenAIJSONPlainPayloadConverter(EncodingPayloadConverter):
37-
"""Payload converter for OpenAI agent types that supports Pydantic models and standard Python types.
38-
39-
This converter extends the standard JSON payload converter to handle OpenAI agent-specific
40-
types, particularly Pydantic models. It supports:
41-
42-
1. All Pydantic models and their nested structures
43-
2. Standard JSON-serializable types
44-
3. Python standard library types like:
45-
- dataclasses
46-
- datetime objects
47-
- sets
48-
- UUIDs
49-
4. Custom types composed of any of the above
50-
51-
The converter uses Pydantic's serialization capabilities to ensure proper handling
52-
of complex types while maintaining compatibility with Temporal's payload system.
53-
54-
See https://docs.pydantic.dev/latest/api/standard_library_types/ for details
55-
on supported types.
56-
"""
57-
58-
@property
59-
def encoding(self) -> str:
60-
"""Get the encoding identifier for this converter.
61-
62-
Returns:
63-
The string "json/plain" indicating this is a plain JSON converter.
64-
"""
65-
return "json/plain"
66-
67-
def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]:
68-
"""Convert a value to a Temporal payload.
69-
70-
This method wraps the value in a Pydantic RootModel to handle arbitrary types
71-
and serializes it to JSON.
72-
73-
Args:
74-
value: The value to convert to a payload.
75-
76-
Returns:
77-
A Temporal payload containing the serialized value, or None if the value
78-
cannot be converted.
79-
"""
80-
wrapper = _WrapperModel[Any](root=value)
81-
data = wrapper.model_dump_json(exclude_unset=True).encode()
82-
83-
return temporalio.api.common.v1.Payload(
84-
metadata={"encoding": self.encoding.encode()}, data=data
85-
)
86-
87-
def from_payload(
88-
self,
89-
payload: temporalio.api.common.v1.Payload,
90-
type_hint: Optional[Type] = None,
91-
) -> Any:
92-
"""Convert a Temporal payload back to a Python value.
93-
94-
This method deserializes the JSON payload and validates it against the
95-
provided type hint using Pydantic's validation system.
96-
97-
Args:
98-
payload: The Temporal payload to convert.
99-
type_hint: Optional type hint for validation.
100-
101-
Returns:
102-
The deserialized and validated value.
103-
104-
Note:
105-
The type hint is used for validation but the actual type returned
106-
may be a Pydantic model instance.
107-
"""
108-
_type_hint = type_hint or Any
109-
wrapper = _WrapperModel[_type_hint] # type: ignore[valid-type]
110-
111-
with workflow.unsafe.imports_passed_through():
112-
with workflow.unsafe.sandbox_unrestricted():
113-
wrapper.model_rebuild(
114-
_types_namespace=_get_openai_modules()
115-
| {"TResponseOutputItem": TResponseOutputItem, "Usage": Usage}
116-
)
117-
return TypeAdapter(wrapper).validate_json(payload.data.decode()).root
118-
119-
120-
def _get_openai_modules() -> dict[Any, Any]:
121-
def get_modules(module):
122-
result_dict: dict[Any, Any] = {}
123-
for _, mod in inspect.getmembers(module, inspect.ismodule):
124-
result_dict |= mod.__dict__ | get_modules(mod)
125-
return result_dict
126-
127-
return get_modules(importlib.import_module("openai.types"))
128-
129-
130-
class OpenAIPayloadConverter(CompositePayloadConverter):
131-
"""Payload converter for payloads containing pydantic model instances.
132-
133-
JSON conversion is replaced with a converter that uses
134-
:py:class:`PydanticJSONPlainPayloadConverter`.
135-
"""
136-
137-
def __init__(self) -> None:
138-
"""Initialize object"""
139-
json_payload_converter = _OpenAIJSONPlainPayloadConverter()
140-
super().__init__(
141-
*(
142-
c
143-
if not isinstance(c, JSONPlainPayloadConverter)
144-
else json_payload_converter
145-
for c in DefaultPayloadConverter.default_encoding_payload_converters
146-
)
147-
)
148-
149-
150-
open_ai_data_converter = DataConverter(payload_converter_class=OpenAIPayloadConverter)
151-
"""Open AI Agent library types data converter"""
10+
open_ai_data_converter = temporalio.contrib.pydantic.pydantic_data_converter
11+
"""DEPRECATED, use temporalio.contrib.pydantic.pydantic_data_converter"""

‎temporalio/contrib/openai_agents/temporal_tools.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
from datetime import timedelta
55
from typing import Any, Callable, Optional
66

7+
from agents import FunctionTool, RunContextWrapper, Tool
8+
from agents.function_schema import function_schema
9+
710
from temporalio import activity, workflow
811
from temporalio.common import Priority, RetryPolicy
912
from temporalio.exceptions import ApplicationError, TemporalError
1013
from temporalio.workflow import ActivityCancellationType, VersioningIntent, unsafe
1114

12-
with unsafe.imports_passed_through():
13-
from agents import FunctionTool, RunContextWrapper, Tool
14-
from agents.function_schema import function_schema
15-
1615

1716
class ToolSerializationError(TemporalError):
1817
"""Error that occurs when a tool output could not be serialized."""

‎temporalio/worker/workflow_sandbox/_importer.py

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -145,18 +145,30 @@ def applied(self) -> Iterator[None]:
145145
while it is running and therefore should be locked against other
146146
code running at the same time.
147147
"""
148-
with _thread_local_sys_modules.applied(sys, "modules", self.new_modules):
149-
with _thread_local_import.applied(builtins, "__import__", self.import_func):
150-
with self._builtins_restricted():
151-
yield None
148+
orig_importer = Importer.current_importer()
149+
Importer._thread_local_current.importer = self
150+
try:
151+
with _thread_local_sys_modules.applied(sys, "modules", self.new_modules):
152+
with _thread_local_import.applied(
153+
builtins, "__import__", self.import_func
154+
):
155+
with self._builtins_restricted():
156+
yield None
157+
finally:
158+
Importer._thread_local_current.importer = orig_importer
152159

153160
@contextmanager
154161
def _unapplied(self) -> Iterator[None]:
162+
orig_importer = Importer.current_importer()
163+
Importer._thread_local_current.importer = None
155164
# Set orig modules, then unset on complete
156-
with _thread_local_sys_modules.unapplied():
157-
with _thread_local_import.unapplied():
158-
with self._builtins_unrestricted():
159-
yield None
165+
try:
166+
with _thread_local_sys_modules.unapplied():
167+
with _thread_local_import.unapplied():
168+
with self._builtins_unrestricted():
169+
yield None
170+
finally:
171+
Importer._thread_local_current.importer = orig_importer
160172

161173
def _traced_import(
162174
self,
@@ -211,6 +223,8 @@ def _import(
211223
# Put it on the parent
212224
if parent:
213225
setattr(sys.modules[parent], child, sys.modules[full_name])
226+
# All children of this module that are on the original sys
227+
# modules but not here and are passthrough
214228

215229
# If the module is __temporal_main__ and not already in sys.modules,
216230
# we load it from whatever file __main__ was originally in
@@ -251,21 +265,31 @@ def _assert_valid_module(self, name: str) -> None:
251265
):
252266
raise RestrictedWorkflowAccessError(name)
253267

268+
def module_configured_passthrough(self, name: str) -> bool:
269+
"""Whether the given module name is configured as passthrough."""
270+
if (
271+
self.restrictions.passthrough_all_modules
272+
or name in self.restrictions.passthrough_modules
273+
):
274+
return True
275+
# Iterate backwards looking if configured passthrough
276+
end_dot = -1
277+
while True:
278+
end_dot = name.find(".", end_dot + 1)
279+
if end_dot == -1:
280+
return False
281+
elif name[:end_dot] in self.restrictions.passthrough_modules:
282+
break
283+
return True
284+
254285
def _maybe_passthrough_module(self, name: str) -> Optional[types.ModuleType]:
255286
# If imports not passed through and all modules are not passed through
256287
# and name not in passthrough modules, check parents
257288
if (
258289
not temporalio.workflow.unsafe.is_imports_passed_through()
259-
and not self.restrictions.passthrough_all_modules
260-
and name not in self.restrictions.passthrough_modules
290+
and not self.module_configured_passthrough(name)
261291
):
262-
end_dot = -1
263-
while True:
264-
end_dot = name.find(".", end_dot + 1)
265-
if end_dot == -1:
266-
return None
267-
elif name[:end_dot] in self.restrictions.passthrough_modules:
268-
break
292+
return None
269293
# Do the pass through
270294
with self._unapplied():
271295
_trace("Passing module %s through from host", name)
@@ -311,6 +335,13 @@ def _builtins_unrestricted(self) -> Iterator[None]:
311335
stack.enter_context(thread_local.unapplied())
312336
yield None
313337

338+
_thread_local_current = threading.local()
339+
340+
@staticmethod
341+
def current_importer() -> Optional[Importer]:
342+
"""Get the current importer if any."""
343+
return Importer._thread_local_current.__dict__.get("importer")
344+
314345

315346
_T = TypeVar("_T")
316347

@@ -385,13 +416,23 @@ class _ThreadLocalSysModules(
385416
MutableMapping[str, types.ModuleType],
386417
):
387418
def __contains__(self, key: object) -> bool:
388-
return key in self.current
419+
if key in self.current:
420+
return True
421+
return (
422+
isinstance(key, str)
423+
and self._lazily_passthrough_if_available(key) is not None
424+
)
389425

390426
def __delitem__(self, key: str) -> None:
391427
del self.current[key]
392428

393429
def __getitem__(self, key: str) -> types.ModuleType:
394-
return self.current[key]
430+
try:
431+
return self.current[key]
432+
except KeyError:
433+
if module := self._lazily_passthrough_if_available(key):
434+
return module
435+
raise
395436

396437
def __len__(self) -> int:
397438
return len(self.current)
@@ -431,6 +472,20 @@ def copy(self) -> Dict[str, types.ModuleType]:
431472
def fromkeys(cls, *args, **kwargs) -> Any:
432473
return dict.fromkeys(*args, **kwargs)
433474

475+
def _lazily_passthrough_if_available(self, key: str) -> Optional[types.ModuleType]:
476+
# We only lazily pass through if it's in orig, lazy not disabled, and
477+
# module configured as pass through
478+
if (
479+
key in self.orig
480+
and (importer := Importer.current_importer())
481+
and not importer.restrictions.disable_lazy_sys_module_passthrough
482+
and importer.module_configured_passthrough(key)
483+
):
484+
orig = self.orig[key]
485+
self.current[key] = orig
486+
return orig
487+
return None
488+
434489

435490
_thread_local_sys_modules = _ThreadLocalSysModules(sys.modules)
436491

0 commit comments

Comments
 (0)